package org.apache.flink.runtime.executiongraph.failover;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.concurrent.FutureUtils;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.GlobalModVersionMismatch;
import org.apache.flink.runtime.executiongraph.SchedulingUtils;
import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy;
import org.apache.flink.runtime.executiongraph.failover.adapter.DefaultFailoverTopology;
import org.apache.flink.runtime.executiongraph.restart.RestartCallback;
import org.apache.flink.runtime.jobgraph.JobStatus;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
import org.apache.flink.runtime.scheduler.ExecutionVertexVersion;
import org.apache.flink.runtime.scheduler.ExecutionVertexVersioner;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/flink/runtime/executiongraph/failover/AdaptedRestartPipelinedRegionStrategyNG.class */
public class AdaptedRestartPipelinedRegionStrategyNG extends FailoverStrategy {
    private static final Logger LOG = LoggerFactory.getLogger(AdaptedRestartPipelinedRegionStrategyNG.class);
    private final ExecutionGraph executionGraph;
    private final ExecutionVertexVersioner executionVertexVersioner = new ExecutionVertexVersioner();
    private org.apache.flink.runtime.executiongraph.failover.flip1.RestartPipelinedRegionStrategy restartPipelinedRegionStrategy;

    /* loaded from: input_file:org/apache/flink/runtime/executiongraph/failover/AdaptedRestartPipelinedRegionStrategyNG$Factory.class */
    public static class Factory implements FailoverStrategy.Factory {
        @Override // org.apache.flink.runtime.executiongraph.failover.FailoverStrategy.Factory
        public FailoverStrategy create(ExecutionGraph executionGraph) {
            return new AdaptedRestartPipelinedRegionStrategyNG(executionGraph);
        }
    }

    public AdaptedRestartPipelinedRegionStrategyNG(ExecutionGraph executionGraph) {
        this.executionGraph = (ExecutionGraph) Preconditions.checkNotNull(executionGraph);
    }

    @Override // org.apache.flink.runtime.executiongraph.failover.FailoverStrategy
    public void onTaskFailure(Execution execution, Throwable th) {
        if (!this.executionGraph.getRestartStrategy().canRestart()) {
            LOG.info("Fail to pass the restart strategy validation in region failover. Fallback to fail global.");
            failGlobal(th);
        } else if (!isLocalFailoverValid(this.executionGraph.getGlobalModVersion())) {
            LOG.info("Skip current region failover as a global failover is ongoing.");
        } else {
            restartTasks(this.restartPipelinedRegionStrategy.getTasksNeedingRestart(getExecutionVertexID(execution.getVertex()), th));
        }
    }

    @VisibleForTesting
    protected void restartTasks(Set<ExecutionVertexID> set) {
        long globalModVersion = this.executionGraph.getGlobalModVersion();
        HashSet hashSet = new HashSet(this.executionVertexVersioner.recordVertexModifications(set).values());
        this.executionGraph.incrementRestarts();
        FutureUtils.assertNoException(cancelTasks(set).thenComposeAsync((Function<? super Object, ? extends CompletionStage<U>>) resetAndRescheduleTasks(globalModVersion, hashSet), (Executor) this.executionGraph.getJobMasterMainThreadExecutor()).handle((BiFunction<? super U, Throwable, ? extends U>) failGlobalOnError()));
    }

    private Function<Object, CompletableFuture<Void>> resetAndRescheduleTasks(long j, Set<ExecutionVertexVersion> set) {
        return obj -> {
            return this.executionGraph.getRestartStrategy().restart(createResetAndRescheduleTasksCallback(j, set), this.executionGraph.getJobMasterMainThreadExecutor());
        };
    }

    private RestartCallback createResetAndRescheduleTasksCallback(long j, Set<ExecutionVertexVersion> set) {
        return () -> {
            if (!isLocalFailoverValid(j)) {
                LOG.info("Skip current region failover as a global failover is ongoing.");
                return;
            }
            Set<ExecutionVertex> set2 = (Set) this.executionVertexVersioner.getUnmodifiedExecutionVertices(set).stream().map(this::getExecutionVertex).collect(Collectors.toSet());
            try {
                LOG.info("Finally restart {} tasks to recover from task failure.", Integer.valueOf(set2.size()));
                resetTasks(set2, j);
                rescheduleTasks(set2, j);
            } catch (GlobalModVersionMismatch e) {
                throw new IllegalStateException("Bug: ExecutionGraph was concurrently modified outside of main thread", e);
            } catch (Exception e2) {
                throw new CompletionException(e2);
            }
        };
    }

    private BiFunction<Object, Throwable, Object> failGlobalOnError() {
        return (obj, th) -> {
            if (th == null) {
                return null;
            }
            LOG.info("Unexpected error happens in region failover. Fail globally.", th);
            failGlobal(th);
            return null;
        };
    }

