package com.amazon.randomcutforest.executor;

import com.amazon.randomcutforest.ComponentList;
import com.amazon.randomcutforest.IMultiVisitorFactory;
import com.amazon.randomcutforest.IVisitorFactory;
import com.amazon.randomcutforest.returntypes.ConvergingAccumulator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.ForkJoinPool;
import java.util.function.BinaryOperator;
import java.util.function.Function;
import java.util.stream.Collector;
import java.util.stream.Collectors;

/* loaded from: input_file:com/amazon/randomcutforest/executor/ParallelForestTraversalExecutor.class */
public class ParallelForestTraversalExecutor extends AbstractForestTraversalExecutor {
    ForkJoinPool forkJoinPool;
    private final int threadPoolSize;

    public ParallelForestTraversalExecutor(ComponentList<?, ?> componentList, int i) {
        super(componentList);
        this.threadPoolSize = i;
        this.forkJoinPool = new ForkJoinPool(i);
    }

    @Override // com.amazon.randomcutforest.executor.AbstractForestTraversalExecutor
    public <R, S> S traverseForest(float[] fArr, IVisitorFactory<R> iVisitorFactory, BinaryOperator<R> binaryOperator, Function<R, S> function) {
        return (S) ((Optional) submitAndJoin(() -> {
            return this.components.parallelStream().map(iComponentModel -> {
                return iComponentModel.traverse(fArr, iVisitorFactory);
            }).reduce(binaryOperator).map(function);
        })).orElseThrow(() -> {
            return new IllegalStateException("accumulator returned an empty result");
        });
    }

    @Override // com.amazon.randomcutforest.executor.AbstractForestTraversalExecutor
    public <R, S> S traverseForest(float[] fArr, IVisitorFactory<R> iVisitorFactory, Collector<R, ?, S> collector) {
        return (S) submitAndJoin(() -> {
            return this.components.parallelStream().map(iComponentModel -> {
                return iComponentModel.traverse(fArr, iVisitorFactory);
            }).collect(collector);
        });
    }

    @Override // com.amazon.randomcutforest.executor.AbstractForestTraversalExecutor
    public <R, S> S traverseForest(float[] fArr, IVisitorFactory<R> iVisitorFactory, ConvergingAccumulator<R> convergingAccumulator, Function<R, S> function) {
        int i = 0;
        while (true) {
            int i2 = i;
            if (i2 >= this.components.size()) {
                break;
            }
            int min = Math.min(i2 + this.threadPoolSize, this.components.size());
            List list = (List) submitAndJoin(() -> {
                return (List) this.components.subList(i2, min).parallelStream().map(iComponentModel -> {
                    return iComponentModel.traverse(fArr, iVisitorFactory);
                }).collect(Collectors.toList());
            });
            Objects.requireNonNull(convergingAccumulator);
            list.forEach(convergingAccumulator::accept);
            if (convergingAccumulator.isConverged()) {
                break;
            }
            i = i2 + this.threadPoolSize;
        }
        return function.apply(convergingAccumulator.getAccumulatedValue());
    }

    @Override // com.amazon.randomcutforest.executor.AbstractForestTraversalExecutor
    public <R, S> S traverseForestMulti(float[] fArr, IMultiVisitorFactory<R> iMultiVisitorFactory, BinaryOperator<R> binaryOperator, Function<R, S> function) {
        return (S) ((Optional) submitAndJoin(() -> {
            return this.components.parallelStream().map(iComponentModel -> {
                return iComponentModel.traverseMulti(fArr, iMultiVisitorFactory);
            }).reduce(binaryOperator).map(function);
        })).orElseThrow(() -> {
            return new IllegalStateException("accumulator returned an empty result");
        });
    }

    @Override // com.amazon.randomcutforest.executor.AbstractForestTraversalExecutor
    public <R, S> S traverseForestMulti(float[] fArr, IMultiVisitorFactory<R> iMultiVisitorFactory, Collector<R, ?, S> collector) {
        return (S) submitAndJoin(() -> {
            return this.components.parallelStream().map(iComponentModel -> {
                return iComponentModel.traverseMulti(fArr, iMultiVisitorFactory);
            }).collect(collector);
        });
    }

    <T> T submitAndJoin(Callable<T> callable) {
        if (this.forkJoinPool == null) {
            this.forkJoinPool = new ForkJoinPool(this.threadPoolSize);
        }
        return this.forkJoinPool.submit((Callable) callable).join();
    }
}
