/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.checkpoint;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.OperatorIDPair;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.MappingBasedRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorState;
import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateAssignment;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.api.writer.SubtaskStateMapper;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.OperatorInstanceID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Internal
public class StateAssignmentOperation {
    private static final Logger LOG = LoggerFactory.getLogger(StateAssignmentOperation.class);
    private final Set<ExecutionJobVertex> tasks;
    private final Map<OperatorID, OperatorState> operatorStates;
    private final long restoreCheckpointId;
    private final boolean allowNonRestoredState;
    private final Map<ExecutionJobVertex, TaskStateAssignment> vertexAssignments;
    private final Map<IntermediateDataSetID, TaskStateAssignment> consumerAssignment = new HashMap<IntermediateDataSetID, TaskStateAssignment>();

    public StateAssignmentOperation(long restoreCheckpointId, Set<ExecutionJobVertex> tasks, Map<OperatorID, OperatorState> operatorStates, boolean allowNonRestoredState) {
        this.restoreCheckpointId = restoreCheckpointId;
        this.tasks = (Set)Preconditions.checkNotNull(tasks);
        this.operatorStates = (Map)Preconditions.checkNotNull(operatorStates);
        this.allowNonRestoredState = allowNonRestoredState;
        this.vertexAssignments = new HashMap<ExecutionJobVertex, TaskStateAssignment>(tasks.size());
    }

    public void assignStates() {
        HashMap<OperatorID, OperatorState> localOperators = new HashMap<OperatorID, OperatorState>(this.operatorStates);
        StateAssignmentOperation.checkStateMappingCompleteness(this.allowNonRestoredState, this.operatorStates, this.tasks);
        for (ExecutionJobVertex executionJobVertex : this.tasks) {
            List<OperatorIDPair> operatorIDPairs = executionJobVertex.getOperatorIDs();
            HashMap<OperatorID, OperatorState> operatorStates = new HashMap<OperatorID, OperatorState>(operatorIDPairs.size());
            for (OperatorIDPair operatorIDPair : operatorIDPairs) {
                OperatorID operatorID = operatorIDPair.getUserDefinedOperatorID().orElse(operatorIDPair.getGeneratedOperatorID());
                OperatorState operatorState = (OperatorState)localOperators.remove((Object)operatorID);
                if (operatorState == null) {
                    operatorState = new OperatorState(operatorID, executionJobVertex.getParallelism(), executionJobVertex.getMaxParallelism());
                }
                operatorStates.put(operatorIDPair.getGeneratedOperatorID(), operatorState);
            }
            TaskStateAssignment stateAssignment = new TaskStateAssignment(executionJobVertex, operatorStates);
            this.vertexAssignments.put(executionJobVertex, stateAssignment);
            for (IntermediateResult producedDataSet : executionJobVertex.getInputs()) {
                this.consumerAssignment.put(producedDataSet.getId(), stateAssignment);
            }
        }
        for (TaskStateAssignment stateAssignment : this.vertexAssignments.values()) {
            if (!stateAssignment.hasState) continue;
            this.assignAttemptState(stateAssignment);
        }
        for (TaskStateAssignment stateAssignment : this.vertexAssignments.values()) {
            if (!stateAssignment.hasState) continue;
            this.assignTaskStateToExecutionJobVertices(stateAssignment);
        }
    }

    private void assignAttemptState(TaskStateAssignment taskStateAssignment) {
        this.checkParallelismPreconditions(taskStateAssignment);
        List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(taskStateAssignment.executionJobVertex.getMaxParallelism(), taskStateAssignment.newParallelism);
        StateAssignmentOperation.reDistributePartitionableStates(taskStateAssignment.oldState, taskStateAssignment.newParallelism, OperatorSubtaskState::getManagedOperatorState, RoundRobinOperatorStateRepartitioner.INSTANCE, taskStateAssignment.subManagedOperatorState);
        StateAssignmentOperation.reDistributePartitionableStates(taskStateAssignment.oldState, taskStateAssignment.newParallelism, OperatorSubtaskState::getRawOperatorState, RoundRobinOperatorStateRepartitioner.INSTANCE, taskStateAssignment.subRawOperatorState);
        this.reDistributeInputChannelStates(taskStateAssignment);
        this.reDistributeResultSubpartitionStates(taskStateAssignment);
        this.reDistributeKeyedStates(keyGroupPartitions, taskStateAssignment);
    }

