package org.nd4j.linalg.api.parallel.tasks.cpu.vector;

import io.netty.buffer.ByteBuf;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.RecursiveAction;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.parallel.tasks.Task;
import org.nd4j.linalg.api.parallel.tasks.TaskExecutorProvider;
import org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/vector/CpuBroadcastOp.class */
public class CpuBroadcastOp extends BaseCPUAction {
    protected final BroadcastOp op;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/nd4j/linalg/api/parallel/tasks/cpu/vector/CpuBroadcastOp$SingleVectorAction.class */
    public class SingleVectorAction extends BaseCPUAction {
        private SingleVectorAction(int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
            super(i, i2, i3, i4, i5, i6, i7, i8);
        }

        private SingleVectorAction(int i, int i2, int i3) {
            super(CpuBroadcastOp.this.op, i, i2, i3);
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction, org.nd4j.linalg.api.parallel.tasks.Task
        public void invokeAsync() {
            this.future = TaskExecutorProvider.getTaskExecutor().executeAsync(this);
        }

        @Override // java.util.concurrent.RecursiveAction
        protected void compute() {
            if (this.doTensorFirst) {
                doTensorFirst(CpuBroadcastOp.this.op);
            }
            if (this.n <= this.threshold) {
                execute();
                return;
            }
            int i = this.n / 2;
            SingleVectorAction singleVectorAction = new SingleVectorAction(this.threshold, i, this.offsetX, this.offsetY, this.offsetZ, this.incrX, this.incrY, this.incrZ);
            singleVectorAction.fork();
            SingleVectorAction singleVectorAction2 = new SingleVectorAction(this.threshold, this.n - i, this.offsetX + (i * this.incrX), this.offsetY + (i * this.incrY), this.offsetZ + (i * this.incrZ), this.incrX, this.incrY, this.incrZ);
            singleVectorAction2.fork();
            singleVectorAction.join();
            singleVectorAction2.join();
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction, org.nd4j.linalg.api.parallel.tasks.Task
        public Void blockUntilComplete() {
            if (this.future != null) {
                try {
                    this.future.get();
                    return null;
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            Iterator<Task<Void>> it = this.subTasks.iterator();
            while (it.hasNext()) {
                it.next().blockUntilComplete();
            }
            return null;
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
        public Void call() {
            if (this.n <= this.threshold) {
                execute();
                return null;
            }
            if (this.doTensorFirst) {
                doTensorFirst(CpuBroadcastOp.this.op);
            }
            int i = 1 + (this.n / this.threshold);
            this.subTasks = new ArrayList(i);
            int i2 = this.n / i;
            int i3 = 0;
            int i4 = 0;
            while (i4 < i) {
                int i5 = i4 == i - 1 ? this.n - i3 : i2;
                SingleVectorAction singleVectorAction = new SingleVectorAction(this.threshold, i5, this.offsetX + (i3 * this.incrX), this.offsetY + (i3 * this.incrY), this.offsetZ + (i3 * this.incrZ), this.incrX, this.incrY, this.incrZ);
                singleVectorAction.invokeAsync();
                this.subTasks.add(singleVectorAction);
                i3 += i5;
                i4++;
            }
            return null;
        }

        @Override // org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction
        public void doTensorFirst(Op op) {
            INDArray x = op.x();
            INDArray y = op.y();
            INDArray z = op.z();
            INDArray tensorAlongDimension = x.tensorAlongDimension(this.tensorIdx, this.tensorDim);
            this.n = tensorAlongDimension.length();
            this.offsetX = tensorAlongDimension.offset();
            this.incrX = tensorAlongDimension.elementWiseStride();
            if (y == null) {
                this.offsetY = 0;
                this.incrY = 0;
            } else {
                this.offsetY = y.offset();
                this.incrY = y.elementWiseStride();
            }
            if (z == null) {
                this.offsetZ = 0;
                this.incrZ = 0;
            } else if (z == x) {
                this.offsetZ = this.offsetX;
                this.incrZ = this.incrX;
            } else {
                INDArray tensorAlongDimension2 = z.tensorAlongDimension(this.tensorIdx, this.tensorDim);
                this.offsetZ = tensorAlongDimension2.offset();
                this.incrZ = tensorAlongDimension2.elementWiseStride();
            }
        }

        private Void execute() {
            DataBuffer data = CpuBroadcastOp.this.op.x().data();
            DataBuffer data2 = CpuBroadcastOp.this.op.y().data();
            DataBuffer data3 = CpuBroadcastOp.this.op.z().data();
            if (data.allocationMode() == DataBuffer.AllocationMode.HEAP) {
                if (data.dataType() == DataBuffer.Type.FLOAT) {
                    float[] fArr = (float[]) data.array();
                    float[] fArr2 = (float[]) data2.array();
                    if (this.incrX == 1 && this.incrY == 1 && (data == data3 || this.incrZ == 1)) {
                        if (data == data3) {
                            for (int i = 0; i < this.n; i++) {
                                int i2 = this.offsetX + i;
                                fArr[i2] = CpuBroadcastOp.this.op.op(fArr[i2], fArr2[this.offsetY + i]);
                            }
                            return null;
                        }
                        float[] fArr3 = (float[]) data3.array();
                        for (int i3 = 0; i3 < this.n; i3++) {
                            fArr3[this.offsetZ + i3] = CpuBroadcastOp.this.op.op(fArr[this.offsetX + i3], fArr2[this.offsetY + i3]);
                        }
                        return null;
                    }
                    if (data == data3) {
                        for (int i4 = 0; i4 < this.n; i4++) {
                            int i5 = this.offsetX + (i4 * this.incrX);
                            fArr[i5] = CpuBroadcastOp.this.op.op(fArr[i5], fArr2[this.offsetY + (i4 * this.incrY)]);
                        }
                        return null;
                    }
                    float[] fArr4 = (float[]) data3.array();
                    for (int i6 = 0; i6 < this.n; i6++) {
                        fArr4[this.offsetZ + (i6 * this.incrZ)] = CpuBroadcastOp.this.op.op(fArr[this.offsetX + (i6 * this.incrX)], fArr2[this.offsetY + (i6 * this.incrY)]);
                    }
                    return null;
                }
                double[] dArr = (double[]) data.array();
                double[] dArr2 = (double[]) data2.array();
                if (this.incrX == 1 && this.incrY == 1 && (data == data3 || this.incrZ == 1)) {
                    if (data == data3) {
                        for (int i7 = 0; i7 < this.n; i7++) {
                            int i8 = this.offsetX + i7;
                            dArr[i8] = CpuBroadcastOp.this.op.op(dArr[i8], dArr2[this.offsetY + i7]);
                        }
                        return null;
                    }
                    double[] dArr3 = (double[]) data3.array();
                    for (int i9 = 0; i9 < this.n; i9++) {
                        dArr3[this.offsetZ + i9] = CpuBroadcastOp.this.op.op(dArr[this.offsetX + i9], dArr2[this.offsetY + i9]);
                    }
                    return null;
                }
                if (data == data3) {
                    for (int i10 = 0; i10 < this.n; i10++) {
                        int i11 = this.offsetX + (i10 * this.incrX);
                        dArr[i11] = CpuBroadcastOp.this.op.op(dArr[i11], dArr2[this.offsetY + (i10 * this.incrY)]);
                    }
                    return null;
                }
                double[] dArr4 = (double[]) data3.array();
                for (int i12 = 0; i12 < this.n; i12++) {
                    dArr4[this.offsetZ + (i12 * this.incrZ)] = CpuBroadcastOp.this.op.op(dArr[this.offsetX + (i12 * this.incrX)], dArr2[this.offsetY + (i12 * this.incrY)]);
                }
                return null;
            }
            ByteBuf asNetty = data.asNetty();
            ByteBuf asNetty2 = data2.asNetty();
            ByteBuf asNetty3 = data3.asNetty();
            if (data.dataType() == DataBuffer.Type.FLOAT) {
                int i13 = 4 * this.offsetX;
                int i14 = 4 * this.offsetY;
                int i15 = 4 * this.offsetZ;
                if (this.incrX == 1 && this.incrY == 1 && (data == data3 || this.incrZ == 1)) {
                    if (data != data3) {
                        for (int i16 = 0; i16 < 4 * this.n; i16 += 4) {
                            asNetty3.setFloat(i15 + i16, CpuBroadcastOp.this.op.op(asNetty.getFloat(i13 + i16), asNetty2.getFloat(i14 + i16)));
                        }
                        return null;
                    }
                    for (int i17 = 0; i17 < 4 * this.n; i17 += 4) {
                        int i18 = i13 + i17;
                        asNetty.setFloat(i18, CpuBroadcastOp.this.op.op(asNetty.getFloat(i18), asNetty2.getFloat(i14 + i17)));
                    }
                    return null;
                }
                if (data != data3) {
                    for (int i19 = 0; i19 < 4 * this.n; i19 += 4) {
                        asNetty3.setFloat(i15 + (i19 * this.incrZ), CpuBroadcastOp.this.op.op(asNetty.getFloat(i13 + (i19 * this.incrX)), asNetty2.getFloat(i14 + (i19 * this.incrY))));
                    }
                    return null;
                }
                for (int i20 = 0; i20 < 4 * this.n; i20 += 4) {
                    int i21 = i13 + (i20 * this.incrX);
                    asNetty.setFloat(i21, CpuBroadcastOp.this.op.op(asNetty.getFloat(i21), asNetty2.getFloat(i14 + (i20 * this.incrY))));
                }
                return null;
            }
            int i22 = 8 * this.offsetX;
            int i23 = 8 * this.offsetY;
            int i24 = 8 * this.offsetZ;
            if (this.incrX == 1 && this.incrY == 1 && (data == data3 || this.incrZ == 1)) {
                if (data != data3) {
                    for (int i25 = 0; i25 < 8 * this.n; i25 += 8) {
                        asNetty3.setDouble(i24 + i25, CpuBroadcastOp.this.op.op(asNetty.getDouble(i22 + i25), asNetty2.getDouble(i23 + i25)));
                    }
                    return null;
                }
                for (int i26 = 0; i26 < 8 * this.n; i26 += 8) {
                    int i27 = i22 + i26;
                    asNetty.setDouble(i27, CpuBroadcastOp.this.op.op(asNetty.getDouble(i27), asNetty2.getDouble(i23 + i26)));
                }
                return null;
            }
            if (data != data3) {
                for (int i28 = 0; i28 < 8 * this.n; i28 += 8) {
                    asNetty3.setDouble(i24 + (i28 * this.incrZ), CpuBroadcastOp.this.op.op(asNetty.getDouble(i22 + (i28 * this.incrX)), asNetty2.getDouble(i23 + (i28 * this.incrY))));
                }
                return null;
            }
            for (int i29 = 0; i29 < 8 * this.n; i29 += 8) {
                int i30 = i22 + (i29 * this.incrX);
                asNetty.setDouble(i30, CpuBroadcastOp.this.op.op(asNetty.getDouble(i30), asNetty2.getDouble(i23 + (i29 * this.incrY))));
            }
            return null;
        }
    }

    public CpuBroadcastOp(BroadcastOp broadcastOp, int i) {
        super(broadcastOp, i);
        this.op = broadcastOp;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.api.parallel.tasks.cpu.BaseCPUAction, org.nd4j.linalg.api.parallel.tasks.Task
    public Void blockUntilComplete() {
        if (this.future == null) {
            invokeAsync();
        }
        try {
            this.future.get();
            if (this.subTasks == null) {
                return null;
            }
            Iterator<Task<Void>> it = this.subTasks.iterator();
            while (it.hasNext()) {
                it.next().blockUntilComplete();
            }
            return null;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.nd4j.linalg.api.parallel.tasks.Task, java.util.concurrent.Callable
    public Void call() {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z = this.op.z();
        int columns = x.rank() == 2 ? this.op.getDimension()[0] == 0 ? x.columns() : x.rows() : ArrayUtil.prod(ArrayUtil.removeIndex(x.shape(), this.op.getDimension()));
        this.subTasks = new ArrayList(columns);
        int[] dimension = this.op.getDimension();
        if (x.size(dimension[0]) != y.length()) {
            throw new UnsupportedOperationException("Array length " + y.length() + " does not match x.shape(" + dimension + ")=" + x.size(dimension[0]));
        }
        if (x.rank() != 2) {
            for (int i = 0; i < columns; i++) {
                SingleVectorAction singleVectorAction = new SingleVectorAction(this.threshold, i, dimension[0]);
                singleVectorAction.invokeAsync();
                this.subTasks.add(singleVectorAction);
            }
            return null;
        }
        OpExecutionerUtil.Tensor1DStats tensor1DStats = OpExecutionerUtil.get1DTensorStats(x, dimension);
        if (y == null) {
            if (x == z) {
                for (int i2 = 0; i2 < columns; i2++) {
                    int firstTensorOffset = tensor1DStats.getFirstTensorOffset() + (i2 * tensor1DStats.getTensorStartSeparation());
                    SingleVectorAction singleVectorAction2 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), firstTensorOffset, 0, firstTensorOffset, tensor1DStats.getElementWiseStride(), 0, tensor1DStats.getElementWiseStride());
                    singleVectorAction2.invokeAsync();
                    this.subTasks.add(singleVectorAction2);
                }
                return null;
            }
            OpExecutionerUtil.Tensor1DStats tensor1DStats2 = OpExecutionerUtil.get1DTensorStats(z, dimension);
            for (int i3 = 0; i3 < columns; i3++) {
                SingleVectorAction singleVectorAction3 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), tensor1DStats.getFirstTensorOffset() + (i3 * tensor1DStats.getTensorStartSeparation()), 0, tensor1DStats2.getFirstTensorOffset() + (i3 * tensor1DStats2.getTensorStartSeparation()), tensor1DStats.getElementWiseStride(), 0, tensor1DStats2.getElementWiseStride());
                singleVectorAction3.invokeAsync();
                this.subTasks.add(singleVectorAction3);
            }
            return null;
        }
        int offset = y.offset();
        int elementWiseStride = y.elementWiseStride();
        if (x == z) {
            for (int i4 = 0; i4 < columns; i4++) {
                int firstTensorOffset2 = tensor1DStats.getFirstTensorOffset() + (i4 * tensor1DStats.getTensorStartSeparation());
                SingleVectorAction singleVectorAction4 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), firstTensorOffset2, offset, firstTensorOffset2, tensor1DStats.getElementWiseStride(), elementWiseStride, tensor1DStats.getElementWiseStride());
                singleVectorAction4.invokeAsync();
                this.subTasks.add(singleVectorAction4);
            }
            return null;
        }
        OpExecutionerUtil.Tensor1DStats tensor1DStats3 = OpExecutionerUtil.get1DTensorStats(z, dimension);
        for (int i5 = 0; i5 < columns; i5++) {
            SingleVectorAction singleVectorAction5 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), tensor1DStats.getFirstTensorOffset() + (i5 * tensor1DStats.getTensorStartSeparation()), offset, tensor1DStats3.getFirstTensorOffset() + (i5 * tensor1DStats3.getTensorStartSeparation()), tensor1DStats.getElementWiseStride(), elementWiseStride, tensor1DStats3.getElementWiseStride());
            singleVectorAction5.invokeAsync();
            this.subTasks.add(singleVectorAction5);
        }
        return null;
    }

