package org.nd4j.linalg.api.parallel.tasks.cpu.transform;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/transform/CPUTransformAlongDimensionTask.class */
public class CPUTransformAlongDimensionTask extends BaseCPUAction {
    protected final TransformOp op;
    protected final int[] dimensions;
    protected List<Task<Void>> subTasks;

    public CPUTransformAlongDimensionTask(TransformOp transformOp, int i, int... iArr) {
        super(transformOp, i);
        this.op = transformOp;
        this.dimensions = iArr;
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public Void call() {
        int tensorssAlongDimension = this.op.x().tensorssAlongDimension(this.dimensions);
        this.subTasks = new ArrayList(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            TransformOp transformOp = (TransformOp) this.op.opForDimension(i, this.dimensions);
            INDArray x = transformOp.x();
            INDArray y = transformOp.y();
            Task<Void> cPUTransformOpAction = y == null ? OpExecutionerUtil.canDoOpDirectly(x) : OpExecutionerUtil.canDoOpDirectly(x, y) ? new CPUTransformOpAction(transformOp, this.threshold) : new CPUTransformOpViaTensorTask(transformOp, this.threshold);
            cPUTransformOpAction.invokeAsync();
            this.subTasks.add(cPUTransformOpAction);
        }
        return null;
    }

    @Override // java.util.concurrent.RecursiveAction
    protected void compute() {
        int tensorssAlongDimension = this.op.x().tensorssAlongDimension(this.dimensions);
        ArrayList arrayList = new ArrayList(tensorssAlongDimension);
        for (int i = 0; i < tensorssAlongDimension; i++) {
            TransformOp transformOp = (TransformOp) this.op.opForDimension(i, this.dimensions);
            INDArray x = transformOp.x();
            INDArray y = transformOp.y();
            RecursiveAction cPUTransformOpAction = y == null ? OpExecutionerUtil.canDoOpDirectly(x) : OpExecutionerUtil.canDoOpDirectly(x, y) ? new CPUTransformOpAction(transformOp, this.threshold) : new CPUTransformOpViaTensorTask(transformOp, this.threshold);
            cPUTransformOpAction.fork();
            arrayList.add(cPUTransformOpAction);
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((RecursiveAction) it.next()).join();
        }
    }
}
