/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.core.construction.graph;

import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.Environments;
import org.apache.beam.runners.core.construction.NativeTransforms;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.graph.Networks;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.SideInputReference;
import org.apache.beam.runners.core.construction.graph.TimerReference;
import org.apache.beam.runners.core.construction.graph.UserStateReference;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.grpc.v1p13p1.com.google.protobuf.ProtocolStringList;
import org.apache.beam.vendor.guava.v20_0.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Sets;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.MutableNetwork;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.Network;
import org.apache.beam.vendor.guava.v20_0.com.google.common.graph.NetworkBuilder;

public class QueryablePipeline {
    private final RunnerApi.Components components;
    private final Network<PipelineNode, PipelineEdge> pipelineNetwork;
    private static final Set<String> PRIMITIVE_URNS = ImmutableSet.of(PTransformTranslation.PAR_DO_TRANSFORM_URN, PTransformTranslation.FLATTEN_TRANSFORM_URN, PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, PTransformTranslation.IMPULSE_TRANSFORM_URN, PTransformTranslation.ASSIGN_WINDOWS_TRANSFORM_URN, PTransformTranslation.TEST_STREAM_TRANSFORM_URN, new String[]{PTransformTranslation.MAP_WINDOWS_TRANSFORM_URN, PTransformTranslation.READ_TRANSFORM_URN, PTransformTranslation.CREATE_VIEW_TRANSFORM_URN, PTransformTranslation.COMBINE_PER_KEY_PRECOMBINE_TRANSFORM_URN, PTransformTranslation.COMBINE_PER_KEY_MERGE_ACCUMULATORS_TRANSFORM_URN, PTransformTranslation.COMBINE_PER_KEY_EXTRACT_OUTPUTS_TRANSFORM_URN, PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN, PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN});

    public static QueryablePipeline forPrimitivesIn(RunnerApi.Components components) {
        return new QueryablePipeline(QueryablePipeline.getPrimitiveTransformIds(components), components);
    }

    public static QueryablePipeline forPipeline(RunnerApi.Pipeline p) {
        return QueryablePipeline.forTransforms(p.getRootTransformIdsList(), p.getComponents());
    }

    public static QueryablePipeline forTransforms(Collection<String> transformIds, RunnerApi.Components components) {
        return new QueryablePipeline(transformIds, components);
    }

    private QueryablePipeline(Collection<String> transformIds, RunnerApi.Components components) {
        this.components = components;
        this.pipelineNetwork = this.buildNetwork(transformIds, this.components);
    }

    @VisibleForTesting
    static Collection<String> getPrimitiveTransformIds(RunnerApi.Components components) {
        LinkedHashSet<String> ids = new LinkedHashSet<String>();
        for (Map.Entry<String, RunnerApi.PTransform> transformEntry : components.getTransformsMap().entrySet()) {
            RunnerApi.PTransform transform = transformEntry.getValue();
            boolean isPrimitive = QueryablePipeline.isPrimitiveTransform(transform);
            if (!isPrimitive) continue;
            ArrayDeque<String> transforms = new ArrayDeque<String>();
            transforms.push(transformEntry.getKey());
            while (!transforms.isEmpty()) {
                String id = (String)transforms.pop();
                RunnerApi.PTransform next = components.getTransformsMap().get(id);
                ProtocolStringList subtransforms = next.getSubtransformsList();
                if (subtransforms.isEmpty()) {
                    ids.add(id);
                    continue;
                }
                transforms.addAll(subtransforms);
            }
        }
        return ids;
    }

    private static boolean isPrimitiveTransform(RunnerApi.PTransform transform) {
        String urn = PTransformTranslation.urnForTransformOrNull(transform);
        return PRIMITIVE_URNS.contains(urn) || NativeTransforms.isNative(transform);
    }

