package org.apache.samza.operators.impl;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.samza.config.Config;
import org.apache.samza.container.TaskContextImpl;
import org.apache.samza.job.model.JobModel;
import org.apache.samza.operators.KV;
import org.apache.samza.operators.StreamGraphImpl;
import org.apache.samza.operators.functions.JoinFunction;
import org.apache.samza.operators.functions.PartialJoinFunction;
import org.apache.samza.operators.impl.store.TimestampedValue;
import org.apache.samza.operators.spec.InputOperatorSpec;
import org.apache.samza.operators.spec.JoinOperatorSpec;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.operators.spec.OutputOperatorSpec;
import org.apache.samza.operators.spec.PartitionByOperatorSpec;
import org.apache.samza.operators.spec.SinkOperatorSpec;
import org.apache.samza.operators.spec.StreamOperatorSpec;
import org.apache.samza.operators.spec.WindowOperatorSpec;
import org.apache.samza.storage.kv.KeyValueStore;
import org.apache.samza.system.StreamSpec;
import org.apache.samza.system.SystemStream;
import org.apache.samza.task.TaskContext;
import org.apache.samza.util.Clock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/samza/operators/impl/OperatorImplGraph.class */
public class OperatorImplGraph {
    private static final Logger LOG = LoggerFactory.getLogger(OperatorImplGraph.class);
    private final Map<String, OperatorImpl> operatorImpls = new LinkedHashMap();
    private final Map<SystemStream, InputOperatorImpl> inputOperators = new HashMap();
    private final Map<String, KV<PartialJoinFunction, PartialJoinFunction>> joinFunctions = new HashMap();
    private final Clock clock;

    public OperatorImplGraph(StreamGraphImpl streamGraphImpl, Config config, TaskContext taskContext, Clock clock) {
        this.clock = clock;
        TaskContextImpl taskContextImpl = (TaskContextImpl) taskContext;
        Map<SystemStream, Integer> producerTaskCountForIntermediateStreams = hasIntermediateStreams(streamGraphImpl) ? getProducerTaskCountForIntermediateStreams(getStreamToConsumerTasks(taskContextImpl.getJobModel()), getIntermediateToInputStreamsMap(streamGraphImpl)) : Collections.EMPTY_MAP;
        producerTaskCountForIntermediateStreams.forEach((systemStream, num) -> {
            LOG.info("{} has {} producer tasks.", systemStream, num);
        });
        taskContextImpl.registerObject(EndOfStreamStates.class.getName(), new EndOfStreamStates(taskContext.getSystemStreamPartitions(), producerTaskCountForIntermediateStreams));
        taskContextImpl.registerObject(WatermarkStates.class.getName(), new WatermarkStates(taskContext.getSystemStreamPartitions(), producerTaskCountForIntermediateStreams));
        streamGraphImpl.getInputOperators().forEach((streamSpec, inputOperatorSpec) -> {
            SystemStream systemStream2 = new SystemStream(streamSpec.getSystemName(), streamSpec.getPhysicalName());
            this.inputOperators.put(systemStream2, (InputOperatorImpl) createAndRegisterOperatorImpl(null, inputOperatorSpec, systemStream2, config, taskContext));
        });
    }

    public InputOperatorImpl getInputOperator(SystemStream systemStream) {
        return this.inputOperators.get(systemStream);
    }

    public void close() {
        Lists.reverse(new ArrayList(this.operatorImpls.values())).forEach((v0) -> {
            v0.close();
        });
    }

    public Collection<InputOperatorImpl> getAllInputOperators() {
        return Collections.unmodifiableCollection(this.inputOperators.values());
    }

