package org.nd4j.linalg.api.parallel.tasks.cpu.accumulation;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.RecursiveTask;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUTask;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/accumulation/CPUAccumulationViaTensorTask.class */
public class CPUAccumulationViaTensorTask extends BaseCPUTask<Double> {
    protected final Accumulation op;
    protected final boolean outerTask;
    protected List<Task<Double>> subTasks;

    public CPUAccumulationViaTensorTask(Accumulation accumulation, int i, boolean z) {
        super(accumulation, i);
        this.op = accumulation;
        this.outerTask = z;
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task
    public Double blockUntilComplete() {
        if (this.future == null) {
            invokeAsync();
        }
        try {
            Double d = (Double) this.future.get();
            if (this.subTasks != null) {
                d = Double.valueOf(this.op.zeroDouble());
                Iterator<Task<Double>> it = this.subTasks.iterator();
                while (it.hasNext()) {
                    d = Double.valueOf(this.op.combineSubResults(d.doubleValue(), it.next().blockUntilComplete().doubleValue()));
                }
            }
            return (!this.outerTask || this.subTasks == null) ? d : Double.valueOf(this.op.getAndSetFinalResult(d.doubleValue()));
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public Double call() {
        return execute(false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // java.util.concurrent.RecursiveTask
    public Double compute() {
        double doubleValue = execute(true).doubleValue();
        return this.outerTask ? Double.valueOf(this.op.getAndSetFinalResult(doubleValue)) : Double.valueOf(doubleValue);
    }

    private Double execute(boolean z) {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        int chooseElementWiseTensorDimension = y == null ? OpExecutionerUtil.chooseElementWiseTensorDimension(x) : OpExecutionerUtil.chooseElementWiseTensorDimension(x, y);
        int tensorssAlongDimension = x.tensorssAlongDimension(chooseElementWiseTensorDimension);
        ArrayList arrayList = null;
        if (z) {
            arrayList = new ArrayList(tensorssAlongDimension);
        } else {
            this.subTasks = new ArrayList(tensorssAlongDimension);
        }
        if (tensorssAlongDimension == 1) {
            CPUAccumulationTask cPUAccumulationTask = new CPUAccumulationTask(this.op, this.threshold, false);
            if (z) {
                return (Double) cPUAccumulationTask.invoke();
            }
            cPUAccumulationTask.invokeAsync();
            this.subTasks.add(cPUAccumulationTask);
            return null;
        }
        if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats tensor1DStats = OpExecutionerUtil.get1DTensorStats(x, chooseElementWiseTensorDimension);
            int tensorLength = tensor1DStats.getTensorLength();
            int elementWiseStride = tensor1DStats.getElementWiseStride();
            if (y == null) {
                for (int i = 0; i < tensorssAlongDimension; i++) {
                    CPUAccumulationTask cPUAccumulationTask2 = new CPUAccumulationTask(this.op, this.threshold, tensorLength, tensor1DStats.getFirstTensorOffset() + (i * tensor1DStats.getTensorStartSeparation()), 0, elementWiseStride, 0, false);
                    if (z) {
                        cPUAccumulationTask2.fork();
                        arrayList.add(cPUAccumulationTask2);
                    } else {
                        cPUAccumulationTask2.invokeAsync();
                        this.subTasks.add(cPUAccumulationTask2);
                    }
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tensor1DStats2 = OpExecutionerUtil.get1DTensorStats(y, chooseElementWiseTensorDimension);
                int elementWiseStride2 = tensor1DStats2.getElementWiseStride();
                for (int i2 = 0; i2 < tensorssAlongDimension; i2++) {
                    CPUAccumulationTask cPUAccumulationTask3 = new CPUAccumulationTask(this.op, this.threshold, tensorLength, tensor1DStats.getFirstTensorOffset() + (i2 * tensor1DStats.getTensorStartSeparation()), tensor1DStats2.getFirstTensorOffset() + (i2 * tensor1DStats2.getTensorStartSeparation()), elementWiseStride, elementWiseStride2, false);
                    if (z) {
                        cPUAccumulationTask3.fork();
                        arrayList.add(cPUAccumulationTask3);
                    } else {
                        cPUAccumulationTask3.invokeAsync();
                        this.subTasks.add(cPUAccumulationTask3);
                    }
                }
            }
        } else {
            for (int i3 = 0; i3 < tensorssAlongDimension; i3++) {
                CPUAccumulationTask cPUAccumulationTask4 = new CPUAccumulationTask(this.op, this.threshold, i3, chooseElementWiseTensorDimension, false);
                if (z) {
                    cPUAccumulationTask4.fork();
                    arrayList.add(cPUAccumulationTask4);
                } else {
                    cPUAccumulationTask4.invokeAsync();
                    this.subTasks.add(cPUAccumulationTask4);
                }
            }
        }
        if (!z) {
            return null;
        }
        double zeroDouble = this.op.zeroDouble();
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            zeroDouble = this.op.combineSubResults(zeroDouble, ((Double) ((RecursiveTask) it.next()).join()).doubleValue());
        }
        return Double.valueOf(zeroDouble);
    }
}