    private MutableNetwork<PipelineNode, PipelineEdge> buildNetwork(Collection<String> transformIds, RunnerApi.Components components) {
        MutableNetwork<PipelineNode, PipelineEdge> network = NetworkBuilder.directed().allowsParallelEdges(true).allowsSelfLoops(false).build();
        HashSet<PipelineNode.PCollectionNode> unproducedCollections = new HashSet<PipelineNode.PCollectionNode>();
        for (String transformId : transformIds) {
            RunnerApi.PTransform transform = components.getTransformsOrThrow(transformId);
            PipelineNode.PTransformNode transformNode = PipelineNode.pTransform(transformId, this.components.getTransformsOrThrow(transformId));
            network.addNode(transformNode);
            for (String string : transform.getOutputsMap().values()) {
                PipelineNode.PCollectionNode producedNode = PipelineNode.pCollection(string, components.getPcollectionsOrThrow(string));
                network.addNode(producedNode);
                network.addEdge(transformNode, producedNode, new PerElementEdge());
                Preconditions.checkArgument(network.inDegree(producedNode) == 1, "A %s should have exactly one producing %s, but found %s:\nPCollection:\n%s\nProducers:\n%s", PipelineNode.PCollectionNode.class.getSimpleName(), PipelineNode.PTransformNode.class.getSimpleName(), network.predecessors(producedNode).size(), producedNode, network.predecessors(producedNode));
                unproducedCollections.remove(producedNode);
            }
            for (Map.Entry entry : transform.getInputsMap().entrySet()) {
                String pcollectionId = (String)entry.getValue();
                PipelineNode.PCollectionNode consumedNode = PipelineNode.pCollection(pcollectionId, this.components.getPcollectionsOrThrow(pcollectionId));
                if (network.addNode(consumedNode)) {
                    unproducedCollections.add(consumedNode);
                }
                if (this.getLocalSideInputNames(transform).contains(entry.getKey())) {
                    network.addEdge(consumedNode, transformNode, new SingletonEdge());
                    continue;
                }
                network.addEdge(consumedNode, transformNode, new PerElementEdge());
            }
        }
        Preconditions.checkArgument(unproducedCollections.isEmpty(), "%ss %s were consumed but never produced", (Object)PipelineNode.PCollectionNode.class.getSimpleName(), unproducedCollections);
        return network;
    }

    public Collection<PipelineNode.PTransformNode> getTransforms() {
        return this.pipelineNetwork.nodes().stream().filter(PipelineNode.PTransformNode.class::isInstance).map(PipelineNode.PTransformNode.class::cast).collect(Collectors.toList());
    }

    public Iterable<PipelineNode.PTransformNode> getTopologicallyOrderedTransforms() {
        return StreamSupport.stream(Networks.topologicalOrder(this.pipelineNetwork, Comparator.comparing(PipelineNode::getId)).spliterator(), false).filter(PipelineNode.PTransformNode.class::isInstance).map(PipelineNode.PTransformNode.class::cast).collect(Collectors.toList());
    }

    public Set<PipelineNode.PTransformNode> getRootTransforms() {
        return this.pipelineNetwork.nodes().stream().filter(pipelineNode -> this.pipelineNetwork.inEdges(pipelineNode).isEmpty()).map(pipelineNode -> (PipelineNode.PTransformNode)pipelineNode).collect(Collectors.toSet());
    }

    public PipelineNode.PTransformNode getProducer(PipelineNode.PCollectionNode pcollection) {
        return (PipelineNode.PTransformNode)Iterables.getOnlyElement(this.pipelineNetwork.predecessors(pcollection));
    }

    public Set<PipelineNode.PTransformNode> getPerElementConsumers(PipelineNode.PCollectionNode pCollection) {
        return this.pipelineNetwork.successors(pCollection).stream().filter(consumer -> this.pipelineNetwork.edgesConnecting(pCollection, consumer).stream().anyMatch(PipelineEdge::isPerElement)).map(pipelineNode -> (PipelineNode.PTransformNode)pipelineNode).collect(Collectors.toSet());
    }

    public Set<PipelineNode.PTransformNode> getSingletonConsumers(PipelineNode.PCollectionNode pCollection) {
        return this.pipelineNetwork.successors(pCollection).stream().filter(consumer -> this.pipelineNetwork.edgesConnecting(pCollection, consumer).stream().anyMatch(edge -> !edge.isPerElement())).map(pipelineNode -> (PipelineNode.PTransformNode)pipelineNode).collect(Collectors.toSet());
    }

    public Set<PipelineNode.PCollectionNode> getPerElementInputPCollections(PipelineNode.PTransformNode ptransform) {
        return this.pipelineNetwork.inEdges(ptransform).stream().filter(PipelineEdge::isPerElement).map(edge -> (PipelineNode.PCollectionNode)this.pipelineNetwork.incidentNodes(edge).source()).collect(Collectors.toSet());
    }

