package org.apache.samza.execution;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Multimap;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import org.apache.samza.SamzaException;
import org.apache.samza.config.ApplicationConfig;
import org.apache.samza.config.ClusterManagerConfig;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.operators.StreamGraphImpl;
import org.apache.samza.operators.spec.JoinOperatorSpec;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.system.StreamSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samza/execution/ExecutionPlanner.class */
public class ExecutionPlanner {
    private static final Logger log = LoggerFactory.getLogger(ExecutionPlanner.class);
    static final int MAX_INFERRED_PARTITIONS = 256;
    private final Config config;
    private final StreamManager streamManager;

    public ExecutionPlanner(Config config, StreamManager streamManager) {
        this.config = config;
        this.streamManager = streamManager;
    }

    public ExecutionPlan plan(StreamGraphImpl streamGraphImpl) throws Exception {
        validateConfig();
        JobGraph createJobGraph = createJobGraph(streamGraphImpl);
        updateExistingPartitions(createJobGraph, this.streamManager);
        if (!createJobGraph.getIntermediateStreamEdges().isEmpty()) {
            calculatePartitions(streamGraphImpl, createJobGraph);
        }
        return createJobGraph;
    }

    private void validateConfig() {
        ApplicationConfig applicationConfig = new ApplicationConfig(this.config);
        ClusterManagerConfig clusterManagerConfig = new ClusterManagerConfig(this.config);
        if (applicationConfig.getAppMode() == ApplicationConfig.ApplicationMode.BATCH && clusterManagerConfig.getHostAffinityEnabled()) {
            throw new SamzaException("Host affinity is not supported in batch mode. Please configure job.host-affinity.enabled=false.");
        }
    }

    JobGraph createJobGraph(StreamGraphImpl streamGraphImpl) {
        JobGraph jobGraph = new JobGraph(this.config);
        HashSet hashSet = new HashSet(streamGraphImpl.getInputOperators().keySet());
        HashSet hashSet2 = new HashSet(streamGraphImpl.getOutputStreams().keySet());
        HashSet hashSet3 = new HashSet(hashSet);
        hashSet3.retainAll(hashSet2);
        hashSet.removeAll(hashSet3);
        hashSet2.removeAll(hashSet3);
        JobNode orCreateJobNode = jobGraph.getOrCreateJobNode((String) this.config.get(JobConfig.JOB_NAME()), this.config.get(JobConfig.JOB_ID(), "1"), streamGraphImpl);
        hashSet.forEach(streamSpec -> {
            jobGraph.addSource(streamSpec, orCreateJobNode);
        });
        hashSet2.forEach(streamSpec2 -> {
            jobGraph.addSink(streamSpec2, orCreateJobNode);
        });
        hashSet3.forEach(streamSpec3 -> {
            jobGraph.addIntermediateStream(streamSpec3, orCreateJobNode, orCreateJobNode);
        });
        jobGraph.validate();
        return jobGraph;
    }

    void calculatePartitions(StreamGraphImpl streamGraphImpl, JobGraph jobGraph) {
        calculateJoinInputPartitions(streamGraphImpl, jobGraph);
        calculateIntStreamPartitions(jobGraph, this.config);
        validatePartitions(jobGraph);
    }

