package org.nd4j.linalg.cpu.nativecpu.ops;

import java.util.Arrays;
import org.apache.commons.math3.util.Pair;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexNDArray;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Accumulation;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.accum.Variance;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.nativeblas.NativeOps;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.class */
public class NativeOpExecutioner extends DefaultOpExecutioner {
    private NativeOps loop = new NativeOps();
    private ConstantHandler constantHandler = Nd4j.getConstantHandler();
    private CpuTADManager tadManager = new CpuTADManager();

    public NativeOpExecutioner() {
        this.tadManager.init(this.loop, this.constantHandler);
    }

    public Op exec(Op op) {
        if (op instanceof ScalarOp) {
            exec((ScalarOp) op);
        } else if (op instanceof TransformOp) {
            exec((TransformOp) op);
        } else if (op instanceof Accumulation) {
            exec((Accumulation) op);
        } else if (op instanceof IndexAccumulation) {
            exec((IndexAccumulation) op);
        } else if (op instanceof BroadcastOp) {
            BroadcastOp broadcastOp = (BroadcastOp) op;
            exec(broadcastOp, broadcastOp.getDimension());
        }
        return op;
    }

    public INDArray exec(IndexAccumulation indexAccumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + indexAccumulation.x().rank();
            }
        }
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(indexAccumulation.x().shape(), iArr);
        if (indexAccumulation.x().isVector() && indexAccumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return indexAccumulation.x();
        }
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        indexAccumulation.setZ(Nd4j.valueArrayOf(removeIndex, indexAccumulation.zeroDouble()));
        if (iArr.length == indexAccumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        Pointer addressPointer = this.constantHandler.getConstantBuffer(iArr).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(indexAccumulation.x(), iArr);
        Pointer addressPointer2 = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{addressPointer2, dataBuffer == null ? null : dataBuffer.addressPointer()});
        Pointer addressPointer3 = indexAccumulation.x().data().addressPointer();
        Pointer addressPointer4 = indexAccumulation.z().data().addressPointer();
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execIndexReduceDouble(pointerPointer, indexAccumulation.opNum(), addressPointer3, indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation), addressPointer4, indexAccumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer, iArr.length);
        } else {
            this.loop.execIndexReduceFloat(pointerPointer, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation), indexAccumulation.z().data().addressPointer(), indexAccumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer, iArr.length);
        }
        return indexAccumulation.z();
    }

    public INDArray exec(Accumulation accumulation, int... iArr) {
        Arrays.sort(iArr);
        for (int i = 0; i < iArr.length; i++) {
            if (iArr[i] < 0) {
                int i2 = i;
                iArr[i2] = iArr[i2] + accumulation.x().rank();
            }
        }
        if (iArr.length == accumulation.x().rank()) {
            iArr = new int[]{Integer.MAX_VALUE};
        }
        int[] removeIndex = Shape.wholeArrayDimension(iArr) ? new int[]{1, 1} : ArrayUtil.removeIndex(accumulation.x().shape(), iArr);
        if (removeIndex.length == 1) {
            removeIndex = iArr[0] == 0 ? new int[]{1, removeIndex[0]} : new int[]{removeIndex[0], 1};
        } else if (removeIndex.length == 0) {
            removeIndex = new int[]{1, 1};
        }
        if (accumulation.x().isVector() && accumulation.x().length() == ArrayUtil.prod(removeIndex)) {
            return accumulation.noOp();
        }
        INDArray valueArrayOf = Nd4j.valueArrayOf(removeIndex, accumulation.zeroDouble());
        accumulation.setZ(valueArrayOf);
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(accumulation.x(), iArr);
        Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{addressPointer, dataBuffer == null ? null : dataBuffer.addressPointer()});
        Pointer addressPointer2 = this.constantHandler.getConstantBuffer(iArr).addressPointer();
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                if (valueArrayOf.isScalar()) {
                    valueArrayOf.putScalar(0, this.loop.execSummaryStatsScalarDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), true));
                } else {
                    this.loop.execSummaryStatsDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer2, iArr.length, ((Variance) accumulation).isBiasCorrected());
                }
            } else if (accumulation.y() != null) {
                if (valueArrayOf.isScalar()) {
                    valueArrayOf.putScalar(0, this.loop.execReduce3ScalarDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer()));
                } else {
                    this.loop.execReduce3Double(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer(), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer2, iArr.length);
                }
            } else if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execReduceScalarDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation)));
            } else {
                this.loop.execReduceDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer2, iArr.length);
            }
        } else if (accumulation instanceof Variance) {
            Variance variance = (Variance) accumulation;
            if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execSummaryStatsScalarFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), variance.isBiasCorrected()));
            } else {
                this.loop.execSummaryStatsFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer2, iArr.length, variance.isBiasCorrected());
            }
        } else if (accumulation.y() != null) {
            if (valueArrayOf.isScalar()) {
                valueArrayOf.putScalar(0, this.loop.execReduce3ScalarFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer()));
            } else {
                this.loop.execReduce3Float(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer(), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer2, iArr.length);
            }
        } else if (valueArrayOf.isScalar()) {
            valueArrayOf.putScalar(0, this.loop.execReduceScalarFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation)));
        } else {
            this.loop.execReduceFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.z().data().addressPointer(), accumulation.z().shapeInfoDataBuffer().addressPointer(), addressPointer2, iArr.length);
        }
        return valueArrayOf;
    }

    private void exec(ScalarOp scalarOp) {
        if ((scalarOp.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(scalarOp);
            return;
        }
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null});
        if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (scalarOp.x().elementWiseStride() < 1 || scalarOp.isExecSpecial() || scalarOp.z().elementWiseStride() < 1 || scalarOp.isExecSpecial()) {
                this.loop.execScalarDouble(pointerPointer, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().shapeInfoDataBuffer().addressPointer(), scalarOp.z().data().addressPointer(), scalarOp.z().shapeInfoDataBuffer().addressPointer(), scalarOp.scalar().doubleValue(), getPointerForExtraArgs(scalarOp));
                return;
            } else {
                this.loop.execScalarDouble(pointerPointer, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().elementWiseStride(), scalarOp.z().data().addressPointer(), scalarOp.z().elementWiseStride(), scalarOp.scalar().doubleValue(), getPointerForExtraArgs(scalarOp), scalarOp.n());
                return;
            }
        }
        if (scalarOp.x().elementWiseStride() < 1 || scalarOp.isExecSpecial() || scalarOp.z().elementWiseStride() < 1 || scalarOp.isExecSpecial()) {
            this.loop.execScalarFloat(pointerPointer, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().shapeInfoDataBuffer().addressPointer(), scalarOp.z().data().addressPointer(), scalarOp.z().shapeInfoDataBuffer().addressPointer(), scalarOp.scalar().floatValue(), getPointerForExtraArgs(scalarOp));
        } else {
            this.loop.execScalarFloat(pointerPointer, scalarOp.opNum(), scalarOp.x().data().addressPointer(), scalarOp.x().elementWiseStride(), scalarOp.z().data().addressPointer(), scalarOp.z().elementWiseStride(), scalarOp.scalar().floatValue(), getPointerForExtraArgs(scalarOp), scalarOp.n());
        }
    }

    private Pointer getPointerForExtraArgs(Op op) {
        if (op.extraArgs() != null) {
            return op.extraArgsDataBuff().addressPointer();
        }
        return null;
    }

    private void exec(TransformOp transformOp) {
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null});
        if (transformOp.x().data().dataType() != DataBuffer.Type.DOUBLE) {
            if (transformOp.y() == null) {
                if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.z().ordering()) {
                    this.loop.execTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
                    return;
                } else {
                    this.loop.execTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().elementWiseStride(), transformOp.z().data().addressPointer(), transformOp.z().elementWiseStride(), getPointerForExtraArgs(transformOp), transformOp.n());
                    return;
                }
            }
            if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering()) {
                this.loop.execPairwiseTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.y().data().addressPointer(), transformOp.y().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
                return;
            } else {
                this.loop.execPairwiseTransformFloat(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().elementWiseStride(), transformOp.y().data().addressPointer(), transformOp.y().elementWiseStride(), transformOp.z().data().addressPointer(), transformOp.z().elementWiseStride(), getPointerForExtraArgs(transformOp), transformOp.n());
                return;
            }
        }
        if (transformOp.y() == null) {
            if (transformOp.x().elementWiseStride() < 1 || transformOp.isExecSpecial() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.z().ordering()) {
                this.loop.execTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
                return;
            } else {
                this.loop.execTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().elementWiseStride(), transformOp.z().data().addressPointer(), transformOp.z().elementWiseStride(), getPointerForExtraArgs(transformOp), transformOp.n());
                return;
            }
        }
        if (transformOp.x().elementWiseStride() < 1 || transformOp.y().elementWiseStride() < 1 || transformOp.x().elementWiseStride() != transformOp.y().elementWiseStride() || transformOp.isExecSpecial() || transformOp.x().ordering() != transformOp.y().ordering() || transformOp.x().ordering() != transformOp.z().ordering()) {
            this.loop.execPairwiseTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().shapeInfoDataBuffer().addressPointer(), transformOp.y().data().addressPointer(), transformOp.y().shapeInfoDataBuffer().addressPointer(), transformOp.z().data().addressPointer(), transformOp.z().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(transformOp));
        } else {
            this.loop.execPairwiseTransformDouble(pointerPointer, transformOp.opNum(), transformOp.x().data().addressPointer(), transformOp.x().elementWiseStride(), transformOp.y().data().addressPointer(), transformOp.y().elementWiseStride(), transformOp.z().data().addressPointer(), transformOp.z().elementWiseStride(), getPointerForExtraArgs(transformOp), transformOp.n());
        }
    }

    public INDArray exec(BroadcastOp broadcastOp, int... iArr) {
        Arrays.sort(iArr);
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(broadcastOp.x(), iArr);
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer(), ((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer()});
        Pointer addressPointer = this.constantHandler.getConstantBuffer(iArr).addressPointer();
        if (broadcastOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            this.loop.execBroadcastDouble(pointerPointer, broadcastOp.opNum(), broadcastOp.x().data().addressPointer(), broadcastOp.x().shapeInfoDataBuffer().addressPointer(), broadcastOp.y().data().addressPointer(), broadcastOp.y().shapeInfoDataBuffer().addressPointer(), broadcastOp.z().data().addressPointer(), broadcastOp.z().shapeInfoDataBuffer().addressPointer(), addressPointer, iArr.length);
        } else {
            this.loop.execBroadcastFloat(pointerPointer, broadcastOp.opNum(), broadcastOp.x().data().addressPointer(), broadcastOp.x().shapeInfoDataBuffer().addressPointer(), broadcastOp.y().data().addressPointer(), broadcastOp.y().shapeInfoDataBuffer().addressPointer(), broadcastOp.z().data().addressPointer(), broadcastOp.z().shapeInfoDataBuffer().addressPointer(), addressPointer, iArr.length);
        }
        return broadcastOp.z();
    }

    private void exec(IndexAccumulation indexAccumulation) {
        if ((indexAccumulation.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(indexAccumulation);
            return;
        }
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null});
        if (indexAccumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            indexAccumulation.setFinalResult((int) this.loop.execIndexReduceScalarDouble(pointerPointer, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation)));
        } else {
            indexAccumulation.setFinalResult((int) this.loop.execIndexReduceScalarFloat(pointerPointer, indexAccumulation.opNum(), indexAccumulation.x().data().addressPointer(), indexAccumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(indexAccumulation)));
        }
    }

    private void exec(Accumulation accumulation) {
        if ((accumulation.x() instanceof IComplexNDArray) || executionMode() == OpExecutioner.ExecutionMode.JAVA) {
            super.exec(accumulation);
            return;
        }
        PointerPointer pointerPointer = new PointerPointer(new Pointer[]{null});
        if (accumulation.x().data().dataType() == DataBuffer.Type.DOUBLE) {
            if (accumulation instanceof Variance) {
                accumulation.setFinalResult(Double.valueOf(this.loop.execSummaryStatsScalarDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), true)));
                return;
            } else if (accumulation.y() != null) {
                accumulation.setFinalResult(Double.valueOf(this.loop.execReduce3ScalarDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer())));
                return;
            } else {
                accumulation.setFinalResult(Double.valueOf(this.loop.execReduceScalarDouble(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation))));
                return;
            }
        }
        if (accumulation instanceof Variance) {
            accumulation.setFinalResult(Float.valueOf(this.loop.execSummaryStatsScalarFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), ((Variance) accumulation).isBiasCorrected())));
        } else if (accumulation.y() != null) {
            accumulation.setFinalResult(Float.valueOf(this.loop.execReduce3ScalarFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation), accumulation.y().data().addressPointer(), accumulation.y().shapeInfoDataBuffer().addressPointer())));
        } else {
            accumulation.setFinalResult(Float.valueOf(this.loop.execReduceScalarFloat(pointerPointer, accumulation.opNum(), accumulation.x().data().addressPointer(), accumulation.x().shapeInfoDataBuffer().addressPointer(), getPointerForExtraArgs(accumulation))));
        }
    }

    public CpuTADManager getTadManager() {
        return this.tadManager;
    }
}