    @Override // java.util.concurrent.RecursiveAction
    protected void compute() {
        INDArray x = this.op.x();
        INDArray y = this.op.y();
        INDArray z = this.op.z();
        int columns = x.rank() == 2 ? this.op.getDimension()[0] == 0 ? x.columns() : x.rows() : ArrayUtil.prod(ArrayUtil.removeIndex(x.shape(), this.op.getDimension()));
        ArrayList arrayList = new ArrayList(columns);
        int[] dimension = this.op.getDimension();
        if (x.size(dimension[0]) != y.length()) {
            throw new UnsupportedOperationException("Vector length " + y.length() + " does not match x.shape(" + dimension[0] + ")= " + x.size(dimension[0]));
        }
        if (x.rank() == 2) {
            OpExecutionerUtil.Tensor1DStats tensor1DStats = OpExecutionerUtil.get1DTensorStats(x, dimension);
            if (y != null) {
                int offset = y.offset();
                int elementWiseStride = y.elementWiseStride();
                if (x == z) {
                    for (int i = 0; i < columns; i++) {
                        int firstTensorOffset = tensor1DStats.getFirstTensorOffset() + (i * tensor1DStats.getTensorStartSeparation());
                        SingleVectorAction singleVectorAction = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), firstTensorOffset, offset, firstTensorOffset, tensor1DStats.getElementWiseStride(), elementWiseStride, tensor1DStats.getElementWiseStride());
                        singleVectorAction.fork();
                        arrayList.add(singleVectorAction);
                    }
                } else {
                    OpExecutionerUtil.Tensor1DStats tensor1DStats2 = OpExecutionerUtil.get1DTensorStats(z, dimension);
                    for (int i2 = 0; i2 < columns; i2++) {
                        SingleVectorAction singleVectorAction2 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), tensor1DStats.getFirstTensorOffset() + (i2 * tensor1DStats.getTensorStartSeparation()), offset, tensor1DStats2.getFirstTensorOffset() + (i2 * tensor1DStats2.getTensorStartSeparation()), tensor1DStats.getElementWiseStride(), elementWiseStride, tensor1DStats2.getElementWiseStride());
                        singleVectorAction2.fork();
                        arrayList.add(singleVectorAction2);
                    }
                }
            } else if (x == z) {
                for (int i3 = 0; i3 < columns; i3++) {
                    int firstTensorOffset2 = tensor1DStats.getFirstTensorOffset() + (i3 * tensor1DStats.getTensorStartSeparation());
                    SingleVectorAction singleVectorAction3 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), firstTensorOffset2, 0, firstTensorOffset2, tensor1DStats.getElementWiseStride(), 0, tensor1DStats.getElementWiseStride());
                    singleVectorAction3.fork();
                    arrayList.add(singleVectorAction3);
                }
            } else {
                OpExecutionerUtil.Tensor1DStats tensor1DStats3 = OpExecutionerUtil.get1DTensorStats(z, dimension);
                for (int i4 = 0; i4 < columns; i4++) {
                    SingleVectorAction singleVectorAction4 = new SingleVectorAction(this.threshold, tensor1DStats.getTensorLength(), tensor1DStats.getFirstTensorOffset() + (i4 * tensor1DStats.getTensorStartSeparation()), 0, tensor1DStats3.getFirstTensorOffset() + (i4 * tensor1DStats3.getTensorStartSeparation()), tensor1DStats.getElementWiseStride(), 0, tensor1DStats3.getElementWiseStride());
                    singleVectorAction4.fork();
                    arrayList.add(singleVectorAction4);
                }
            }
        } else {
            for (int i5 = 0; i5 < columns; i5++) {
                SingleVectorAction singleVectorAction5 = new SingleVectorAction(this.threshold, i5, dimension[0]);
                singleVectorAction5.fork();
                arrayList.add(singleVectorAction5);
            }
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((RecursiveAction) it.next()).join();
        }
    }
}
