package org.nd4j.linalg.api.parallel.tasks.cpu.accumulation;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.BaseTask;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUTask;
import org.nd4j.linalg.api.shape.tensor.TensorCalculator;
import org.nd4j.linalg.api.shape.tensor.TensorCalculatorFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/accumulation/CPUAccumulationAlongDimensionTask.class */
public class CPUAccumulationAlongDimensionTask extends BaseCPUTask<INDArray> {
    protected final Accumulation op;
    protected final int[] dimensions;
    protected List<Task<Double>> subTasks;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/accumulation/CPUAccumulationAlongDimensionTask$OpForDimTask.class */
    public class OpForDimTask extends BaseTask<Double> {
        private int tensorNum;
        private BaseCPUTask<Double> subTask;
        private Future<Double> future;

        public OpForDimTask(int i) {
            this.tensorNum = i;
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.Task
        public void invokeAsync() {
            this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.Task
        public Double blockUntilComplete() {
            try {
                this.future.get();
                return this.subTask.blockUntilComplete();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
        public Double call() {
            Accumulation accumulation = (Accumulation) CPUAccumulationAlongDimensionTask.this.op.opForDimension(this.tensorNum, CPUAccumulationAlongDimensionTask.this.dimensions);
            INDArray x = accumulation.x();
            INDArray y = accumulation.y();
            if (y == null ? OpExecutionerUtil.canDoOpDirectly(x) : OpExecutionerUtil.canDoOpDirectly(x, y)) {
                this.subTask = new CPUAccumulationTask(accumulation, CPUAccumulationAlongDimensionTask.this.threshold, true);
            } else {
                this.subTask = new CPUAccumulationViaTensorTask(accumulation, CPUAccumulationAlongDimensionTask.this.threshold, true);
            }
            this.subTask.invokeAsync();
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/accumulation/CPUAccumulationAlongDimensionTask$OpForDimTaskFJ.class */
    public class OpForDimTaskFJ extends RecursiveTask<Double> implements Task<Double> {
        private int tensorNum;
        private BaseCPUTask<Double> subTask;
        private Future<Double> future;

        public OpForDimTaskFJ(int i) {
            this.tensorNum = i;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.nd4j.linalg.api.parallel.tasks.Task
        public Double invokeBlocking() {
            invokeAsync();
            return blockUntilComplete();
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.Task
        public void invokeAsync() {
            this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.nd4j.linalg.api.parallel.tasks.Task
        public Double blockUntilComplete() {
            return null;
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
        public Double call() {
            throw new RuntimeException("Callable.call() called as part of ForkJoin task");
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.RecursiveTask
        public Double compute() {
            Accumulation accumulation = (Accumulation) CPUAccumulationAlongDimensionTask.this.op.opForDimension(this.tensorNum, CPUAccumulationAlongDimensionTask.this.dimensions);
            INDArray x = accumulation.x();
            INDArray y = accumulation.y();
            if (y == null ? OpExecutionerUtil.canDoOpDirectly(x) : OpExecutionerUtil.canDoOpDirectly(x, y)) {
                this.subTask = new CPUAccumulationTask(accumulation, CPUAccumulationAlongDimensionTask.this.threshold, true);
            } else {
                this.subTask = new CPUAccumulationViaTensorTask(accumulation, CPUAccumulationAlongDimensionTask.this.threshold, true);
            }
            return (Double) this.subTask.invoke();
        }
    }

    public CPUAccumulationAlongDimensionTask(Accumulation accumulation, int i, int... iArr) {
        super(accumulation, i);
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (iArr[i2] < 0) {
                int i3 = i2;
                iArr[i3] = iArr[i3] + accumulation.x().rank();
            }
        }
        this.op = accumulation;
        this.dimensions = iArr;
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public INDArray blockUntilComplete() {
        if (this.future == null) {
            invokeAsync();
        }
        try {
            INDArray iNDArray = (INDArray) this.future.get();
            if (iNDArray != null) {
                if (this.dimensions.length == 1 && this.dimensions[0] == 1 && this.op.x().isMatrix()) {
                    iNDArray = iNDArray.reshape(iNDArray.length(), 1);
                }
                return iNDArray;
            }
            int[] removeIndex = ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions);
            if (this.dimensions.length == 1 && this.dimensions[0] == 1 && this.op.x().isMatrix()) {
                removeIndex = new int[]{this.op.x().length(), 1};
            }
            INDArray create = Nd4j.create(removeIndex);
            int i = 0;
            Iterator<Task<Double>> it = this.subTasks.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                create.putScalar(i2, it.next().blockUntilComplete().doubleValue());
            }
            this.op.setZ(create);
            return create;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public INDArray call() {
        int tensorssAlongDimension = this.op.x().tensorssAlongDimension(this.dimensions);
        this.subTasks = new ArrayList(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            OpForDimTask opForDimTask = new OpForDimTask(i);
            opForDimTask.invokeAsync();
            this.subTasks.add(opForDimTask);
        }
        return null;
    }

    @Override // java.util.concurrent.RecursiveTask
    public INDArray compute() {
        if (this.dimensions.length == 1 && !this.op.isPassThrough()) {
            TensorCalculator tensorCalculator = TensorCalculatorFactory.getTensorCalculator(this.op.x(), this.dimensions[0]);
            TensorCalculator tensorCalculator2 = this.op.y() != null ? TensorCalculatorFactory.getTensorCalculator(this.op.y(), this.dimensions[0]) : null;
            INDArray create = Nd4j.create(ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions));
            new CPUAccumulations1dAction(this.op, this.threshold, tensorCalculator, tensorCalculator2, 0, tensorCalculator.getNumTensors() - 1, create).invoke();
            this.op.setZ(create);
            return create;
        }
        int tensorssAlongDimension = this.op.x().tensorssAlongDimension(this.dimensions);
        ArrayList arrayList = new ArrayList(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            OpForDimTaskFJ opForDimTaskFJ = new OpForDimTaskFJ(i);
            opForDimTaskFJ.fork();
            arrayList.add(opForDimTaskFJ);
        }
        INDArray create2 = Nd4j.create(ArrayUtil.removeIndex(this.op.x().shape(), this.dimensions));
        int i2 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            create2.putScalar(i3, ((Double) ((RecursiveTask) it.next()).join()).doubleValue());
        }
        this.op.setZ(create2);
        return create2;
    }
}