    static void updateExistingPartitions(JobGraph jobGraph, StreamManager streamManager) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(jobGraph.getSources());
        hashSet.addAll(jobGraph.getSinks());
        HashMultimap create = HashMultimap.create();
        hashSet.forEach(streamEdge -> {
            create.put(streamEdge.getSystemStream().getSystem(), streamEdge);
        });
        for (Map.Entry entry : create.asMap().entrySet()) {
            String str = (String) entry.getKey();
            Collection collection = (Collection) entry.getValue();
            HashMap hashMap = new HashMap();
            collection.forEach(streamEdge2 -> {
            });
            streamManager.getStreamPartitionCounts(str, hashMap.keySet()).forEach((str2, num) -> {
                ((StreamEdge) hashMap.get(str2)).setPartitionCount(num.intValue());
                log.debug("Partition count is {} for stream {}", num, str2);
            });
        }
    }

    static void calculateJoinInputPartitions(StreamGraphImpl streamGraphImpl, JobGraph jobGraph) {
        HashMultimap create = HashMultimap.create();
        HashMultimap create2 = HashMultimap.create();
        LinkedList linkedList = new LinkedList();
        HashSet hashSet = new HashSet();
        streamGraphImpl.getInputOperators().entrySet().forEach(entry -> {
            findReachableJoins((OperatorSpec) entry.getValue(), jobGraph.getOrCreateStreamEdge((StreamSpec) entry.getKey()), create, create2, linkedList, hashSet);
        });
        while (!linkedList.isEmpty()) {
            OperatorSpec operatorSpec = (OperatorSpec) linkedList.poll();
            int i = -1;
            for (StreamEdge streamEdge : create.get(operatorSpec)) {
                int partitionCount = streamEdge.getPartitionCount();
                if (partitionCount != -1) {
                    if (i == -1) {
                        i = partitionCount;
                    } else if (i != partitionCount) {
                        throw new SamzaException(String.format("Unable to resolve input partitions of stream %s for join. Expected: %d, Actual: %d", streamEdge.getFormattedSystemStream(), Integer.valueOf(i), Integer.valueOf(partitionCount)));
                    }
                }
            }
            for (StreamEdge streamEdge2 : create.get(operatorSpec)) {
                if (streamEdge2.getPartitionCount() <= 0) {
                    streamEdge2.setPartitionCount(i);
                    for (OperatorSpec operatorSpec2 : create2.get(streamEdge2)) {
                        if (!hashSet.contains(operatorSpec2)) {
                            linkedList.add(operatorSpec2);
                            hashSet.add(operatorSpec2);
                        }
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void findReachableJoins(OperatorSpec operatorSpec, StreamEdge streamEdge, Multimap<OperatorSpec, StreamEdge> multimap, Multimap<StreamEdge, OperatorSpec> multimap2, Queue<OperatorSpec> queue, Set<OperatorSpec> set) {
        if (operatorSpec instanceof JoinOperatorSpec) {
            multimap.put(operatorSpec, streamEdge);
            multimap2.put(streamEdge, operatorSpec);
            if (!set.contains(operatorSpec) && streamEdge.getPartitionCount() > 0) {
                queue.add(operatorSpec);
                set.add(operatorSpec);
            }
        }
        Iterator it = operatorSpec.getRegisteredOperatorSpecs().iterator();
        while (it.hasNext()) {
            findReachableJoins((OperatorSpec) it.next(), streamEdge, multimap, multimap2, queue, set);
        }
    }

    private static void calculateIntStreamPartitions(JobGraph jobGraph, Config config) {
        int i = config.getInt(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS(), -1);
        if (i < 0) {
            i = Math.max(maxPartition(jobGraph.getSources()), maxPartition(jobGraph.getSinks()));
            if (i > MAX_INFERRED_PARTITIONS) {
                i = MAX_INFERRED_PARTITIONS;
                log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.", Integer.valueOf(i), Integer.valueOf(MAX_INFERRED_PARTITIONS)));
            }
        }
        for (StreamEdge streamEdge : jobGraph.getIntermediateStreamEdges()) {
            if (streamEdge.getPartitionCount() <= 0) {
                streamEdge.setPartitionCount(i);
            }
        }
    }

    private static void validatePartitions(JobGraph jobGraph) {
        for (StreamEdge streamEdge : jobGraph.getIntermediateStreamEdges()) {
            if (streamEdge.getPartitionCount() <= 0) {
                throw new SamzaException(String.format("Failure to assign the partitions to Stream %s", streamEdge.getFormattedSystemStream()));
            }
        }
    }

    static int maxPartition(Collection<StreamEdge> collection) {
        return ((Integer) collection.stream().map((v0) -> {
            return v0.getPartitionCount();
        }).reduce((v0, v1) -> {
            return Integer.max(v0, v1);
        }).orElse(-1)).intValue();
    }
}
