package org.apache.flink.runtime.scheduler;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory;
import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleMaster;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.class */
public class SsgNetworkMemoryCalculationUtils {
    public static void enrichNetworkMemory(SlotSharingGroup slotSharingGroup, Function<JobVertexID, ExecutionJobVertex> function, ShuffleMaster<?> shuffleMaster) {
        ResourceProfile resourceProfile = slotSharingGroup.getResourceProfile();
        if (resourceProfile.equals(ResourceProfile.UNKNOWN) || !resourceProfile.getNetworkMemory().equals(MemorySize.ZERO)) {
            return;
        }
        MemorySize memorySize = MemorySize.ZERO;
        Iterator<JobVertexID> it = slotSharingGroup.getJobVertexIds().iterator();
        while (it.hasNext()) {
            memorySize = memorySize.add(shuffleMaster.computeShuffleMemorySizeForTask(buildTaskInputsOutputsDescriptor(function.apply(it.next()), function)));
        }
        slotSharingGroup.setResourceProfile(ResourceProfile.newBuilder().setCpuCores(resourceProfile.getCpuCores()).setTaskHeapMemory(resourceProfile.getTaskHeapMemory()).setTaskOffHeapMemory(resourceProfile.getTaskOffHeapMemory()).setManagedMemory(resourceProfile.getManagedMemory()).setNetworkMemory(memorySize).setExtendedResources(resourceProfile.getExtendedResources().values()).build());
    }

    private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(ExecutionJobVertex executionJobVertex, Function<JobVertexID, ExecutionJobVertex> function) {
        Map<IntermediateDataSetID, Integer> maxInputChannelNums;
        Map<IntermediateDataSetID, Integer> maxSubpartitionNums;
        if (executionJobVertex.getGraph().isDynamic()) {
            maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(executionJobVertex);
            maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(executionJobVertex);
        } else {
            maxInputChannelNums = getMaxInputChannelNums(executionJobVertex);
            maxSubpartitionNums = getMaxSubpartitionNums(executionJobVertex, function);
        }
        return TaskInputsOutputsDescriptor.from(maxInputChannelNums, maxSubpartitionNums, getPartitionTypes(executionJobVertex.getJobVertex()));
    }

    private static Map<IntermediateDataSetID, Integer> getMaxInputChannelNums(ExecutionJobVertex executionJobVertex) {
        HashMap hashMap = new HashMap();
        List<JobEdge> inputs = executionJobVertex.getJobVertex().getInputs();
        for (int i = 0; i < inputs.size(); i++) {
            JobEdge jobEdge = inputs.get(i);
            IntermediateResult intermediateResult = executionJobVertex.getInputs().get(i);
            Preconditions.checkState(intermediateResult.getId().equals(jobEdge.getSourceId()));
            hashMap.put(intermediateResult.getId(), Integer.valueOf(EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(executionJobVertex.getParallelism(), intermediateResult.getNumberOfAssignedPartitions(), jobEdge.getDistributionPattern())));
        }
        return hashMap;
    }

    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNums(ExecutionJobVertex executionJobVertex, Function<JobVertexID, ExecutionJobVertex> function) {
        HashMap hashMap = new HashMap();
        List<IntermediateDataSet> producedDataSets = executionJobVertex.getJobVertex().getProducedDataSets();
        for (int i = 0; i < producedDataSets.size(); i++) {
            IntermediateDataSet intermediateDataSet = producedDataSets.get(i);
            JobEdge jobEdge = (JobEdge) Preconditions.checkNotNull(intermediateDataSet.getConsumer());
            hashMap.put(intermediateDataSet.getId(), Integer.valueOf(EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(executionJobVertex.getParallelism(), function.apply(jobEdge.getTarget().getID()).getParallelism(), jobEdge.getDistributionPattern())));
        }
        return hashMap;
    }

    private static Map<IntermediateDataSetID, ResultPartitionType> getPartitionTypes(JobVertex jobVertex) {
        HashMap hashMap = new HashMap();
        jobVertex.getProducedDataSets().forEach(intermediateDataSet -> {
        });
        return hashMap;
    }

    @VisibleForTesting
    static Map<IntermediateDataSetID, Integer> getMaxInputChannelNumsForDynamicGraph(ExecutionJobVertex executionJobVertex) {
        HashMap hashMap = new HashMap();
        for (ExecutionVertex executionVertex : executionJobVertex.getTaskVertices()) {
            for (ConsumedPartitionGroup consumedPartitionGroup : executionVertex.getAllConsumedPartitionGroups()) {
                hashMap.merge(consumedPartitionGroup.getIntermediateDataSetID(), Integer.valueOf(TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(executionJobVertex.getGraph().getResultPartitionOrThrow(consumedPartitionGroup.getFirst()), executionVertex.getParallelSubtaskIndex()).size() * consumedPartitionGroup.size()), (v0, v1) -> {
                    return Integer.max(v0, v1);
                });
            }
        }
        return hashMap;
    }

    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNumsForDynamicGraph(ExecutionJobVertex executionJobVertex) {
        HashMap hashMap = new HashMap();
        for (IntermediateResult intermediateResult : executionJobVertex.getProducedDataSets()) {
            hashMap.put(intermediateResult.getId(), Integer.valueOf(((Integer) Arrays.stream(intermediateResult.getPartitions()).map((v0) -> {
                return v0.getNumberOfSubpartitions();
            }).reduce(0, (v0, v1) -> {
                return Integer.max(v0, v1);
            })).intValue()));
        }
        return hashMap;
    }

    private SsgNetworkMemoryCalculationUtils() {
    }
}
