package org.nd4j.linalg.api.ops.executioner;

import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.complex.IComplexNumber;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.parallel.ParallelExecutioner;
import org.nd4j.linalg.api.parallel.tasks.TaskFactory;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/executioner/DefaultOpExecutioner.class */
public class DefaultOpExecutioner implements OpExecutioner {
    protected OpExecutioner.ExecutionMode executionMode = OpExecutioner.ExecutionMode.JAVA;
    protected TaskFactory taskFactory = Nd4j.getTaskFactory();

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public ParallelExecutioner parallelExecutioner() {
        throw new UnsupportedOperationException();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op) {
        if (op.isPassThrough()) {
            op.exec();
            return op;
        }
        if (op instanceof TransformOp) {
            doTransformOp((TransformOp) op);
        } else if (op instanceof Accumulation) {
            doAccumulationOp((Accumulation) op);
        } else if (op instanceof ScalarOp) {
            doScalarOp((ScalarOp) op);
        } else if (op instanceof IndexAccumulation) {
            doIndexAccumulationOp((IndexAccumulation) op);
        } else if (op instanceof BroadcastOp) {
            doBroadcastOp((BroadcastOp) op);
        }
        return op;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(Op op) {
        if (op instanceof TransformOp) {
            return execAndReturn((TransformOp) op);
        }
        if (op instanceof ScalarOp) {
            return execAndReturn((ScalarOp) op);
        }
        if (op instanceof Accumulation) {
            return Nd4j.scalar(execAndReturn((Accumulation) op).getFinalResult());
        }
        if (op instanceof IndexAccumulation) {
            return Nd4j.scalar(execAndReturn((IndexAccumulation) op).getFinalResult());
        }
        throw new IllegalArgumentException("Illegal type of op: " + op.getClass());
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllRows(Op op) {
        if (op.x().isVector()) {
            op.setX(op.x());
            if (op.y() != null) {
                op.setY(op.y());
            }
            op.setZ(op.z());
            exec(op);
            return;
        }
        if (!op.x().isMatrix()) {
            INDArray x = op.x();
            INDArray z = op.z();
            for (int i = 0; i < x.slices(); i++) {
                INDArray slice = x.slice(i);
                INDArray slice2 = z.slice(i);
                op.setX(slice);
                op.setZ(slice2);
                iterateOverAllRows(op);
            }
            return;
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) op.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) op.z();
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) op.y();
            for (int i2 = 0; i2 < iComplexNDArray.rows(); i2++) {
                IComplexNDArray slice3 = iComplexNDArray.slice(i2);
                IComplexNDArray slice4 = iComplexNDArray2.slice(i2);
                op.setX(slice3.dup());
                op.setZ(slice4.dup());
                if (iComplexNDArray3 != null) {
                    op.setY(iComplexNDArray3.slice(i2));
                }
                exec(op);
                iComplexNDArray2.slice(i2).assign(op.z());
            }
            return;
        }
        INDArray x2 = op.x();
        INDArray z2 = op.z();
        INDArray y = op.y();
        for (int i3 = 0; i3 < x2.rows(); i3++) {
            INDArray row = x2.getRow(i3);
            INDArray row2 = z2.getRow(i3);
            op.setX(row.dup());
            op.setZ(row2.dup());
            if (y != null) {
                op.setY(y.getRow(i3).dup());
            }
            exec(op);
            row2.assign(op.z());
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void iterateOverAllColumns(Op op) {
        if (op.x().isVector()) {
            exec(op);
            return;
        }
        if (op.x().isMatrix() || op.x().isColumnVector()) {
            exec(op, 1);
            return;
        }
        if (op.x() instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) op.x();
            IComplexNDArray iComplexNDArray2 = (IComplexNDArray) op.z();
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) op.y();
            for (int i = 0; i < op.x().slices(); i++) {
                op.setX(iComplexNDArray.getColumn(i));
                op.setZ(iComplexNDArray2.getColumn(i));
                if (iComplexNDArray3 != null) {
                    op.setY(iComplexNDArray3.getColumn(i));
                }
                iterateOverAllColumns(op);
            }
            return;
        }
        INDArray x = op.x();
        INDArray z = op.z();
        INDArray y = op.y();
        for (int i2 = 0; i2 < op.x().slices(); i2++) {
            op.setX(x.getColumn(i2));
            op.setZ(z.getColumn(i2));
            if (y != null) {
                op.setY(y.getColumn(i2));
            }
            iterateOverAllColumns(op);
        }
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(TransformOp transformOp) {
        return ((TransformOp) exec(transformOp)).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Accumulation execAndReturn(Accumulation accumulation) {
        return (Accumulation) exec(accumulation);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ScalarOp scalarOp) {
        return exec(scalarOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public IndexAccumulation execAndReturn(IndexAccumulation indexAccumulation) {
        return (IndexAccumulation) exec(indexAccumulation);
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(BroadcastOp broadcastOp) {
        return exec(broadcastOp).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public Op exec(Op op, int... iArr) {
        if (iArr.length == op.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (op.isPassThrough()) {
            op.exec(iArr);
            return op;
        }
        if ((op instanceof Accumulation) || (op instanceof IndexAccumulation)) {
            throw new IllegalStateException("exec(Op,int...) should never be invoked for Accumulation/IndexAccumulation");
        }
        if (op instanceof TransformOp) {
            execAndReturn((TransformOp) op, iArr);
            return op;
        }
        if (!(op instanceof ScalarOp)) {
            throw new UnsupportedOperationException("Unknown op type");
        }
        doScalarOp((ScalarOp) op);
        return op;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(Accumulation accumulation, int... iArr) {
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (accumulation.isPassThrough()) {
            accumulation.exec(iArr);
            return accumulation.z();
        }
        if (iArr[0] == Integer.MAX_VALUE) {
            return accumulation.x() instanceof IComplexNDArray ? Nd4j.scalar(execAndReturn(accumulation).getFinalResultComplex()) : Nd4j.scalar(execAndReturn(accumulation).getFinalResult().doubleValue());
        }
        if (!(accumulation instanceof IComplexNDArray)) {
            return this.taskFactory.getAccumulationTask(accumulation, iArr).invokeBlocking();
        }
        int[] removeIndex = ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        IComplexNDArray createComplex = Nd4j.createComplex(removeIndex);
        for (int i = 0; i < accumulation.x().tensorssAlongDimension(iArr); i++) {
            createComplex.putScalar(i, execAndReturn((Accumulation) accumulation.opForDimension(i, iArr)).getFinalResultComplex());
        }
        if (createComplex.ordering() == 'c') {
            createComplex.setStride(ArrayUtil.reverseCopy(createComplex.stride()));
        }
        return createComplex;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (indexAccumulation.isPassThrough()) {
            indexAccumulation.exec(iArr);
            return indexAccumulation.z();
        }
        if (iArr[0] == Integer.MAX_VALUE) {
            return Nd4j.scalar(execAndReturn(indexAccumulation).getFinalResult());
        }
        if (!(indexAccumulation.x() instanceof IComplexNDArray)) {
            return this.taskFactory.getIndexAccumulationTask(indexAccumulation, iArr).invokeBlocking();
        }
        int[] removeIndex = ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        IComplexNDArray createComplex = Nd4j.createComplex(removeIndex);
        for (int i = 0; i < indexAccumulation.x().tensorssAlongDimension(iArr); i++) {
            createComplex.putScalar(i, execAndReturn((IndexAccumulation) indexAccumulation.opForDimension(i, iArr)).getFinalResult());
        }
        if (createComplex.ordering() == 'c') {
            createComplex.setStride(ArrayUtil.reverseCopy(createComplex.stride()));
        }
        return createComplex;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(TransformOp transformOp, int... iArr) {
        if (iArr.length == transformOp.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        if (transformOp.isPassThrough()) {
            transformOp.exec(iArr);
            return transformOp.z();
        }
        this.taskFactory.getTransformAction(transformOp, iArr).invokeBlocking();
        return transformOp.z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public INDArray execAndReturn(ScalarOp scalarOp, int... iArr) {
        return exec(scalarOp, iArr).z();
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public OpExecutioner.ExecutionMode executionMode() {
        return this.executionMode;
    }

    @Override // org.nd4j.linalg.api.ops.executioner.OpExecutioner
    public void setExecutionMode(OpExecutioner.ExecutionMode executionMode) {
        this.executionMode = executionMode;
    }

    private void doTransformOp(TransformOp transformOp) {
        INDArray x = transformOp.x();
        INDArray y = transformOp.y();
        INDArray z = transformOp.z();
        if (!(x instanceof IComplexNDArray) && !(y instanceof IComplexNDArray) && !(z instanceof IComplexNDArray)) {
            this.taskFactory.getTransformAction(transformOp).invokeBlocking();
            return;
        }
        if (y == null) {
            if (z instanceof IComplexNDArray) {
                IComplexNDArray iComplexNDArray = (IComplexNDArray) z;
                if (!(x instanceof IComplexNDArray)) {
                    for (int i = 0; i < transformOp.n(); i++) {
                        iComplexNDArray.putScalar(i, transformOp.op(x.getDouble(i)));
                    }
                    return;
                }
                IComplexNDArray iComplexNDArray2 = (IComplexNDArray) x;
                for (int i2 = 0; i2 < transformOp.n(); i2++) {
                    iComplexNDArray.putScalar(i2, transformOp.op(iComplexNDArray2.getComplex(i2)));
                }
                return;
            }
            return;
        }
        if (!(z instanceof IComplexNDArray)) {
            throw new UnsupportedOperationException("Invalid op: z is real but x.class=" + x.getClass().getName() + ", y.class=" + y.getClass().getName());
        }
        IComplexNDArray iComplexNDArray3 = (IComplexNDArray) z;
        if (x instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray4 = (IComplexNDArray) x;
            if (!(y instanceof IComplexNDArray)) {
                for (int i3 = 0; i3 < transformOp.n(); i3++) {
                    iComplexNDArray3.putScalar(i3, transformOp.op(iComplexNDArray4.getComplex(i3), y.getDouble(i3)));
                }
                return;
            }
            IComplexNDArray iComplexNDArray5 = (IComplexNDArray) y;
            for (int i4 = 0; i4 < transformOp.n(); i4++) {
                iComplexNDArray3.putScalar(i4, transformOp.op(iComplexNDArray4.getComplex(i4), iComplexNDArray5.getComplex(i4)));
            }
        }
    }

    private void doAccumulationOp(Accumulation accumulation) {
        INDArray x = accumulation.x();
        INDArray y = accumulation.y();
        if (!(x instanceof IComplexNDArray) && !(y instanceof IComplexNDArray)) {
            this.taskFactory.getAccumulationTask(accumulation).invokeBlocking();
            return;
        }
        if (y == null) {
            IComplexNDArray iComplexNDArray = (IComplexNDArray) x;
            IComplexNumber zeroComplex = accumulation.zeroComplex();
            for (int i = 0; i < accumulation.n(); i++) {
                zeroComplex = accumulation.update(zeroComplex, iComplexNDArray.getComplex(i), i);
            }
            accumulation.setFinalResultComplex(zeroComplex);
            return;
        }
        if (!(x instanceof IComplexNDArray) || !(y instanceof IComplexNDArray)) {
            throw new UnsupportedOperationException("Invalid input for accumulation op: x.class=" + x.getClass().getName() + ", y.class=" + y.getClass().getName());
        }
        IComplexNDArray iComplexNDArray2 = (IComplexNDArray) x;
        IComplexNDArray iComplexNDArray3 = (IComplexNDArray) y;
        IComplexNumber zeroComplex2 = accumulation.zeroComplex();
        for (int i2 = 0; i2 < accumulation.n(); i2++) {
            zeroComplex2 = accumulation.update(zeroComplex2, iComplexNDArray2.getComplex(i2), iComplexNDArray3.getComplex(i2));
        }
        accumulation.setFinalResultComplex(zeroComplex2);
    }

    private void doScalarOp(ScalarOp scalarOp) {
        INDArray x = scalarOp.x();
        INDArray z = scalarOp.z();
        if (!(x instanceof IComplexNDArray) && !(z instanceof IComplexNDArray)) {
            this.taskFactory.getScalarAction(scalarOp).invokeBlocking();
            return;
        }
        if (!(z instanceof IComplexNDArray)) {
            throw new UnsupportedOperationException("Scalar op with complex x but real z: not supported");
        }
        IComplexNDArray iComplexNDArray = (IComplexNDArray) z;
        if (!(x instanceof IComplexNDArray)) {
            for (int i = 0; i < scalarOp.n(); i++) {
                iComplexNDArray.putScalar(i, scalarOp.op(x.getDouble(i)));
            }
            return;
        }
        IComplexNDArray iComplexNDArray2 = (IComplexNDArray) x;
        for (int i2 = 0; i2 < scalarOp.n(); i2++) {
            iComplexNDArray.putScalar(i2, scalarOp.op(iComplexNDArray2.getComplex(i2)));
        }
    }

    private void doIndexAccumulationOp(IndexAccumulation indexAccumulation) {
        INDArray x = indexAccumulation.x();
        INDArray y = indexAccumulation.y();
        if (!(x instanceof IComplexNDArray) && !(y instanceof IComplexNDArray)) {
            this.taskFactory.getIndexAccumulationTask(indexAccumulation).invokeBlocking();
            return;
        }
        if (y == null) {
            int i = -1;
            IComplexNDArray iComplexNDArray = (IComplexNDArray) x;
            IComplexNumber zeroComplex = indexAccumulation.zeroComplex();
            for (int i2 = 0; i2 < indexAccumulation.n(); i2++) {
                i = indexAccumulation.update(zeroComplex, i, iComplexNDArray.getComplex(i2), i2);
                if (i == i2) {
                    zeroComplex = indexAccumulation.op(iComplexNDArray.getComplex(i2));
                }
            }
            indexAccumulation.setFinalResult(i);
            return;
        }
        if (!(x instanceof IComplexNDArray) || !(y instanceof IComplexNDArray)) {
            throw new UnsupportedOperationException("Invalid input for index accumulation op: x.class=" + x.getClass().getName() + ", y.class=" + y.getClass().getName());
        }
        int i3 = -1;
        IComplexNDArray iComplexNDArray2 = (IComplexNDArray) x;
        IComplexNDArray iComplexNDArray3 = (IComplexNDArray) y;
        IComplexNumber zeroComplex2 = indexAccumulation.zeroComplex();
        for (int i4 = 0; i4 < indexAccumulation.n(); i4++) {
            i3 = indexAccumulation.update(zeroComplex2, i3, iComplexNDArray2.getComplex(i4), iComplexNDArray3.getComplex(i4), i4);
            if (i3 == i4) {
                zeroComplex2 = indexAccumulation.op(iComplexNDArray2.getComplex(i4), iComplexNDArray3.getComplex(i4));
            }
        }
        indexAccumulation.setFinalResult(i3);
    }

    private void doBroadcastOp(BroadcastOp broadcastOp) {
        INDArray x = broadcastOp.x();
        INDArray y = broadcastOp.y();
        INDArray z = broadcastOp.z();
        if (!(x instanceof IComplexNDArray) && !(y instanceof IComplexNDArray) && !(z instanceof IComplexNDArray)) {
            this.taskFactory.getBroadcastOpAction(broadcastOp).invokeBlocking();
            return;
        }
        int tensorssAlongDimension = x.tensorssAlongDimension(broadcastOp.getDimension());
        if (!(x instanceof IComplexNDArray)) {
            throw new UnsupportedOperationException("Complex vector op with real x not supported/implemented");
        }
        IComplexNDArray iComplexNDArray = (IComplexNDArray) x;
        IComplexNDArray iComplexNDArray2 = (IComplexNDArray) z;
        if (y instanceof IComplexNDArray) {
            IComplexNDArray iComplexNDArray3 = (IComplexNDArray) y;
            for (int i = 0; i < tensorssAlongDimension; i++) {
                IComplexNDArray iComplexNDArray4 = (IComplexNDArray) iComplexNDArray.tensorAlongDimension(i, broadcastOp.getDimension());
                IComplexNDArray iComplexNDArray5 = (IComplexNDArray) iComplexNDArray2.tensorAlongDimension(i, broadcastOp.getDimension());
                for (int i2 = 0; i2 < iComplexNDArray4.length(); i2++) {
                    iComplexNDArray5.put(i2, Nd4j.scalar(broadcastOp.op(iComplexNDArray4.getComplex(i2), iComplexNDArray3.getComplex(i2))));
                }
            }
            return;
        }
        if (y == null) {
            for (int i3 = 0; i3 < tensorssAlongDimension; i3++) {
                IComplexNDArray iComplexNDArray6 = (IComplexNDArray) iComplexNDArray.tensorAlongDimension(i3, broadcastOp.getDimension());
                IComplexNDArray iComplexNDArray7 = (IComplexNDArray) iComplexNDArray2.tensorAlongDimension(i3, broadcastOp.getDimension());
                for (int i4 = 0; i4 < iComplexNDArray7.length(); i4++) {
                    iComplexNDArray7.put(i3, Nd4j.scalar(broadcastOp.op(iComplexNDArray6.getComplex(i3))));
                }
            }
            return;
        }
        for (int i5 = 0; i5 < tensorssAlongDimension; i5++) {
            IComplexNDArray iComplexNDArray8 = (IComplexNDArray) iComplexNDArray.tensorAlongDimension(i5, broadcastOp.getDimension());
            IComplexNDArray iComplexNDArray9 = (IComplexNDArray) iComplexNDArray2.tensorAlongDimension(i5, broadcastOp.getDimension());
            for (int i6 = 0; i6 < iComplexNDArray8.length(); i6++) {
                iComplexNDArray9.put(i6, Nd4j.scalar(broadcastOp.op(iComplexNDArray8.getComplex(i6), y.getDouble(i6))));
            }
        }
    }
}
