/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.runners.direct.portable;

import java.io.File;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.model.fnexecution.v1.ProvisionApi;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.repackaged.beam_runners_direct_java.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.repackaged.beam_runners_direct_java.com.google.common.base.Preconditions;
import org.apache.beam.repackaged.beam_runners_direct_java.com.google.common.collect.ImmutableList;
import org.apache.beam.repackaged.beam_runners_direct_java.com.google.common.collect.Iterables;
import org.apache.beam.repackaged.beam_runners_direct_java.com.google.common.collect.Maps;
import org.apache.beam.repackaged.beam_runners_direct_java.com.google.common.collect.Sets;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.ModelCoders;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.PTransformTranslation;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.SyntheticComponents;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.GreedyPipelineFuser;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.PipelineNode;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.PipelineValidator;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.ProtoOverrides;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.core.construction.graph.QueryablePipeline;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.GrpcContextHeaderAccessorProvider;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.GrpcFnServer;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.InProcessServerFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.ServerFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.artifact.ArtifactRetrievalService;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.control.ControlClientPool;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.control.FnApiControlClientPoolService;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.control.JobBundleFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.control.MapControlClientPool;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.control.SingleEnvironmentInstanceJobBundleFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.data.GrpcDataService;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.environment.DockerEnvironmentFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.environment.EnvironmentFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.environment.InProcessEnvironmentFactory;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.logging.GrpcLoggingService;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.logging.Slf4jLogWriter;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.provisioning.StaticGrpcProvisionService;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.state.GrpcStateService;
import org.apache.beam.repackaged.beam_runners_direct_java.runners.fnexecution.wire.LengthPrefixUnknownCoders;
import org.apache.beam.repackaged.beam_runners_direct_java.sdk.fn.IdGenerators;
import org.apache.beam.repackaged.beam_runners_direct_java.sdk.fn.stream.OutboundObserverFactory;
import org.apache.beam.runners.direct.ExecutableGraph;
import org.apache.beam.runners.direct.portable.EvaluationContext;
import org.apache.beam.runners.direct.portable.EvaluationContextStepStateAndTimersProvider;
import org.apache.beam.runners.direct.portable.ExecutorServiceParallelExecutor;
import org.apache.beam.runners.direct.portable.ImmutableListBundleFactory;
import org.apache.beam.runners.direct.portable.PortableGraph;
import org.apache.beam.runners.direct.portable.RootProviderRegistry;
import org.apache.beam.runners.direct.portable.TransformEvaluatorRegistry;
import org.apache.beam.runners.direct.portable.artifact.LocalFileSystemArtifactRetrievalService;
import org.apache.beam.runners.direct.portable.artifact.UnsupportedArtifactRetrievalService;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.vendor.protobuf.v3.com.google.protobuf.Struct;
import org.joda.time.Duration;
import org.joda.time.Instant;

public class ReferenceRunner {
    private final RunnerApi.Pipeline pipeline;
    private final Struct options;
    @Nullable
    private final File artifactsDir;
    private final EnvironmentType environmentType;

    private ReferenceRunner(RunnerApi.Pipeline p, Struct options, @Nullable File artifactsDir, EnvironmentType environmentType) {
        this.pipeline = this.executable(p);
        this.options = options;
        this.artifactsDir = artifactsDir;
        this.environmentType = environmentType;
    }

    public static ReferenceRunner forPipeline(RunnerApi.Pipeline p, Struct options, File artifactsDir) {
        return new ReferenceRunner(p, options, artifactsDir, EnvironmentType.DOCKER);
    }

    static ReferenceRunner forInProcessPipeline(RunnerApi.Pipeline p, Struct options) {
        return new ReferenceRunner(p, options, null, EnvironmentType.IN_PROCESS);
    }

    private RunnerApi.Pipeline executable(RunnerApi.Pipeline original) {
        RunnerApi.Pipeline p = original;
        PipelineValidator.validate(p);
        p = ProtoOverrides.updateTransform(PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN, p, new SplittableProcessKeyedReplacer());
        p = ProtoOverrides.updateTransform(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, p, new PortableGroupByKeyReplacer());
        p = GreedyPipelineFuser.fuse(p).toPipeline();
        p = ReferenceRunner.foldFeedSDFIntoExecutableStage(p);
        PipelineValidator.validate(p);
        return p;
    }