    public Set<PipelineNode.PCollectionNode> getOutputPCollections(PipelineNode.PTransformNode ptransform) {
        return this.pipelineNetwork.successors(ptransform).stream().map(pipelineNode -> (PipelineNode.PCollectionNode)pipelineNode).collect(Collectors.toSet());
    }

    public RunnerApi.Components getComponents() {
        return this.components;
    }

    public Collection<SideInputReference> getSideInputs(PipelineNode.PTransformNode transform) {
        return this.getLocalSideInputNames(transform.getTransform()).stream().map(localName -> {
            String transformId = transform.getId();
            RunnerApi.PTransform transformProto = this.components.getTransformsOrThrow(transformId);
            String collectionId = transform.getTransform().getInputsOrThrow((String)localName);
            RunnerApi.PCollection collection = this.components.getPcollectionsOrThrow(collectionId);
            return SideInputReference.of(PipelineNode.pTransform(transformId, transformProto), localName, PipelineNode.pCollection(collectionId, collection));
        }).collect(Collectors.toSet());
    }

    public Collection<UserStateReference> getUserStates(PipelineNode.PTransformNode transform) {
        return this.getLocalUserStateNames(transform.getTransform()).stream().map(localName -> {
            String transformId = transform.getId();
            RunnerApi.PTransform transformProto = this.components.getTransformsOrThrow(transformId);
            String collectionId = transform.getTransform().getInputsOrThrow(Iterables.getOnlyElement(Sets.difference(transform.getTransform().getInputsMap().keySet(), ((ImmutableSet.Builder)((ImmutableSet.Builder)ImmutableSet.builder().addAll(this.getLocalSideInputNames(transformProto))).addAll(this.getLocalTimerNames(transformProto))).build())));
            RunnerApi.PCollection collection = this.components.getPcollectionsOrThrow(collectionId);
            return UserStateReference.of(PipelineNode.pTransform(transformId, transformProto), localName, PipelineNode.pCollection(collectionId, collection));
        }).collect(Collectors.toSet());
    }

    public Collection<TimerReference> getTimers(PipelineNode.PTransformNode transform) {
        return this.getLocalTimerNames(transform.getTransform()).stream().map(localName -> {
            String transformId = transform.getId();
            RunnerApi.PTransform transformProto = this.components.getTransformsOrThrow(transformId);
            return TimerReference.of(PipelineNode.pTransform(transformId, transformProto), localName);
        }).collect(Collectors.toSet());
    }

    private Set<String> getLocalSideInputNames(RunnerApi.PTransform transform) {
        if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(transform.getSpec().getUrn())) {
            try {
                return RunnerApi.ParDoPayload.parseFrom(transform.getSpec().getPayload()).getSideInputsMap().keySet();
            }
            catch (InvalidProtocolBufferException e) {
                throw new RuntimeException(e);
            }
        }
        return Collections.emptySet();
    }

    private Set<String> getLocalUserStateNames(RunnerApi.PTransform transform) {
        if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(transform.getSpec().getUrn())) {
            try {
                return RunnerApi.ParDoPayload.parseFrom(transform.getSpec().getPayload()).getStateSpecsMap().keySet();
            }
            catch (InvalidProtocolBufferException e) {
                throw new RuntimeException(e);
            }
        }
        return Collections.emptySet();
    }

    private Set<String> getLocalTimerNames(RunnerApi.PTransform transform) {
        if (PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(transform.getSpec().getUrn())) {
            try {
                return RunnerApi.ParDoPayload.parseFrom(transform.getSpec().getPayload()).getTimerSpecsMap().keySet();
            }
            catch (InvalidProtocolBufferException e) {
                throw new RuntimeException(e);
            }
        }
        return Collections.emptySet();
    }

    public Optional<RunnerApi.Environment> getEnvironment(PipelineNode.PTransformNode parDo) {
        return Environments.getEnvironment(parDo.getId(), this.components);
    }

    private static class SingletonEdge
    implements PipelineEdge {
        private SingletonEdge() {
        }

        @Override
        public boolean isPerElement() {
            return false;
        }
    }

    private static class PerElementEdge
    implements PipelineEdge {
        private PerElementEdge() {
        }

        @Override
        public boolean isPerElement() {
            return true;
        }
    }

    private static interface PipelineEdge {
        public boolean isPerElement();
    }
}