    OperatorImpl createAndRegisterOperatorImpl(OperatorSpec operatorSpec, OperatorSpec operatorSpec2, SystemStream systemStream, Config config, TaskContext taskContext) {
        if (this.operatorImpls.containsKey(operatorSpec2.getOpId()) && !(operatorSpec2 instanceof JoinOperatorSpec)) {
            OperatorImpl operatorImpl = this.operatorImpls.get(operatorSpec2.getOpId());
            operatorImpl.registerInputStream(systemStream);
            operatorSpec2.getRegisteredOperatorSpecs().forEach(operatorSpec3 -> {
                createAndRegisterOperatorImpl(operatorSpec2, operatorSpec3, systemStream, config, taskContext);
            });
            return operatorImpl;
        }
        OperatorImpl createOperatorImpl = createOperatorImpl(operatorSpec, operatorSpec2, config, taskContext);
        createOperatorImpl.init(config, taskContext);
        createOperatorImpl.registerInputStream(systemStream);
        this.operatorImpls.put(createOperatorImpl.getOpImplId(), createOperatorImpl);
        operatorSpec2.getRegisteredOperatorSpecs().forEach(operatorSpec4 -> {
            createOperatorImpl.registerNextOperator(createAndRegisterOperatorImpl(operatorSpec2, operatorSpec4, systemStream, config, taskContext));
        });
        return createOperatorImpl;
    }

    OperatorImpl createOperatorImpl(OperatorSpec operatorSpec, OperatorSpec operatorSpec2, Config config, TaskContext taskContext) {
        if (operatorSpec2 instanceof InputOperatorSpec) {
            return new InputOperatorImpl((InputOperatorSpec) operatorSpec2);
        }
        if (operatorSpec2 instanceof StreamOperatorSpec) {
            return new StreamOperatorImpl((StreamOperatorSpec) operatorSpec2, config, taskContext);
        }
        if (operatorSpec2 instanceof SinkOperatorSpec) {
            return new SinkOperatorImpl((SinkOperatorSpec) operatorSpec2, config, taskContext);
        }
        if (operatorSpec2 instanceof OutputOperatorSpec) {
            return new OutputOperatorImpl((OutputOperatorSpec) operatorSpec2, config, taskContext);
        }
        if (operatorSpec2 instanceof PartitionByOperatorSpec) {
            return new PartitionByOperatorImpl((PartitionByOperatorSpec) operatorSpec2, config, taskContext);
        }
        if (operatorSpec2 instanceof WindowOperatorSpec) {
            return new WindowOperatorImpl((WindowOperatorSpec) operatorSpec2, this.clock);
        }
        if (operatorSpec2 instanceof JoinOperatorSpec) {
            return createPartialJoinOperatorImpl(operatorSpec, (JoinOperatorSpec) operatorSpec2, config, taskContext, this.clock);
        }
        throw new IllegalArgumentException(String.format("Unsupported OperatorSpec: %s", operatorSpec2.getClass().getName()));
    }

    private PartialJoinOperatorImpl createPartialJoinOperatorImpl(OperatorSpec operatorSpec, JoinOperatorSpec joinOperatorSpec, Config config, TaskContext taskContext, Clock clock) {
        KV<PartialJoinFunction, PartialJoinFunction> orCreatePartialJoinFunctions = getOrCreatePartialJoinFunctions(joinOperatorSpec);
        return joinOperatorSpec.getLeftInputOpSpec().equals(operatorSpec) ? new PartialJoinOperatorImpl(joinOperatorSpec, true, (PartialJoinFunction) orCreatePartialJoinFunctions.getKey(), (PartialJoinFunction) orCreatePartialJoinFunctions.getValue(), config, taskContext, clock) : new PartialJoinOperatorImpl(joinOperatorSpec, false, (PartialJoinFunction) orCreatePartialJoinFunctions.getValue(), (PartialJoinFunction) orCreatePartialJoinFunctions.getKey(), config, taskContext, clock);
    }

    private KV<PartialJoinFunction, PartialJoinFunction> getOrCreatePartialJoinFunctions(JoinOperatorSpec joinOperatorSpec) {
        return this.joinFunctions.computeIfAbsent(joinOperatorSpec.getOpId(), str -> {
            return KV.of(createLeftJoinFn(joinOperatorSpec), createRightJoinFn(joinOperatorSpec));
        });
    }