    private static Set<PipelineNode.PCollectionNode> getKeyedPCollections(ExecutableGraph<PipelineNode.PTransformNode, PipelineNode.PCollectionNode> graph) {
        HashSet<PipelineNode.PCollectionNode> res = Sets.newHashSet();
        HashSet<String> keyedProducers = Sets.newHashSet("urn:beam:directrunner:transforms:gbko:v1", "urn:beam:directrunner:transforms:gabw:v1");
        for (PipelineNode.PTransformNode transform : graph.getExecutables()) {
            if (!keyedProducers.contains(transform.getTransform().getSpec().getUrn())) continue;
            res.addAll(graph.getProduced(transform));
        }
        return res;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void execute() throws Exception {
        PortableGraph graph = PortableGraph.forPipeline(this.pipeline);
        ImmutableListBundleFactory bundleFactory = ImmutableListBundleFactory.create();
        EvaluationContext ctxt = EvaluationContext.create(Instant::new, bundleFactory, graph, ReferenceRunner.getKeyedPCollections(graph));
        RootProviderRegistry rootRegistry = RootProviderRegistry.impulseRegistry(bundleFactory);
        int targetParallelism = Math.max(Runtime.getRuntime().availableProcessors(), 3);
        ServerFactory serverFactory = this.createServerFactory();
        MapControlClientPool controlClientPool = MapControlClientPool.create();
        ExecutorService dataExecutor = Executors.newCachedThreadPool();
        ProvisionApi.ProvisionInfo provisionInfo = ProvisionApi.ProvisionInfo.newBuilder().setJobId("id").setJobName("reference").setPipelineOptions(this.options).setWorkerId("foo").setResourceLimits(ProvisionApi.Resources.getDefaultInstance()).build();
        try (GrpcFnServer<GrpcLoggingService> logging = GrpcFnServer.allocatePortAndCreateFor(GrpcLoggingService.forWriter(Slf4jLogWriter.getDefault()), serverFactory);
             GrpcFnServer<ArtifactRetrievalService> artifact = this.artifactsDir == null ? GrpcFnServer.allocatePortAndCreateFor(UnsupportedArtifactRetrievalService.create(), serverFactory) : GrpcFnServer.allocatePortAndCreateFor(LocalFileSystemArtifactRetrievalService.forRootDirectory(this.artifactsDir), serverFactory);
             GrpcFnServer<StaticGrpcProvisionService> provisioning = GrpcFnServer.allocatePortAndCreateFor(StaticGrpcProvisionService.create(provisionInfo), serverFactory);
             GrpcFnServer<FnApiControlClientPoolService> control = GrpcFnServer.allocatePortAndCreateFor(FnApiControlClientPoolService.offeringClientsToPool(controlClientPool.getSink(), GrpcContextHeaderAccessorProvider.getHeaderAccessor()), serverFactory);
             GrpcFnServer<GrpcDataService> data = GrpcFnServer.allocatePortAndCreateFor(GrpcDataService.create(dataExecutor, OutboundObserverFactory.serverDirect()), serverFactory);
             GrpcFnServer<GrpcStateService> state = GrpcFnServer.allocatePortAndCreateFor(GrpcStateService.create(), serverFactory);){
            EnvironmentFactory environmentFactory = this.createEnvironmentFactory(control, logging, artifact, provisioning, controlClientPool.getSource());
            JobBundleFactory jobBundleFactory = SingleEnvironmentInstanceJobBundleFactory.create(environmentFactory, data, state);
            TransformEvaluatorRegistry transformRegistry = TransformEvaluatorRegistry.portableRegistry(graph, this.pipeline.getComponents(), bundleFactory, jobBundleFactory, EvaluationContextStepStateAndTimersProvider.forContext(ctxt));
            ExecutorServiceParallelExecutor executor = ExecutorServiceParallelExecutor.create(targetParallelism, rootRegistry, transformRegistry, graph, ctxt);
            executor.start();
            executor.waitUntilFinish(Duration.ZERO);
        }
        finally {
            dataExecutor.shutdown();
        }
    }

    private ServerFactory createServerFactory() {
        switch (this.environmentType) {
            case DOCKER: {
                return ServerFactory.createDefault();
            }
            case IN_PROCESS: {
                return InProcessServerFactory.create();
            }
        }
        throw new IllegalArgumentException(String.format("Unknown %s %s", new Object[]{EnvironmentType.class.getSimpleName(), this.environmentType}));
    }

    private EnvironmentFactory createEnvironmentFactory(GrpcFnServer<FnApiControlClientPoolService> control, GrpcFnServer<GrpcLoggingService> logging, GrpcFnServer<ArtifactRetrievalService> artifact, GrpcFnServer<StaticGrpcProvisionService> provisioning, ControlClientPool.Source controlClientSource) {
        switch (this.environmentType) {
            case DOCKER: {
                return DockerEnvironmentFactory.forServices(control, logging, artifact, provisioning, controlClientSource, IdGenerators.incrementingLongs());
            }
            case IN_PROCESS: {
                return InProcessEnvironmentFactory.create(PipelineOptionsFactory.create(), logging, control, controlClientSource);
            }
        }
        throw new IllegalArgumentException(String.format("Unknown %s %s", new Object[]{EnvironmentType.class.getSimpleName(), this.environmentType}));
    }

    private static RunnerApi.Pipeline foldFeedSDFIntoExecutableStage(RunnerApi.Pipeline p) {
        RunnerApi.Pipeline.Builder newPipeline = p.toBuilder();
        RunnerApi.Components.Builder newPipelineComponents = newPipeline.getComponentsBuilder();
        QueryablePipeline q = QueryablePipeline.forPipeline(p);
        String feedSdfUrn = "beam:directrunner:transforms:feed_sdf:v1";
        List feedSDFNodes = q.getTransforms().stream().filter(node -> node.getTransform().getSpec().getUrn().equals(feedSdfUrn)).collect(Collectors.toList());
        HashMap<String, PipelineNode.PTransformNode> stageToFeeder = Maps.newHashMap();
        for (PipelineNode.PTransformNode node2 : feedSDFNodes) {
            PipelineNode.PCollectionNode output = Iterables.getOnlyElement(q.getOutputPCollections(node2));
            PipelineNode.PTransformNode consumer = Iterables.getOnlyElement(q.getPerElementConsumers(output));
            String consumerUrn = consumer.getTransform().getSpec().getUrn();
            Preconditions.checkState(consumerUrn.equals("beam:runner:executable_stage:v1"), "Expected all FeedSDF nodes to be consumed by an ExecutableStage, but %s is consumed by %s which is %s", (Object)node2.getId(), (Object)consumer.getId(), (Object)consumerUrn);
            stageToFeeder.put(consumer.getId(), node2);
        }
        Set feedSDFIds = feedSDFNodes.stream().map(PipelineNode.PTransformNode::getId).collect(Collectors.toSet());
        newPipeline.clearRootTransformIds();
        for (String rootId : p.getRootTransformIdsList()) {
            if (feedSDFIds.contains(rootId)) continue;
            newPipeline.addRootTransformIds(rootId);
        }
        for (PipelineNode.PTransformNode node3 : q.getTransforms()) {
            if (feedSDFNodes.contains(node3)) continue;
            if (!stageToFeeder.containsKey(node3.getId())) {
                newPipelineComponents.putTransforms(node3.getId(), node3.getTransform());
                continue;
            }
            PipelineNode.PTransformNode feedSDFNode = (PipelineNode.PTransformNode)stageToFeeder.get(node3.getId());
            PipelineNode.PCollectionNode rawGBKOutput = Iterables.getOnlyElement(q.getPerElementInputPCollections(feedSDFNode));
            newPipelineComponents.putTransforms(node3.getId(), node3.getTransform().toBuilder().mergeSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:directrunner:transforms:splittable_remote_stage:v1").build()).putInputs((String)Iterables.getOnlyElement(node3.getTransform().getInputsMap().keySet()), rawGBKOutput.getId()).build());
        }
        return newPipeline.build();
    }

    private static enum EnvironmentType {
        DOCKER,
        IN_PROCESS;

    }

    @VisibleForTesting
    static class SplittableProcessKeyedReplacer
    implements ProtoOverrides.TransformReplacement {
        SplittableProcessKeyedReplacer() {
        }

        @Override
        public RunnerApi.MessageWithComponents getReplacement(String spkId, RunnerApi.ComponentsOrBuilder components) {
            RunnerApi.PTransform spk = components.getTransformsOrThrow(spkId);
            Preconditions.checkArgument(PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN.equals(spk.getSpec().getUrn()), "URN must be %s, got %s", (Object)PTransformTranslation.SPLITTABLE_PROCESS_KEYED_URN, (Object)spk.getSpec().getUrn());
            RunnerApi.Components.Builder newComponents = RunnerApi.Components.newBuilder();
            newComponents.putAllCoders(components.getCodersMap());
            RunnerApi.PTransform.Builder newPTransform = spk.toBuilder();
            String inputId = (String)Iterables.getOnlyElement(spk.getInputsMap().values());
            RunnerApi.PCollection input = components.getPcollectionsOrThrow(inputId);
            RunnerApi.Coder inputCoder = components.getCodersOrThrow(input.getCoderId());
            ModelCoders.KvCoderComponents kvComponents = ModelCoders.getKvCoderComponents(inputCoder);
            String windowCoderId = components.getWindowingStrategiesOrThrow(input.getWindowingStrategyId()).getWindowCoderId();
            String kwiCollectionId = SyntheticComponents.uniqueId(String.format("%s.kwi", spkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsPcollections(arg_0));
            RunnerApi.Coder kwiCoder = RunnerApi.Coder.newBuilder().setSpec(RunnerApi.SdkFunctionSpec.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:direct:keyedworkitem:v1"))).addAllComponentCoderIds(ImmutableList.of(kvComponents.keyCoderId(), kvComponents.valueCoderId(), windowCoderId)).build();
            String kwiCoderId = SyntheticComponents.uniqueId(String.format("keyed_work_item(%s:%s)", kvComponents.keyCoderId(), kvComponents.valueCoderId()), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsCoders(arg_0));
            RunnerApi.PCollection kwiCollection = input.toBuilder().setUniqueName(kwiCollectionId).setCoderId(kwiCoderId).build();
            String rawGbkId = SyntheticComponents.uniqueId(String.format("%s/RawGBK", spkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsTransforms(arg_0));
            RunnerApi.PTransform rawGbk = RunnerApi.PTransform.newBuilder().setUniqueName(String.format("%s/RawGBK", spk.getUniqueName())).putAllInputs(spk.getInputsMap()).setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:beam:directrunner:transforms:gbko:v1")).putOutputs("output", kwiCollectionId).build();
            newComponents.putCoders(kwiCoderId, kwiCoder).putPcollections(kwiCollectionId, kwiCollection).putTransforms(rawGbkId, rawGbk);
            newPTransform.addSubtransforms(rawGbkId);
            String feedSDFCollectionId = SyntheticComponents.uniqueId(String.format("%s.feed", spkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsPcollections(arg_0));
            String elementRestrictionCoderId = kvComponents.valueCoderId();
            String feedSDFCoderId = LengthPrefixUnknownCoders.addLengthPrefixedCoder(elementRestrictionCoderId, newComponents, false);
            RunnerApi.PCollection feedSDFCollection = input.toBuilder().setUniqueName(feedSDFCollectionId).setCoderId(feedSDFCoderId).build();
            String feedSDFId = SyntheticComponents.uniqueId(String.format("%s/FeedSDF", spkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsTransforms(arg_0));
            RunnerApi.PTransform feedSDF = RunnerApi.PTransform.newBuilder().setUniqueName(String.format("%s/FeedSDF", spk.getUniqueName())).putInputs("input", kwiCollectionId).setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:directrunner:transforms:feed_sdf:v1")).putOutputs("output", feedSDFCollectionId).build();
            newComponents.putPcollections(feedSDFCollectionId, feedSDFCollection).putTransforms(feedSDFId, feedSDF);
            newPTransform.addSubtransforms(feedSDFId);
            String runSDFId = SyntheticComponents.uniqueId(String.format("%s/RunSDF", spkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsTransforms(arg_0));
            RunnerApi.PTransform runSDF = RunnerApi.PTransform.newBuilder().setUniqueName(String.format("%s/RunSDF", spk.getUniqueName())).putInputs("input", feedSDFCollectionId).setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn(PTransformTranslation.SPLITTABLE_PROCESS_ELEMENTS_URN).setPayload(spk.getSpec().getPayload())).putAllOutputs(spk.getOutputsMap()).build();
            newComponents.putTransforms(runSDFId, runSDF);
            newPTransform.addSubtransforms(runSDFId);
            return RunnerApi.MessageWithComponents.newBuilder().setPtransform(newPTransform.build()).setComponents(newComponents).build();
        }
    }

    @VisibleForTesting
    static class PortableGroupByKeyReplacer
    implements ProtoOverrides.TransformReplacement {
        PortableGroupByKeyReplacer() {
        }

        @Override
        public RunnerApi.MessageWithComponents getReplacement(String gbkId, RunnerApi.ComponentsOrBuilder components) {
            RunnerApi.PTransform gbk = components.getTransformsOrThrow(gbkId);
            Preconditions.checkArgument(PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN.equals(gbk.getSpec().getUrn()), "URN must be %s, got %s", (Object)PTransformTranslation.GROUP_BY_KEY_TRANSFORM_URN, (Object)gbk.getSpec().getUrn());
            RunnerApi.PTransform.Builder newTransform = gbk.toBuilder();
            RunnerApi.Components.Builder newComponents = RunnerApi.Components.newBuilder();
            String inputId = (String)Iterables.getOnlyElement(gbk.getInputsMap().values());
            String kwiCollectionId = SyntheticComponents.uniqueId(String.format("%s.%s", inputId, "kwi"), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsPcollections(arg_0));
            RunnerApi.PCollection input = components.getPcollectionsOrThrow(inputId);
            RunnerApi.Coder inputCoder = components.getCodersOrThrow(input.getCoderId());
            ModelCoders.KvCoderComponents kvComponents = ModelCoders.getKvCoderComponents(inputCoder);
            String windowCoderId = components.getWindowingStrategiesOrThrow(input.getWindowingStrategyId()).getWindowCoderId();
            RunnerApi.Coder kwiCoder = RunnerApi.Coder.newBuilder().setSpec(RunnerApi.SdkFunctionSpec.newBuilder().setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("beam:direct:keyedworkitem:v1"))).addAllComponentCoderIds(ImmutableList.of(kvComponents.keyCoderId(), kvComponents.valueCoderId(), windowCoderId)).build();
            String kwiCoderId = SyntheticComponents.uniqueId(String.format("kwi(%s:%s)", kvComponents.keyCoderId(), kvComponents.valueCoderId()), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsCoders(arg_0));
            RunnerApi.PCollection kwi = input.toBuilder().setUniqueName(kwiCollectionId).setCoderId(kwiCoderId).build();
            String gbkoId = SyntheticComponents.uniqueId(String.format("%s/GBKO", gbkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsTransforms(arg_0));
            RunnerApi.PTransform gbko = RunnerApi.PTransform.newBuilder().setUniqueName(String.format("%s/GBKO", gbk.getUniqueName())).putAllInputs(gbk.getInputsMap()).setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:beam:directrunner:transforms:gbko:v1")).putOutputs("output", kwiCollectionId).build();
            newTransform.addSubtransforms(gbkoId);
            newComponents.putCoders(kwiCoderId, kwiCoder).putPcollections(kwiCollectionId, kwi).putTransforms(gbkoId, gbko);
            String gabwId = SyntheticComponents.uniqueId(String.format("%s/GABW", gbkId), arg_0 -> ((RunnerApi.ComponentsOrBuilder)components).containsTransforms(arg_0));
            RunnerApi.PTransform gabw = RunnerApi.PTransform.newBuilder().setUniqueName(String.format("%s/GABW", gbk.getUniqueName())).putInputs("input", kwiCollectionId).setSpec(RunnerApi.FunctionSpec.newBuilder().setUrn("urn:beam:directrunner:transforms:gabw:v1")).putAllOutputs(gbk.getOutputsMap()).build();
            newTransform.addSubtransforms(gabwId);
            newComponents.putTransforms(gabwId, gabw);
            return RunnerApi.MessageWithComponents.newBuilder().setPtransform(newTransform).setComponents(newComponents).build();
        }
    }
}