    private void assignTaskStateToExecutionJobVertices(TaskStateAssignment assignment) {
        ExecutionJobVertex executionJobVertex = assignment.executionJobVertex;
        List<OperatorIDPair> operatorIDs = executionJobVertex.getOperatorIDs();
        int newParallelism = executionJobVertex.getParallelism();
        for (int subTaskIndex = 0; subTaskIndex < newParallelism; ++subTaskIndex) {
            Execution currentExecutionAttempt = executionJobVertex.getTaskVertices()[subTaskIndex].getCurrentExecutionAttempt();
            TaskStateSnapshot taskState = new TaskStateSnapshot(operatorIDs.size());
            boolean statelessTask = true;
            for (OperatorIDPair operatorID : operatorIDs) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID.getGeneratedOperatorID());
                OperatorSubtaskState operatorSubtaskState = StateAssignmentOperation.operatorSubtaskStateFrom(instanceID, assignment);
                if (operatorSubtaskState.hasState()) {
                    statelessTask = false;
                }
                taskState.putSubtaskStateByOperatorID(operatorID.getGeneratedOperatorID(), operatorSubtaskState);
            }
            if (statelessTask) continue;
            JobManagerTaskRestore taskRestore = new JobManagerTaskRestore(this.restoreCheckpointId, taskState);
            currentExecutionAttempt.setInitialState(taskRestore);
        }
    }

    public static OperatorSubtaskState operatorSubtaskStateFrom(OperatorInstanceID instanceID, TaskStateAssignment assignment) {
        return assignment.getSubtaskState(instanceID);
    }

    public void checkParallelismPreconditions(TaskStateAssignment taskStateAssignment) {
        for (OperatorState operatorState : taskStateAssignment.oldState.values()) {
            StateAssignmentOperation.checkParallelismPreconditions(operatorState, taskStateAssignment.executionJobVertex);
        }
    }

    private void reDistributeKeyedStates(List<KeyGroupRange> keyGroupPartitions, TaskStateAssignment stateAssignment) {
        stateAssignment.oldState.forEach((operatorID, operatorState) -> {
            for (int subTaskIndex = 0; subTaskIndex < stateAssignment.newParallelism; ++subTaskIndex) {
                OperatorInstanceID instanceID = OperatorInstanceID.of(subTaskIndex, operatorID);
                Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> subKeyedStates = this.reAssignSubKeyedStates((OperatorState)operatorState, keyGroupPartitions, subTaskIndex, stateAssignment.newParallelism, operatorState.getParallelism());
                stateAssignment.subManagedKeyedState.put(instanceID, (List<KeyedStateHandle>)subKeyedStates.f0);
                stateAssignment.subRawKeyedState.put(instanceID, (List<KeyedStateHandle>)subKeyedStates.f1);
            }
        });
    }

    private Tuple2<List<KeyedStateHandle>, List<KeyedStateHandle>> reAssignSubKeyedStates(OperatorState operatorState, List<KeyGroupRange> keyGroupPartitions, int subTaskIndex, int newParallelism, int oldParallelism) {
        List<Object> subRawKeyedState;
        List<Object> subManagedKeyedState;
        if (newParallelism == oldParallelism) {
            if (operatorState.getState(subTaskIndex) != null) {
                subManagedKeyedState = operatorState.getState(subTaskIndex).getManagedKeyedState().asList();
                subRawKeyedState = operatorState.getState(subTaskIndex).getRawKeyedState().asList();
            } else {
                subManagedKeyedState = Collections.emptyList();
                subRawKeyedState = Collections.emptyList();
            }
        } else {
            subManagedKeyedState = StateAssignmentOperation.getManagedKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
            subRawKeyedState = StateAssignmentOperation.getRawKeyedStateHandles(operatorState, keyGroupPartitions.get(subTaskIndex));
        }
        if (subManagedKeyedState.isEmpty() && subRawKeyedState.isEmpty()) {
            return new Tuple2(Collections.emptyList(), Collections.emptyList());
        }
        return new Tuple2(subManagedKeyedState, subRawKeyedState);
    }

    public static <T extends StateObject> void reDistributePartitionableStates(Map<OperatorID, OperatorState> oldOperatorStates, int newParallelism, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle, OperatorStateRepartitioner<T> stateRepartitioner, Map<OperatorInstanceID, List<T>> result) {
        Map oldStates = StateAssignmentOperation.splitManagedAndRawOperatorStates(oldOperatorStates, extractHandle);
        oldOperatorStates.forEach((operatorID, oldOperatorState) -> result.putAll(StateAssignmentOperation.applyRepartitioner(operatorID, stateRepartitioner, (List)oldStates.get(operatorID), oldOperatorState.getParallelism(), newParallelism)));
    }

    public <I, T extends AbstractChannelStateHandle<I>> void reDistributeResultSubpartitionStates(TaskStateAssignment assignment) {
        ExecutionJobVertex executionJobVertex = assignment.executionJobVertex;
        OperatorID outputOperatorID = executionJobVertex.getOperatorIDs().get(0).getGeneratedOperatorID();
        List<List<T>> outputOperatorState = this.getChannelState(assignment.oldState, OperatorSubtaskState::getResultSubpartitionState, outputOperatorID);
        if (outputOperatorState == null) {
            return;
        }
        List<IntermediateDataSet> outputs = executionJobVertex.getJobVertex().getProducedDataSets();
        if (outputOperatorState.size() == executionJobVertex.getParallelism()) {
            assignment.resultSubpartitionStates.putAll(StateAssignmentOperation.toInstanceMap(outputOperatorID, outputOperatorState));
        } else {
            for (int partitionIndex = 0; partitionIndex < outputs.size(); ++partitionIndex) {
                TaskStateAssignment downstreamAssignment = this.consumerAssignment.get(executionJobVertex.getProducedDataSets()[partitionIndex].getId());
                IntermediateResult output = executionJobVertex.getProducedDataSets()[partitionIndex];
                int gateIndex = downstreamAssignment.executionJobVertex.getInputs().indexOf(output);
                Preconditions.checkState((gateIndex >= 0 ? 1 : 0) != 0, (Object)"Gate index not found for IntermediateResult");
                downstreamAssignment.upstreamAssignments.put(gateIndex, assignment);
                SubtaskStateMapper mapper = (SubtaskStateMapper)((Object)Preconditions.checkNotNull((Object)((Object)downstreamAssignment.executionJobVertex.getJobVertex().getInputs().get(gateIndex).getUpstreamSubtaskStateMapper()), (String)"No channel rescaler found during rescaling of channel state"));
                Map<Integer, Set<Integer>> mapping = mapper.getNewToOldSubtasksMapping(outputOperatorState.size(), executionJobVertex.getParallelism());
                assignment.outputSubtaskMappings = this.checkSubtaskMapping(assignment.outputSubtaskMappings, mapping);
                List<List<T>> partitionState = outputs.size() == 1 ? outputOperatorState : StateAssignmentOperation.getPartitionState(outputOperatorState, ResultSubpartitionInfo::getPartitionIdx, partitionIndex);
                MappingBasedRepartitioner repartitioner = new MappingBasedRepartitioner(mapping);
                Map repartitioned = StateAssignmentOperation.applyRepartitioner(outputOperatorID, repartitioner, partitionState, outputOperatorState.size(), executionJobVertex.getParallelism());
                StateAssignmentOperation.addToSubtasks(assignment.resultSubpartitionStates, repartitioned);
            }
        }
    }

    private Map<Integer, Set<Integer>> checkSubtaskMapping(Map<Integer, Set<Integer>> oldMapping, Map<Integer, Set<Integer>> mapping) {
        if (oldMapping.isEmpty()) {
            return mapping;
        }
        if (!oldMapping.equals(mapping)) {
            throw new IllegalStateException("Incompatible subtask mappings: are multiple operators ingesting/producing intermediate results with varying degrees of parallelism?Found " + oldMapping + " and " + mapping + ".");
        }
        return oldMapping;
    }

    public void reDistributeInputChannelStates(TaskStateAssignment stateAssignment) {
        List inputOperatorState = this.getChannelState(stateAssignment.oldState, OperatorSubtaskState::getInputChannelState, stateAssignment.inputOperatorID);
        if (inputOperatorState == null) {
            return;
        }
        ExecutionJobVertex executionJobVertex = stateAssignment.executionJobVertex;
        List<IntermediateResult> inputs = executionJobVertex.getInputs();
        if (inputOperatorState.size() == executionJobVertex.getParallelism()) {
            stateAssignment.inputChannelStates.putAll(StateAssignmentOperation.toInstanceMap(stateAssignment.inputOperatorID, inputOperatorState));
        } else {
            for (int gateIndex = 0; gateIndex < inputs.size(); ++gateIndex) {
                SubtaskStateMapper mapper = (SubtaskStateMapper)((Object)Preconditions.checkNotNull((Object)((Object)executionJobVertex.getJobVertex().getInputs().get(gateIndex).getDownstreamSubtaskStateMapper()), (String)"No channel rescaler found during rescaling of channel state"));
                Map<Integer, Set<Integer>> mapping = mapper.getNewToOldSubtasksMapping(inputOperatorState.size(), stateAssignment.newParallelism);
                stateAssignment.inputSubtaskMappings = this.checkSubtaskMapping(stateAssignment.inputSubtaskMappings, mapping);
                List gateState = inputs.size() == 1 ? inputOperatorState : StateAssignmentOperation.getPartitionState(inputOperatorState, InputChannelInfo::getGateIdx, gateIndex);
                MappingBasedRepartitioner repartitioner = new MappingBasedRepartitioner(mapping);
                Map repartitioned = StateAssignmentOperation.applyRepartitioner(stateAssignment.inputOperatorID, repartitioner, gateState, inputOperatorState.size(), stateAssignment.newParallelism);
                StateAssignmentOperation.addToSubtasks(stateAssignment.inputChannelStates, repartitioned);
                IntermediateResult input = executionJobVertex.getInputs().get(gateIndex);
                TaskStateAssignment upstreamAssignment = this.vertexAssignments.get(executionJobVertex.getInputs().get(gateIndex).getProducer());
                int partitionIndex = Arrays.asList(upstreamAssignment.executionJobVertex.getProducedDataSets()).indexOf(input);
                Preconditions.checkState((partitionIndex >= 0 ? 1 : 0) != 0, (Object)"Partition index not found for IntermediateResult");
                upstreamAssignment.downstreamAssignments.put(partitionIndex, stateAssignment);
            }
        }
    }

    private static <K, V> void addToSubtasks(Map<K, List<V>> target, Map<K, List<V>> toAdd) {
        toAdd.forEach((key, values) -> target.computeIfAbsent(key, unused -> new ArrayList(values.size())).addAll(values));
    }

    @Nullable
    private <T extends AbstractChannelStateHandle<?>> List<List<T>> getChannelState(Map<OperatorID, OperatorState> oldOperatorStates, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle, OperatorID statefulOperatorID) {
        List unexpectedState = oldOperatorStates.entrySet().stream().filter(operatorIDAndState -> !((OperatorID)((Object)((Object)operatorIDAndState.getKey()))).equals((Object)statefulOperatorID)).filter(operatorState -> ((OperatorState)operatorState.getValue()).getSubtaskStates().values().stream().anyMatch(operatorSubtaskState -> ((StateObjectCollection)extractHandle.apply((OperatorSubtaskState)operatorSubtaskState)).stream().anyMatch(state -> !state.getOffsets().isEmpty()))).map(Map.Entry::getKey).collect(Collectors.toList());
        if (!unexpectedState.isEmpty()) {
            throw new IllegalStateException("Cannot recover from unaligned checkpoint when topology changes, such that data exchanges with persisted data are now chained.\nThe following operators contain channel state: " + unexpectedState);
        }
        OperatorState operatorState2 = oldOperatorStates.get((Object)statefulOperatorID);
        if (operatorState2.getSubtaskStates().values().stream().allMatch(operatorSubtaskState -> ((StateObjectCollection)extractHandle.apply((OperatorSubtaskState)operatorSubtaskState)).stream().allMatch(state -> state.getOffsets().isEmpty()))) {
            return null;
        }
        return StateAssignmentOperation.splitBySubtasks(operatorState2, extractHandle);
    }

    private static <T extends AbstractChannelStateHandle<I>, I> List<List<T>> getPartitionState(List<List<T>> subtaskStates, Function<I, Integer> partitionExtractor, int partitionId) {
        return subtaskStates.stream().map(subtaskState -> subtaskState.stream().filter(state -> (Integer)partitionExtractor.apply(state.getInfo()) == partitionId).collect(Collectors.toList())).collect(Collectors.toList());
    }

    private static <T extends StateObject> Map<OperatorID, List<List<T>>> splitManagedAndRawOperatorStates(Map<OperatorID, OperatorState> operatorStates, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
        return operatorStates.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, operatorIdAndState -> StateAssignmentOperation.splitBySubtasks((OperatorState)operatorIdAndState.getValue(), extractHandle)));
    }

    private static <T extends StateObject> List<List<T>> splitBySubtasks(OperatorState operatorState, Function<OperatorSubtaskState, StateObjectCollection<T>> extractHandle) {
        ArrayList<List<T>> statePerSubtask = new ArrayList<List<T>>(operatorState.getParallelism());
        for (int subTaskIndex = 0; subTaskIndex < operatorState.getParallelism(); ++subTaskIndex) {
            OperatorSubtaskState subtaskState = operatorState.getState(subTaskIndex);
            statePerSubtask.add(subtaskState == null ? Collections.emptyList() : extractHandle.apply(subtaskState).asList());
        }
        return statePerSubtask;
    }

    public static List<KeyedStateHandle> getManagedKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList subtaskKeyedStateHandles = null;
        for (int i = 0; i < parallelism; ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> keyedStateHandles = operatorState.getState(i).getManagedKeyedState();
            if (subtaskKeyedStateHandles == null) {
                subtaskKeyedStateHandles = new ArrayList(parallelism * keyedStateHandles.size());
            }
            StateAssignmentOperation.extractIntersectingState(keyedStateHandles, subtaskKeyGroupRange, subtaskKeyedStateHandles);
        }
        return subtaskKeyedStateHandles;
    }

    public static List<KeyedStateHandle> getRawKeyedStateHandles(OperatorState operatorState, KeyGroupRange subtaskKeyGroupRange) {
        int parallelism = operatorState.getParallelism();
        ArrayList extractedKeyedStateHandles = null;
        for (int i = 0; i < parallelism; ++i) {
            if (operatorState.getState(i) == null) continue;
            StateObjectCollection<KeyedStateHandle> rawKeyedState = operatorState.getState(i).getRawKeyedState();
            if (extractedKeyedStateHandles == null) {
                extractedKeyedStateHandles = new ArrayList(parallelism * rawKeyedState.size());
            }
            StateAssignmentOperation.extractIntersectingState(rawKeyedState, subtaskKeyGroupRange, extractedKeyedStateHandles);
        }
        return extractedKeyedStateHandles;
    }

    @VisibleForTesting
    public static void extractIntersectingState(Collection<? extends KeyedStateHandle> originalSubtaskStateHandles, KeyGroupRange rangeToExtract, List<KeyedStateHandle> extractedStateCollector) {
        for (KeyedStateHandle keyedStateHandle : originalSubtaskStateHandles) {
            KeyedStateHandle intersectedKeyedStateHandle;
            if (keyedStateHandle == null || (intersectedKeyedStateHandle = keyedStateHandle.getIntersection(rangeToExtract)) == null) continue;
            extractedStateCollector.add(intersectedKeyedStateHandle);
        }
    }

    public static List<KeyGroupRange> createKeyGroupPartitions(int numberKeyGroups, int parallelism) {
        Preconditions.checkArgument((numberKeyGroups >= parallelism ? 1 : 0) != 0);
        ArrayList<KeyGroupRange> result = new ArrayList<KeyGroupRange>(parallelism);
        for (int i = 0; i < parallelism; ++i) {
            result.add(KeyGroupRangeAssignment.computeKeyGroupRangeForOperatorIndex(numberKeyGroups, parallelism, i));
        }
        return result;
    }

    private static void checkParallelismPreconditions(OperatorState operatorState, ExecutionJobVertex executionJobVertex) {
        if (operatorState.getMaxParallelism() < executionJobVertex.getParallelism()) {
            throw new IllegalStateException("The state for task " + executionJobVertex.getJobVertexId() + " can not be restored. The maximum parallelism (" + operatorState.getMaxParallelism() + ") of the restored state is lower than the configured parallelism (" + executionJobVertex.getParallelism() + "). Please reduce the parallelism of the task to be lower or equal to the maximum parallelism.");
        }
        if (operatorState.getMaxParallelism() != executionJobVertex.getMaxParallelism()) {
            if (!executionJobVertex.isMaxParallelismConfigured()) {
                LOG.debug("Overriding maximum parallelism for JobVertex {} from {} to {}", new Object[]{executionJobVertex.getJobVertexId(), executionJobVertex.getMaxParallelism(), operatorState.getMaxParallelism()});
                executionJobVertex.setMaxParallelism(operatorState.getMaxParallelism());
            } else {
                throw new IllegalStateException("The maximum parallelism (" + operatorState.getMaxParallelism() + ") with which the latest checkpoint of the execution job vertex " + executionJobVertex + " has been taken and the current maximum parallelism (" + executionJobVertex.getMaxParallelism() + ") changed. This is currently not supported.");
            }
        }
    }

    private static void checkStateMappingCompleteness(boolean allowNonRestoredState, Map<OperatorID, OperatorState> operatorStates, Set<ExecutionJobVertex> tasks) {
        HashSet<OperatorID> allOperatorIDs = new HashSet<OperatorID>();
        for (ExecutionJobVertex executionJobVertex : tasks) {
            for (OperatorIDPair operatorIDPair : executionJobVertex.getOperatorIDs()) {
                allOperatorIDs.add(operatorIDPair.getGeneratedOperatorID());
                operatorIDPair.getUserDefinedOperatorID().ifPresent(allOperatorIDs::add);
            }
        }
        for (Map.Entry entry : operatorStates.entrySet()) {
            OperatorState operatorState = (OperatorState)entry.getValue();
            if (allOperatorIDs.contains(entry.getKey())) continue;
            if (allowNonRestoredState) {
                LOG.info("Skipped checkpoint state for operator {}.", (Object)operatorState.getOperatorID());
                continue;
            }
            throw new IllegalStateException("There is no operator for the state " + (Object)((Object)operatorState.getOperatorID()));
        }
    }

    public static <T> Map<OperatorInstanceID, List<T>> applyRepartitioner(OperatorID operatorID, OperatorStateRepartitioner<T> opStateRepartitioner, List<List<T>> chainOpParallelStates, int oldParallelism, int newParallelism) {
        List<List<T>> states = StateAssignmentOperation.applyRepartitioner(opStateRepartitioner, chainOpParallelStates, oldParallelism, newParallelism);
        return StateAssignmentOperation.toInstanceMap(operatorID, states);
    }

    private static <T> Map<OperatorInstanceID, List<T>> toInstanceMap(OperatorID operatorID, List<List<T>> states) {
        HashMap<OperatorInstanceID, List<T>> result = new HashMap<OperatorInstanceID, List<T>>(states.size());
        for (int subtaskIndex = 0; subtaskIndex < states.size(); ++subtaskIndex) {
            Preconditions.checkNotNull((Object)(states.get(subtaskIndex) != null ? 1 : 0), (String)"states.get(subtaskIndex) is null");
            result.put(OperatorInstanceID.of(subtaskIndex, operatorID), states.get(subtaskIndex));
        }
        return result;
    }

    public static <T> List<List<T>> applyRepartitioner(OperatorStateRepartitioner<T> opStateRepartitioner, List<List<T>> chainOpParallelStates, int oldParallelism, int newParallelism) {
        if (chainOpParallelStates == null) {
            return Collections.emptyList();
        }
        return opStateRepartitioner.repartitionState(chainOpParallelStates, oldParallelism, newParallelism);
    }
}

