package org.apache.samza.execution;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.HashMultimap;
import java.util.Collection;
import java.util.HashSet;
import java.util.Optional;
import java.util.stream.Stream;
import org.apache.samza.SamzaException;
import org.apache.samza.config.Config;
import org.apache.samza.config.JobConfig;
import org.apache.samza.execution.ExecutionPlanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apache/samza/execution/IntermediateStreamManager.class */
public class IntermediateStreamManager {
    private static final Logger log = LoggerFactory.getLogger(IntermediateStreamManager.class);
    private final Config config;

    @VisibleForTesting
    static final int MAX_INFERRED_PARTITIONS = 256;

    /* JADX INFO: Access modifiers changed from: package-private */
    public IntermediateStreamManager(Config config) {
        this.config = config;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void calculatePartitions(JobGraph jobGraph, Collection<ExecutionPlanner.StreamSet> collection) {
        setJoinedIntermediateStreamPartitions(collection);
        setIntermediateStreamPartitions(jobGraph);
        validateIntermediateStreamPartitions(jobGraph);
    }

    private void setIntermediateStreamPartitions(JobGraph jobGraph) {
        int i = this.config.getInt(JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS, -1);
        if (i == -1) {
            i = Math.max(maxPartitions(jobGraph.getInputStreams()), maxPartitions(jobGraph.getOutputStreams()));
            if (i > MAX_INFERRED_PARTITIONS) {
                i = MAX_INFERRED_PARTITIONS;
                log.warn(String.format("Inferred intermediate stream partition count %d is greater than the max %d. Using the max.", Integer.valueOf(i), Integer.valueOf(MAX_INFERRED_PARTITIONS)));
            }
        } else {
            if (i <= 0) {
                throw new SamzaException(String.format("Invalid value %d specified for config property %s", Integer.valueOf(i), JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS));
            }
            log.info("Using partition count value {} specified for config property {}", Integer.valueOf(i), JobConfig.JOB_INTERMEDIATE_STREAM_PARTITIONS);
        }
        for (StreamEdge streamEdge : jobGraph.getIntermediateStreamEdges()) {
            if (streamEdge.getPartitionCount() <= 0) {
                log.info("Set the partition count for intermediate stream {} to {}.", streamEdge.getName(), Integer.valueOf(i));
                streamEdge.setPartitionCount(i);
            }
        }
    }

    private static void setJoinedIntermediateStreamPartitions(Collection<ExecutionPlanner.StreamSet> collection) {
        HashMultimap create = HashMultimap.create();
        for (ExecutionPlanner.StreamSet streamSet : collection) {
            for (StreamEdge streamEdge : streamSet.getStreamEdges()) {
                if (streamEdge.getPartitionCount() == -1) {
                    create.put(streamEdge, streamSet);
                }
            }
        }
        HashSet hashSet = new HashSet(collection);
        HashSet hashSet2 = new HashSet();
        while (!hashSet.isEmpty()) {
            ExecutionPlanner.StreamSet streamSet2 = (ExecutionPlanner.StreamSet) hashSet.iterator().next();
            hashSet.remove(streamSet2);
            Optional<StreamEdge> findAny = streamSet2.getStreamEdges().stream().filter(streamEdge2 -> {
                return streamEdge2.getPartitionCount() != -1;
            }).findAny();
            if (findAny.isPresent()) {
                hashSet2.add(streamSet2);
                int partitionCount = findAny.get().getPartitionCount();
                for (StreamEdge streamEdge3 : streamSet2.getStreamEdges()) {
                    if (streamEdge3.getPartitionCount() == -1) {
                        streamEdge3.setPartitionCount(partitionCount);
                        Stream filter = create.get(streamEdge3).stream().filter(streamSet3 -> {
                            return !hashSet2.contains(streamSet3);
                        });
                        hashSet.getClass();
                        filter.forEach((v1) -> {
                            r1.add(v1);
                        });
                    }
                }
            }
        }
    }

    private static void validateIntermediateStreamPartitions(JobGraph jobGraph) {
        for (StreamEdge streamEdge : jobGraph.getIntermediateStreamEdges()) {
            if (streamEdge.getPartitionCount() <= 0) {
                throw new SamzaException(String.format("Failed to assign valid partition count to Stream %s", streamEdge.getName()));
            }
        }
    }

    static int maxPartitions(Collection<StreamEdge> collection) {
        return collection.stream().mapToInt((v0) -> {
            return v0.getPartitionCount();
        }).max().orElse(-1);
    }
}