    private PartialJoinFunction<Object, Object, Object, Object> createLeftJoinFn(final JoinOperatorSpec joinOperatorSpec) {
        return new PartialJoinFunction<Object, Object, Object, Object>() { // from class: org.apache.samza.operators.impl.OperatorImplGraph.1
            private final JoinFunction joinFn;
            private KeyValueStore<Object, TimestampedValue<Object>> leftStreamState;

            {
                this.joinFn = joinOperatorSpec.getJoinFn();
            }

            @Override // org.apache.samza.operators.functions.PartialJoinFunction
            public Object apply(Object obj, Object obj2) {
                return this.joinFn.apply(obj, obj2);
            }

            @Override // org.apache.samza.operators.functions.PartialJoinFunction
            public Object getKey(Object obj) {
                return this.joinFn.getFirstKey(obj);
            }

            @Override // org.apache.samza.operators.functions.PartialJoinFunction
            public KeyValueStore<Object, TimestampedValue<Object>> getState() {
                return this.leftStreamState;
            }

            public void init(Config config, TaskContext taskContext) {
                this.leftStreamState = (KeyValueStore) taskContext.getStore(joinOperatorSpec.getLeftOpId());
                this.joinFn.init(config, taskContext);
            }

            public void close() {
                this.joinFn.close();
            }
        };
    }

    private PartialJoinFunction<Object, Object, Object, Object> createRightJoinFn(final JoinOperatorSpec joinOperatorSpec) {
        return new PartialJoinFunction<Object, Object, Object, Object>() { // from class: org.apache.samza.operators.impl.OperatorImplGraph.2
            private final JoinFunction joinFn;
            private KeyValueStore<Object, TimestampedValue<Object>> rightStreamState;

            {
                this.joinFn = joinOperatorSpec.getJoinFn();
            }

            @Override // org.apache.samza.operators.functions.PartialJoinFunction
            public Object apply(Object obj, Object obj2) {
                return this.joinFn.apply(obj2, obj);
            }

            @Override // org.apache.samza.operators.functions.PartialJoinFunction
            public Object getKey(Object obj) {
                return this.joinFn.getSecondKey(obj);
            }

            public void init(Config config, TaskContext taskContext) {
                this.rightStreamState = (KeyValueStore) taskContext.getStore(joinOperatorSpec.getRightOpId());
            }

            @Override // org.apache.samza.operators.functions.PartialJoinFunction
            public KeyValueStore<Object, TimestampedValue<Object>> getState() {
                return this.rightStreamState;
            }
        };
    }

    private boolean hasIntermediateStreams(StreamGraphImpl streamGraphImpl) {
        return !Collections.disjoint(streamGraphImpl.getInputOperators().keySet(), streamGraphImpl.getOutputStreams().keySet());
    }

    static Map<SystemStream, Integer> getProducerTaskCountForIntermediateStreams(Multimap<SystemStream, String> multimap, Multimap<SystemStream, SystemStream> multimap2) {
        HashMap hashMap = new HashMap();
        multimap2.asMap().entrySet().forEach(entry -> {
            hashMap.put(entry.getKey(), Integer.valueOf(((Set) ((Collection) entry.getValue()).stream().flatMap(systemStream -> {
                return multimap.get(systemStream).stream();
            }).collect(Collectors.toSet())).size()));
        });
        return hashMap;
    }

    static Multimap<SystemStream, String> getStreamToConsumerTasks(JobModel jobModel) {
        HashMultimap create = HashMultimap.create();
        jobModel.getContainers().values().forEach(containerModel -> {
            containerModel.getTasks().values().forEach(taskModel -> {
                taskModel.getSystemStreamPartitions().forEach(systemStreamPartition -> {
                    create.put(systemStreamPartition.getSystemStream(), taskModel.getTaskName().getTaskName());
                });
            });
        });
        return create;
    }

    static Multimap<SystemStream, SystemStream> getIntermediateToInputStreamsMap(StreamGraphImpl streamGraphImpl) {
        HashMultimap create = HashMultimap.create();
        streamGraphImpl.getInputOperators().entrySet().stream().forEach(entry -> {
            computeOutputToInput(((StreamSpec) entry.getKey()).toSystemStream(), (OperatorSpec) entry.getValue(), create);
        });
        return create;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void computeOutputToInput(SystemStream systemStream, OperatorSpec operatorSpec, Multimap<SystemStream, SystemStream> multimap) {
        if (operatorSpec instanceof PartitionByOperatorSpec) {
            multimap.put(((PartitionByOperatorSpec) operatorSpec).getOutputStream().getStreamSpec().toSystemStream(), systemStream);
        } else {
            operatorSpec.getRegisteredOperatorSpecs().forEach(operatorSpec2 -> {
                computeOutputToInput(systemStream, operatorSpec2, multimap);
            });
        }
    }
}