    @VisibleForTesting
    protected CompletableFuture<?> cancelTasks(Set<ExecutionVertexID> set) {
        return FutureUtils.combineAll((List) set.stream().map(this::cancelExecutionVertex).collect(Collectors.toList()));
    }

    private void resetTasks(Set<ExecutionVertex> set, long j) throws Exception {
        HashSet hashSet = new HashSet();
        long currentTimeMillis = System.currentTimeMillis();
        for (ExecutionVertex executionVertex : set) {
            CoLocationGroup coLocationGroup = executionVertex.getJobVertex().getCoLocationGroup();
            if (coLocationGroup != null && !hashSet.contains(coLocationGroup)) {
                coLocationGroup.resetConstraints();
                hashSet.add(coLocationGroup);
            }
            executionVertex.resetForNewExecution(currentTimeMillis, j);
        }
        if (this.executionGraph.getCheckpointCoordinator() != null) {
            this.executionGraph.getCheckpointCoordinator().abortPendingCheckpoints(new CheckpointException(CheckpointFailureReason.JOB_FAILOVER_REGION));
            this.executionGraph.getCheckpointCoordinator().restoreLatestCheckpointedState(getInvolvedExecutionJobVertices(set), false, true);
        }
    }

    private void rescheduleTasks(Set<ExecutionVertex> set, long j) throws Exception {
        CompletableFuture<Void> schedule = SchedulingUtils.schedule(this.executionGraph.getScheduleMode(), sortVerticesTopologically(set), this.executionGraph);
        if (isLocalFailoverValid(j)) {
            schedule.whenComplete((r4, th) -> {
                if (th != null) {
                    Throwable stripCompletionException = ExceptionUtils.stripCompletionException(th);
                    if (stripCompletionException instanceof CancellationException) {
                        return;
                    }
                    failGlobal(stripCompletionException);
                }
            });
        }
    }

    private boolean isLocalFailoverValid(long j) {
        return this.executionGraph.getState() == JobStatus.RUNNING && this.executionGraph.getGlobalModVersion() == j;
    }

    private CompletableFuture<?> cancelExecutionVertex(ExecutionVertexID executionVertexID) {
        return getExecutionVertex(executionVertexID).cancel();
    }

    private Map<JobVertexID, ExecutionJobVertex> getInvolvedExecutionJobVertices(Set<ExecutionVertex> set) {
        HashMap hashMap = new HashMap();
        for (ExecutionVertex executionVertex : set) {
            hashMap.putIfAbsent(executionVertex.getJobvertexId(), executionVertex.getJobVertex());
        }
        return hashMap;
    }

    private void failGlobal(Throwable th) {
        this.executionGraph.failGlobal(th);
    }

    private ExecutionVertex getExecutionVertex(ExecutionVertexID executionVertexID) {
        return this.executionGraph.getAllVertices().get(executionVertexID.getJobVertexId()).getTaskVertices()[executionVertexID.getSubtaskIndex()];
    }

    private ExecutionVertexID getExecutionVertexID(ExecutionVertex executionVertex) {
        return new ExecutionVertexID(executionVertex.getJobvertexId(), executionVertex.getParallelSubtaskIndex());
    }

    private List<ExecutionVertex> sortVerticesTopologically(Set<ExecutionVertex> set) {
        HashMap hashMap = new HashMap();
        for (ExecutionVertex executionVertex : set) {
            ((List) hashMap.computeIfAbsent(executionVertex.getJobvertexId(), jobVertexID -> {
                return new ArrayList();
            })).add(executionVertex);
        }
        ArrayList arrayList = new ArrayList(set.size());
        Iterator<ExecutionJobVertex> it = this.executionGraph.getVerticesTopologically().iterator();
        while (it.hasNext()) {
            arrayList.addAll((Collection) hashMap.getOrDefault(it.next().getJobVertexId(), Collections.emptyList()));
        }
        return arrayList;
    }

    @Override // org.apache.flink.runtime.executiongraph.failover.FailoverStrategy
    public void notifyNewVertices(List<ExecutionJobVertex> list) {
        Preconditions.checkState(this.restartPipelinedRegionStrategy == null, "notifyNewVertices() must be called only once");
        this.restartPipelinedRegionStrategy = new org.apache.flink.runtime.executiongraph.failover.flip1.RestartPipelinedRegionStrategy(new DefaultFailoverTopology(this.executionGraph), this.executionGraph.getResultPartitionAvailabilityChecker());
    }

    @Override // org.apache.flink.runtime.executiongraph.failover.FailoverStrategy
    public String getStrategyName() {
        return "New Pipelined Region Failover";
    }
}
