package org.nd4j.linalg.cpu.nativecpu.ops;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import lombok.NonNull;
import org.bytedeco.javacpp.BooleanPointer;
import org.bytedeco.javacpp.DoublePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.LongPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.PointerPointer;
import org.bytedeco.javacpp.ShortPointer;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.AtomicBoolean;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemcpyDirection;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ndarray.INDArrayStatistics;
import org.nd4j.linalg.api.ops.BaseReduceBoolOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.aggregates.Aggregate;
import org.nd4j.linalg.api.ops.aggregates.Batch;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpStatus;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
import org.nd4j.linalg.api.ops.impl.summarystats.Variance;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.api.ops.random.BaseRandomOp;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.api.shape.TadPack;
import org.nd4j.linalg.api.shape.options.ArrayOptionsHelper;
import org.nd4j.linalg.api.shape.options.ArrayType;
import org.nd4j.linalg.cache.ConstantHandler;
import org.nd4j.linalg.cache.TADManager;
import org.nd4j.linalg.cpu.nativecpu.CpuTADManager;
import org.nd4j.linalg.cpu.nativecpu.buffer.LongBuffer;
import org.nd4j.linalg.cpu.nativecpu.buffer.Utf8Buffer;
import org.nd4j.linalg.cpu.nativecpu.rng.CpuNativeRandom;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.exception.ND4JOpProfilerException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.LongPointerWrapper;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import org.nd4j.nativeblas.Nd4jCpu;
import org.nd4j.nativeblas.OpaqueConstantShapeBuffer;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import org.nd4j.nativeblas.OpaqueShapeList;
import org.nd4j.nativeblas.OpaqueTadPack;
import org.nd4j.nativeblas.OpaqueVariable;
import org.nd4j.nativeblas.OpaqueVariablesSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.class */
public class NativeOpExecutioner extends DefaultOpExecutioner {
    private static final Logger log = LoggerFactory.getLogger(NativeOpExecutioner.class);
    private NativeOps loop = NativeOpsHolder.getInstance().getDeviceNativeOps();
    private ConstantHandler constantHandler = Nd4j.getConstantHandler();
    private CpuTADManager tadManager = new CpuTADManager();
    private ThreadLocal<Map<Integer, PointerPointer>> inputShapes = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, PointerPointer>> inputBuffers = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, PointerPointer>> outputShapes = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, PointerPointer>> outputBuffers = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, LongPointer>> iArgsPointer = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, DoublePointer>> tArgsPointer = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, BooleanPointer>> bArgsPointer = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, ShortPointer>> halfArgsPointer = new ThreadLocal<>();
    protected Map<String, CustomOpDescriptor> customOps = null;
    protected ThreadLocal<PointerPointer> extraz = new ThreadLocal<>();
    protected AtomicBoolean experimentalMode = new AtomicBoolean(false);
    protected Map<String, Boolean> mklOverrides = new HashMap();
    private ThreadLocal<Map<Integer, Pointer>> batchPointers = new ThreadLocal<>();
    private ThreadLocal<Map<Integer, AggregateMemoryBlock>> memoryBlocks = new ThreadLocal<>();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$ops$Op$Type;
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$buffer$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$buffer$DataType[DataType.DOUBLE.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            $SwitchMap$org$nd4j$linalg$api$ops$Op$Type = new int[Op.Type.values().length];
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_FLOAT.ordinal()] = 1;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_BOOL.ordinal()] = 2;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_SAME.ordinal()] = 3;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.REDUCE_LONG.ordinal()] = 4;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR.ordinal()] = 5;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.SCALAR_BOOL.ordinal()] = 6;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_ANY.ordinal()] = 7;
            } catch (NoSuchFieldError e9) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_FLOAT.ordinal()] = 8;
            } catch (NoSuchFieldError e10) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_STRICT.ordinal()] = 9;
            } catch (NoSuchFieldError e11) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_SAME.ordinal()] = 10;
            } catch (NoSuchFieldError e12) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.TRANSFORM_BOOL.ordinal()] = 11;
            } catch (NoSuchFieldError e13) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.PAIRWISE_BOOL.ordinal()] = 12;
            } catch (NoSuchFieldError e14) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST.ordinal()] = 13;
            } catch (NoSuchFieldError e15) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$ops$Op$Type[Op.Type.BROADCAST_BOOL.ordinal()] = 14;
            } catch (NoSuchFieldError e16) {
            }
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner$AggregateMemoryBlock.class */
    private static class AggregateMemoryBlock {
        private List<IntPointer> intArrays;
        private IntPointer indexingPointer;
        private Pointer realArgumentsPointer;
        private PointerPointer shapesPointer;
        private PointerPointer argumentsPointer;
        private PointerPointer arraysPointer;
        private final int opNum;

        private AggregateMemoryBlock(@NonNull Aggregate aggregate) {
            this.intArrays = new ArrayList();
            if (aggregate == null) {
                throw new NullPointerException("op is marked non-null but is null");
            }
            this.opNum = aggregate.opNum();
            for (int i = 0; i < aggregate.maxIntArrays(); i++) {
                this.intArrays.add(new IntPointer(aggregate.maxIntArraySize()));
            }
            this.indexingPointer = new IntPointer(aggregate.maxIndexArguments());
            this.realArgumentsPointer = Nd4j.dataType() == DataType.DOUBLE ? new DoublePointer(aggregate.maxRealArguments()) : new FloatPointer(aggregate.maxRealArguments());
            this.shapesPointer = new PointerPointer(aggregate.maxShapes());
            this.argumentsPointer = new PointerPointer(aggregate.maxArguments());
            this.arraysPointer = new PointerPointer(aggregate.maxIntArrays());
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            return obj != null && getClass() == obj.getClass() && this.opNum == ((AggregateMemoryBlock) obj).opNum;
        }

        public int hashCode() {
            return this.opNum;
        }

        public List<IntPointer> getIntArrays() {
            return this.intArrays;
        }

        public IntPointer getIndexingPointer() {
            return this.indexingPointer;
        }

        public Pointer getRealArgumentsPointer() {
            return this.realArgumentsPointer;
        }

        public PointerPointer getShapesPointer() {
            return this.shapesPointer;
        }

        public PointerPointer getArgumentsPointer() {
            return this.argumentsPointer;
        }

        public PointerPointer getArraysPointer() {
            return this.arraysPointer;
        }

        public int getOpNum() {
            return this.opNum;
        }

        public void setIntArrays(List<IntPointer> list) {
            this.intArrays = list;
        }

        public void setIndexingPointer(IntPointer intPointer) {
            this.indexingPointer = intPointer;
        }

        public void setRealArgumentsPointer(Pointer pointer) {
            this.realArgumentsPointer = pointer;
        }

        public void setShapesPointer(PointerPointer pointerPointer) {
            this.shapesPointer = pointerPointer;
        }

        public void setArgumentsPointer(PointerPointer pointerPointer) {
            this.argumentsPointer = pointerPointer;
        }

        public void setArraysPointer(PointerPointer pointerPointer) {
            this.arraysPointer = pointerPointer;
        }

        public String toString() {
            return "NativeOpExecutioner.AggregateMemoryBlock(intArrays=" + getIntArrays() + ", indexingPointer=" + getIndexingPointer() + ", realArgumentsPointer=" + getRealArgumentsPointer() + ", shapesPointer=" + getShapesPointer() + ", argumentsPointer=" + getArgumentsPointer() + ", arraysPointer=" + getArraysPointer() + ", opNum=" + getOpNum() + ")";
        }

        /* synthetic */ AggregateMemoryBlock(Aggregate aggregate, AnonymousClass1 anonymousClass1) {
            this(aggregate);
        }
    }

    public NativeOpExecutioner() {
        this.tadManager.init(this.loop, this.constantHandler);
        this.experimentalMode.set(this.loop.isExperimentalEnabled());
        String str = System.getenv("ND4J_MKL_FALLBACK");
        if (str != null) {
            if (str.equalsIgnoreCase("true")) {
                Nd4jCpu.Environment.getInstance().setUseMKLDNN(false);
                return;
            }
            for (String str2 : str.toLowerCase().split(",")) {
                this.mklOverrides.put(str2, new Boolean(true));
            }
        }
    }

    public INDArray exec(Op op) {
        return exec(op, (OpContext) null);
    }

    public INDArray exec(Op op, OpContext opContext) {
        checkForCompression(op);
        if (op instanceof ScalarOp) {
            exec((ScalarOp) op, opContext);
        } else if (op instanceof TransformOp) {
            exec((TransformOp) op, opContext);
        } else if (op instanceof ReduceOp) {
            exec((ReduceOp) op, opContext);
        } else if (op instanceof IndexAccumulation) {
            exec((IndexAccumulation) op, opContext);
        } else if (op instanceof BroadcastOp) {
            exec((BroadcastOp) op, opContext);
        } else if (op instanceof RandomOp) {
            exec((RandomOp) op, opContext, Nd4j.getRandom());
        }
        return op.z();
    }

    public INDArray exec(IndexAccumulation indexAccumulation) {
        return exec(indexAccumulation, (OpContext) null);
    }

    public INDArray exec(IndexAccumulation indexAccumulation, OpContext opContext) {
        checkForCompression(indexAccumulation);
        INDArray x = getX(indexAccumulation, opContext);
        INDArray z = getZ(indexAccumulation, opContext);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        int[] normalizeAxis = Shape.normalizeAxis(x.rank(), indexAccumulation.dimensions().toIntVector());
        if (x.isEmpty()) {
            for (int i : normalizeAxis) {
                Preconditions.checkArgument(x.shape()[i] != 0, "IndexReduce can't be issued along axis with 0 in shape");
            }
        }
        long[] reductionShape = Shape.reductionShape(x, normalizeAxis, true, indexAccumulation.isKeepDims());
        if (z == null || x == z) {
            INDArray createUninitialized = Nd4j.createUninitialized(DataType.LONG, reductionShape);
            setZ(createUninitialized, indexAccumulation, opContext);
            z = createUninitialized;
        } else if (!Arrays.equals(reductionShape, z.shape())) {
            throw new IllegalStateException("Z array shape does not match expected return type for op " + indexAccumulation + ": expected shape " + Arrays.toString(reductionShape) + ", z.shape()=" + Arrays.toString(z.shape()));
        }
        indexAccumulation.validateDataTypes();
        this.constantHandler.getConstantBuffer(normalizeAxis, DataType.INT).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(x, normalizeAxis);
        Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
        PointerPointer put = this.extraz.get().put(new Pointer[]{addressPointer, dataBuffer == null ? null : dataBuffer.addressPointer()});
        long profilingConfigurableHookIn = profilingConfigurableHookIn(indexAccumulation, new DataBuffer[]{(DataBuffer) tADOnlyShapeInfo.getFirst()});
        OpaqueDataBuffer opaqueDataBuffer = x.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = z.data().getOpaqueDataBuffer();
        if (z.isScalar()) {
            this.loop.execIndexReduceScalar(put, indexAccumulation.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(indexAccumulation, x.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
        } else {
            this.loop.execIndexReduce(put, indexAccumulation.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(indexAccumulation, x.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, indexAccumulation.dimensions().data().getOpaqueDataBuffer(), indexAccumulation.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        profilingConfigurableHookOut(indexAccumulation, opContext, profilingConfigurableHookIn);
        return getZ(indexAccumulation, opContext);
    }

    public INDArray exec(Variance variance) {
        return exec((ReduceOp) variance);
    }

    public INDArray exec(ReduceOp reduceOp) {
        return exec(reduceOp, (OpContext) null);
    }

    public INDArray exec(ReduceOp reduceOp, OpContext opContext) {
        INDArray create;
        RuntimeException runtimeException;
        INDArray x = getX(reduceOp, opContext);
        INDArray y = getY(reduceOp, opContext);
        INDArray z = getZ(reduceOp, opContext);
        Preconditions.checkNotNull(x, "Op.x() cannot be null: Was null for op %s", reduceOp);
        reduceOp.validateDataTypes(opContext);
        if ((reduceOp instanceof BaseReduceOp) && ((BaseReduceOp) reduceOp).isEmptyReduce()) {
            if (z == null) {
                setZ(x.dup(), reduceOp, opContext);
                return z;
            }
            Preconditions.checkState(x.equalShapes(z), "For empty reductions, result (z) array must have same shape as x shape. Got: x=%ndShape, z=%ndShape", x, z);
            z.assign(x);
            return z;
        }
        int[] normalizeAxis = Shape.normalizeAxis(x.rank(), reduceOp.dimensions() != null ? reduceOp.dimensions().toIntVector() : null);
        if ((reduceOp instanceof BaseReduceBoolOp) && x.isEmpty() && (normalizeAxis == null || (normalizeAxis.length == 1 && normalizeAxis[0] == Integer.MAX_VALUE))) {
            if (z == null) {
                setZ(Nd4j.scalar(((BaseReduceBoolOp) reduceOp).emptyValue()), reduceOp, opContext);
            } else {
                z.assign(((BaseReduceBoolOp) reduceOp).emptyValue());
            }
            return z;
        }
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        long[] reductionShape = Shape.reductionShape(x, normalizeAxis, true, reduceOp.isKeepDims());
        if (x.isVector() && x.length() == ArrayUtil.prod(reductionShape) && ArrayUtil.prodLong(reductionShape) > 1 && y == null) {
            return reduceOp.noOp();
        }
        if (z == null || z == x) {
            if (reduceOp.isComplexAccumulation()) {
                create = Nd4j.create(reduceOp.resultType(), new long[]{x.tensorsAlongDimension(normalizeAxis), y.tensorsAlongDimension(normalizeAxis)});
            } else {
                if (y != null) {
                    if (x.length() != y.length()) {
                        long length = x.length() / x.tensorsAlongDimension(normalizeAxis);
                        if (length != y.length()) {
                            throw new ND4JIllegalStateException("Size of TADs along dimension don't match for pairwise execution: (x TAD size = " + length + ", y size = " + y.length());
                        }
                    } else if (x.tensorsAlongDimension(normalizeAxis) != y.tensorsAlongDimension(normalizeAxis)) {
                        throw new ND4JIllegalStateException("Number of TADs along dimension don't match: (x shape = " + Arrays.toString(x.shape()) + ", y shape = " + Arrays.toString(y.shape()) + ", dimension = " + Arrays.toString(normalizeAxis) + ")");
                    }
                }
                create = Nd4j.create(opContext != null ? reduceOp.resultType(opContext) : reduceOp.resultType(), reductionShape);
            }
            setZ(create, reduceOp, opContext);
            z = create;
        } else {
            long prodLong = reductionShape.length == 0 ? 1L : ArrayUtil.prodLong(reductionShape);
            if (reduceOp.isComplexAccumulation() || z.length() == prodLong) {
                if (reduceOp.isComplexAccumulation()) {
                    long tensorsAlongDimension = x.tensorsAlongDimension(normalizeAxis);
                    long tensorsAlongDimension2 = y.tensorsAlongDimension(normalizeAxis);
                    if (z.length() != tensorsAlongDimension * tensorsAlongDimension2) {
                        throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(z.shape()) + "] doesn't match expected [" + (tensorsAlongDimension * tensorsAlongDimension2) + "]");
                    }
                }
            } else if (!x.isEmpty() || !reduceOp.isKeepDims()) {
                throw new ND4JIllegalStateException("Shape of target array for reduction [" + Arrays.toString(z.shape()) + "] doesn't match expected [" + Arrays.toString(reductionShape) + "]");
            }
            create = z;
        }
        Pair<DataBuffer, DataBuffer> makePair = x.isEmpty() ? Pair.makePair(x.data(), (Object) null) : this.tadManager.getTADOnlyShapeInfo(x, normalizeAxis);
        Pair<DataBuffer, DataBuffer> pair = null;
        Pointer addressPointer = x.isEmpty() ? x.shapeInfoDataBuffer().addressPointer() : ((DataBuffer) makePair.getFirst()).addressPointer();
        DataBuffer dataBuffer = x.isEmpty() ? null : (DataBuffer) makePair.getSecond();
        Pointer addressPointer2 = dataBuffer == null ? null : dataBuffer.addressPointer();
        if (y == null || x.tensorAlongDimension(0L, normalizeAxis).length() == y.length()) {
        }
        if (reduceOp.isComplexAccumulation()) {
            pair = this.tadManager.getTADOnlyShapeInfo(y, normalizeAxis);
            if (x.tensorAlongDimension(0L, normalizeAxis).length() != y.tensorAlongDimension(0L, normalizeAxis).length()) {
                throw new ND4JIllegalStateException("Impossible to issue AllDistances operation: TAD lengths mismatch along given dimension: x TAD length = " + x.tensorAlongDimension(0L, normalizeAxis).length() + ", y TAD length " + y.tensorAlongDimension(0L, normalizeAxis).length());
            }
        }
        profilingConfigurableHookIn(reduceOp, new DataBuffer[]{(DataBuffer) makePair.getFirst()});
        this.constantHandler.getConstantBuffer(normalizeAxis, DataType.INT).addressPointer();
        OpaqueDataBuffer opaqueDataBuffer = x.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = z.data().getOpaqueDataBuffer();
        if (reduceOp instanceof Variance) {
            if (create.isScalar()) {
                this.loop.execSummaryStatsScalar((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, ((Variance) reduceOp).isBiasCorrected());
            } else {
                try {
                    this.loop.execSummaryStatsTad((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, ((Variance) reduceOp).isBiasCorrected(), (LongPointer) null, (LongPointer) null);
                } finally {
                }
            }
        } else if (y != null && reduceOp.getOpType() == Op.Type.REDUCE3) {
            OpaqueDataBuffer opaqueDataBuffer3 = y.data().getOpaqueDataBuffer();
            if (reduceOp.isComplexAccumulation()) {
                try {
                    this.loop.execReduce3All((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer3, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, ((DataBuffer) makePair.getFirst()).addressPointer(), new LongPointerWrapper(((DataBuffer) makePair.getSecond()).addressPointer()), ((DataBuffer) pair.getFirst()).addressPointer(), new LongPointerWrapper(((DataBuffer) pair.getSecond()).addressPointer()));
                } finally {
                }
            } else if (create.isScalar()) {
                this.loop.execReduce3Scalar((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer3, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, create.shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
            } else {
                try {
                    this.loop.execReduce3Tad((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer3, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, (LongPointer) null, (LongPointer) null, (LongPointer) null, (LongPointer) null);
                } finally {
                }
            }
        } else if (create.isScalar()) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 1:
                    this.loop.execReduceFloat((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer2, create.shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 2:
                    this.loop.execReduceBool((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, x.dataType()), opaqueDataBuffer2, create.shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 3:
                    this.loop.execReduceSame((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, x.dataType()), opaqueDataBuffer2, create.shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 4:
                    this.loop.execReduceLong((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, x.dataType()), opaqueDataBuffer2, create.shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported op used in reduce: " + reduceOp.getOpType());
            }
        } else {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[reduceOp.getOpType().ordinal()]) {
                case 1:
                    this.loop.execReduceFloat2((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 2:
                    this.loop.execReduceBool2((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, x.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 3:
                    this.loop.execReduceSame2((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, z.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                case 4:
                    this.loop.execReduceLong2((PointerPointer) null, reduceOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(reduceOp, x.dataType()), opaqueDataBuffer2, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, reduceOp.dimensions().data().getOpaqueDataBuffer(), reduceOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                    break;
                default:
                    throw new UnsupportedOperationException("Unsupported op used in reduce: " + reduceOp.getOpType());
            }
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        return getZ(reduceOp, opContext);
    }

    private void invokeScalarAlongDimension(ScalarOp scalarOp) {
        invokeScalarAlongDimension(scalarOp, null);
    }

    private void invokeScalarAlongDimension(ScalarOp scalarOp, OpContext opContext) {
        INDArray x = getX(scalarOp, opContext);
        INDArray y = getY(scalarOp, opContext);
        INDArray z = getZ(scalarOp, opContext);
        int[] intVector = scalarOp.dimensions().toIntVector();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(scalarOp.x(), intVector);
        LongPointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        LongPointer addressPointer2 = ((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo2 = this.tadManager.getTADOnlyShapeInfo(scalarOp.z(), intVector);
        LongPointer addressPointer3 = ((DataBuffer) tADOnlyShapeInfo2.getFirst()).addressPointer();
        LongPointer addressPointer4 = ((DataBuffer) tADOnlyShapeInfo2.getSecond()).addressPointer();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        OpaqueDataBuffer opaqueDataBuffer = x.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z.data().getOpaqueDataBuffer();
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
            case Nd4jCpu.FLOAT32 /* 5 */:
                this.loop.execScalarTad((PointerPointer) null, scalarOp.opNum(), opaqueDataBuffer, scalarOp.x().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, scalarOp.z().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(scalarOp, scalarOp.z().dataType()), scalarOp.dimensions().data().getOpaqueDataBuffer(), scalarOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, addressPointer, addressPointer2, addressPointer3, addressPointer4);
                break;
            case Nd4jCpu.DOUBLE /* 6 */:
                this.loop.execScalarBoolTad((PointerPointer) null, scalarOp.opNum(), opaqueDataBuffer, scalarOp.x().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, scalarOp.z().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, scalarOp.y().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(scalarOp, scalarOp.z().dataType()), scalarOp.dimensions().data().getOpaqueDataBuffer(), scalarOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, addressPointer, addressPointer2, addressPointer3, addressPointer4);
                break;
            default:
                throw new UnsupportedOperationException();
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
    }

    public INDArray exec(ScalarOp scalarOp) {
        return exec(scalarOp, (OpContext) null);
    }

    public INDArray exec(ScalarOp scalarOp, OpContext opContext) {
        long profilingConfigurableHookIn = profilingConfigurableHookIn(scalarOp, new DataBuffer[0]);
        if ((opContext != null && opContext.getOutputArray(0) == null) || getZ(scalarOp, opContext) == null) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
                case Nd4jCpu.FLOAT32 /* 5 */:
                    setZ(getX(scalarOp, opContext).ulike(), scalarOp, opContext);
                    break;
                case Nd4jCpu.DOUBLE /* 6 */:
                    setZ(Nd4j.createUninitialized(DataType.BOOL, getX(scalarOp, opContext).shape()), scalarOp, opContext);
                    break;
                default:
                    throw new ND4JIllegalStateException("Unknown op type: [" + scalarOp.getOpType() + "]");
            }
        }
        if (getX(scalarOp, opContext).length() != getZ(scalarOp, opContext).length()) {
            throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: x.length()=" + getX(scalarOp, opContext).length() + ", z.length()=" + getZ(scalarOp, opContext).length() + " - x shape info = [" + Arrays.toString(getX(scalarOp, opContext).shapeInfoDataBuffer().asInt()) + "], z shape info = [" + Arrays.toString(getZ(scalarOp, opContext).shapeInfoDataBuffer().asInt()) + "]");
        }
        if (scalarOp.dimensions() != null) {
            invokeScalarAlongDimension(scalarOp);
            return getZ(scalarOp, opContext);
        }
        OpaqueDataBuffer opaqueDataBuffer = getX(scalarOp, opContext).data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = scalarOp.scalar().data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = getZ(scalarOp, opContext).data().getOpaqueDataBuffer();
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[scalarOp.getOpType().ordinal()]) {
            case Nd4jCpu.FLOAT32 /* 5 */:
                this.loop.execScalar((PointerPointer) null, scalarOp.opNum(), opaqueDataBuffer, getX(scalarOp, opContext).shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, getZ(scalarOp, opContext).shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, scalarOp.scalar().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(scalarOp, getZ(scalarOp, opContext).dataType()));
                break;
            case Nd4jCpu.DOUBLE /* 6 */:
                this.loop.execScalarBool((PointerPointer) null, scalarOp.opNum(), opaqueDataBuffer, getX(scalarOp, opContext).shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, getZ(scalarOp, opContext).shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, scalarOp.scalar().shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(scalarOp, getX(scalarOp, opContext).dataType()));
                break;
            default:
                throw new ND4JIllegalStateException("Unknown op type: [" + scalarOp.getOpType() + "]");
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        profilingConfigurableHookOut(scalarOp, opContext, profilingConfigurableHookIn);
        return getZ(scalarOp, opContext);
    }

    private Pointer getPointerForExtraArgs(Op op, DataType dataType) {
        DataBuffer extraArgsDataBuff;
        if (op.extraArgs() == null || (extraArgsDataBuff = op.extraArgsDataBuff(dataType)) == null) {
            return null;
        }
        return extraArgsDataBuff.addressPointer();
    }

    private void exec(TransformOp transformOp) {
        exec(transformOp, (OpContext) null);
    }

    private void exec(TransformOp transformOp, OpContext opContext) {
        long profilingConfigurableHookIn;
        INDArray x = getX(transformOp, opContext);
        INDArray y = getY(transformOp, opContext);
        INDArray z = getZ(transformOp, opContext);
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer pointerPointer = this.extraz.get();
        if (transformOp.opNum() == 31 && y != null && y.isScalar()) {
            setY(Nd4j.valueArrayOf(x.shape(), y.getDouble(0L)), transformOp, opContext);
        }
        if (!transformOp.opName().equalsIgnoreCase("ismax") || transformOp.extraArgs() == null || transformOp.extraArgs().length <= 0) {
            profilingConfigurableHookIn = profilingConfigurableHookIn(transformOp, new DataBuffer[0]);
        } else {
            int[] iArr = new int[((Integer) transformOp.extraArgs()[0]).intValue()];
            for (int i = 0; i < iArr.length; i++) {
                iArr[i] = ((Integer) transformOp.extraArgs()[i + 1]).intValue();
            }
            Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(transformOp.z(), iArr);
            Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
            DataBuffer dataBuffer = (DataBuffer) tADOnlyShapeInfo.getSecond();
            Pointer addressPointer2 = dataBuffer == null ? null : dataBuffer.addressPointer();
            pointerPointer.put(0L, addressPointer);
            pointerPointer.put(1L, addressPointer2);
            profilingConfigurableHookIn = profilingConfigurableHookIn(transformOp, new DataBuffer[]{(DataBuffer) tADOnlyShapeInfo.getFirst()});
        }
        if (y != null) {
            if (z == null) {
                setZ(Nd4j.create(transformOp.resultType(), x.shape()), transformOp, opContext);
                z = getZ(transformOp, opContext);
            }
            transformOp.validateDataTypes(opContext, this.experimentalMode.get());
            OpaqueDataBuffer opaqueDataBuffer = x.data().getOpaqueDataBuffer();
            OpaqueDataBuffer opaqueDataBuffer2 = y.data().getOpaqueDataBuffer();
            OpaqueDataBuffer opaqueDataBuffer3 = z.data().getOpaqueDataBuffer();
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[transformOp.getOpType().ordinal()]) {
                case Nd4jCpu.INT8 /* 7 */:
                case Nd4jCpu.INT16 /* 8 */:
                case Nd4jCpu.INT32 /* 9 */:
                case 10:
                    if (!this.experimentalMode.get()) {
                        Preconditions.checkArgument(x.dataType() == y.dataType() || y.dataType() == DataType.BOOL, "Op.X and Op.Y must have the same data type, but got %s vs. %s", x.dataType(), y.dataType());
                    }
                    this.loop.execPairwiseTransform(pointerPointer, transformOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, z.dataType()));
                    break;
                case Nd4jCpu.UINT8 /* 11 */:
                case Nd4jCpu.UINT16 /* 12 */:
                    this.loop.execPairwiseTransformBool(pointerPointer, transformOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, x.dataType()));
                    break;
            }
        } else {
            if (z == null) {
                setZ(Nd4j.createUninitialized(opContext != null ? transformOp.resultType(opContext) : transformOp.resultType(), x.shape()), transformOp, opContext);
                z = getZ(transformOp, opContext);
            }
            transformOp.validateDataTypes(opContext, this.experimentalMode.get());
            OpaqueDataBuffer opaqueDataBuffer4 = x.data().getOpaqueDataBuffer();
            OpaqueDataBuffer opaqueDataBuffer5 = z.data().getOpaqueDataBuffer();
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[transformOp.getOpType().ordinal()]) {
                case Nd4jCpu.INT8 /* 7 */:
                    this.loop.execTransformAny(pointerPointer, transformOp.opNum(), opaqueDataBuffer4, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer5, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, x.dataType()));
                    break;
                case Nd4jCpu.INT16 /* 8 */:
                    this.loop.execTransformFloat(pointerPointer, transformOp.opNum(), opaqueDataBuffer4, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer5, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, z.dataType()));
                    break;
                case Nd4jCpu.INT32 /* 9 */:
                    this.loop.execTransformStrict(pointerPointer, transformOp.opNum(), opaqueDataBuffer4, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer5, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, z.dataType()));
                    break;
                case 10:
                    this.loop.execTransformSame(pointerPointer, transformOp.opNum(), opaqueDataBuffer4, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer5, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, z.dataType()));
                    break;
                case Nd4jCpu.UINT8 /* 11 */:
                    this.loop.execTransformBool(pointerPointer, transformOp.opNum(), opaqueDataBuffer4, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer5, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, getPointerForExtraArgs(transformOp, x.dataType()));
                    break;
                default:
                    throw new UnsupportedOperationException("Unknown transform type: [" + transformOp.getOpType() + "]");
            }
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        profilingConfigurableHookOut(transformOp, opContext, profilingConfigurableHookIn);
    }

    public INDArray exec(BroadcastOp broadcastOp) {
        return exec(broadcastOp, (OpContext) null);
    }

    public INDArray exec(BroadcastOp broadcastOp, OpContext opContext) {
        INDArray x = getX(broadcastOp, opContext);
        INDArray y = getY(broadcastOp, opContext);
        INDArray z = getZ(broadcastOp, opContext);
        profilingConfigurableHookIn(broadcastOp, new DataBuffer[0]);
        broadcastOp.validateDataTypes(this.experimentalMode.get());
        int[] intVector = broadcastOp.dimensions().toIntVector();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(x, intVector);
        Pointer addressPointer = ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer();
        Pointer addressPointer2 = ((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer();
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo2 = this.tadManager.getTADOnlyShapeInfo(z, intVector);
        Pointer addressPointer3 = ((DataBuffer) tADOnlyShapeInfo2.getFirst()).addressPointer();
        Pointer addressPointer4 = ((DataBuffer) tADOnlyShapeInfo2.getSecond()).addressPointer();
        if (this.extraz.get() == null) {
            this.extraz.set(new PointerPointer(32L));
        }
        PointerPointer put = this.extraz.get().put(new Pointer[]{addressPointer, addressPointer2, addressPointer3, addressPointer4});
        this.constantHandler.getConstantBuffer(intVector, DataType.INT).addressPointer();
        OpaqueDataBuffer opaqueDataBuffer = x.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z.data().getOpaqueDataBuffer();
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$ops$Op$Type[broadcastOp.getOpType().ordinal()]) {
            case Nd4jCpu.UINT32 /* 13 */:
                this.loop.execBroadcast(put, broadcastOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, broadcastOp.dimensions().data().getOpaqueDataBuffer(), broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                break;
            case 14:
                this.loop.execBroadcastBool(put, broadcastOp.opNum(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, (Pointer) null, broadcastOp.dimensions().data().getOpaqueDataBuffer(), broadcastOp.dimensions().shapeInfoDataBuffer().addressPointer(), (LongPointer) null);
                break;
            default:
                throw new UnsupportedOperationException("Unknown operation type: [" + broadcastOp.getOpType() + "]");
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        return z;
    }

    protected <T extends Aggregate> Pointer getPointer(Batch<T> batch) {
        if (this.batchPointers.get() == null) {
            this.batchPointers.set(new HashMap());
        }
        if (this.batchPointers.get().containsKey(Integer.valueOf(batch.opNum()))) {
            return this.batchPointers.get().get(Integer.valueOf(batch.opNum()));
        }
        Pointer intPointer = new IntPointer(batch.getSample().getRequiredBatchMemorySize() / 4);
        this.batchPointers.get().put(Integer.valueOf(batch.opNum()), intPointer);
        return intPointer;
    }

    public <T extends Aggregate> void exec(Batch<T> batch) {
        IntPointer pointer = getPointer(batch);
        int maxIntArrays = batch.getSample().maxIntArrays();
        int maxIntArraySize = batch.getSample().maxIntArraySize();
        int batchLimit = (((((5 * Batch.getBatchLimit()) + (batch.getSample().maxIndexArguments() * Batch.getBatchLimit())) + ((maxIntArrays * maxIntArraySize) * Batch.getBatchLimit())) / (Nd4j.dataType() == DataType.DOUBLE ? 2 : 1)) + (batch.getSample().maxRealArguments() * Batch.getBatchLimit())) / (Nd4j.dataType() == DataType.DOUBLE ? 1 : 2);
        int maxArguments = batchLimit + (batch.getSample().maxArguments() * Batch.getBatchLimit());
        DataType dataType = null;
        for (int i = 0; i < batch.getNumAggregates(); i++) {
            Aggregate aggregate = (Aggregate) batch.getAggregates().get(i);
            if (i == 0) {
                dataType = ((INDArray) aggregate.getArguments().get(0)).dataType();
            }
            pointer.put(i * 5, aggregate.getArguments().size());
            pointer.put(r0 + 1, aggregate.getShapes().size());
            pointer.put(r0 + 2, aggregate.getIndexingArguments().size());
            pointer.put(r0 + 3, aggregate.getRealArguments().size());
            pointer.put(r0 + 4, aggregate.getIntArrayArguments().size());
            for (int i2 = 0; i2 < aggregate.getIndexingArguments().size(); i2++) {
                pointer.put(r0 + (i * batch.getSample().maxIndexArguments()) + i2, ((Integer) aggregate.getIndexingArguments().get(i2)).intValue());
            }
            int i3 = maxIntArrays * maxIntArraySize;
            for (int i4 = 0; i4 < aggregate.getIntArrayArguments().size(); i4++) {
                int i5 = (i * i3) + (i4 * maxIntArraySize);
                if (aggregate.getIntArrayArguments().get(i4) != null) {
                    for (int i6 = 0; i6 < ((int[]) aggregate.getIntArrayArguments().get(i4)).length; i6++) {
                        pointer.put(r0 + i5 + i6, ((int[]) aggregate.getIntArrayArguments().get(i4))[i6]);
                    }
                }
            }
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType.ordinal()]) {
                case 1:
                    FloatPointer floatPointer = new FloatPointer(pointer);
                    for (int i7 = 0; i7 < aggregate.getRealArguments().size(); i7++) {
                        floatPointer.put(r0 + (i * aggregate.maxRealArguments()) + i7, ((Number) aggregate.getRealArguments().get(i7)).floatValue());
                    }
                    break;
                case 2:
                    DoublePointer doublePointer = new DoublePointer(pointer);
                    for (int i8 = 0; i8 < aggregate.getRealArguments().size(); i8++) {
                        doublePointer.put(r0 + (i * aggregate.maxRealArguments()) + i8, ((Number) aggregate.getRealArguments().get(i8)).doubleValue());
                    }
                    break;
                default:
                    throw new ND4JIllegalArgumentException("Only FLOAT and DOUBLE datatypes are supported");
            }
            if (this.extraz.get() == null) {
                this.extraz.set(new PointerPointer(32L));
            }
            PointerPointer pointerPointer = new PointerPointer(pointer);
            for (int i9 = 0; i9 < aggregate.getArguments().size(); i9++) {
                int maxArguments2 = batchLimit + (i * batch.getSample().maxArguments());
                if (aggregate.getArguments().get(i9) != null) {
                    pointerPointer.put(maxArguments2 + i9, ((INDArray) aggregate.getArguments().get(i9)).data().addressPointer());
                }
            }
            for (int i10 = 0; i10 < aggregate.getShapes().size(); i10++) {
                int maxShapes = maxArguments + (i * batch.getSample().maxShapes());
                if (aggregate.getShapes().get(i10) != null) {
                    pointerPointer.put(maxShapes + i10, ((DataBuffer) aggregate.getShapes().get(i10)).addressPointer());
                }
            }
        }
        this.loop.execAggregateBatch((PointerPointer) null, batch.getNumAggregates(), batch.opNum(), batch.getSample().maxArguments(), batch.getSample().maxShapes(), batch.getSample().maxIntArrays(), batch.getSample().maxIntArraySize(), batch.getSample().maxIndexArguments(), batch.getSample().maxRealArguments(), pointer, FlatBuffersMapper.getDataTypeAsByte(dataType));
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
    }

    public void exec(List<Aggregate> list) {
        if (list.size() == 0) {
            return;
        }
        Iterator it = Batch.getBatches(list).iterator();
        while (it.hasNext()) {
            exec((Batch) it.next());
        }
    }

    public void exec(Aggregate aggregate) {
        if (this.memoryBlocks.get() == null) {
            this.memoryBlocks.set(new HashMap());
        }
        if (this.memoryBlocks.get().get(Integer.valueOf(aggregate.opNum())) == null) {
            this.memoryBlocks.get().put(Integer.valueOf(aggregate.opNum()), new AggregateMemoryBlock(aggregate, null));
        }
        AggregateMemoryBlock aggregateMemoryBlock = this.memoryBlocks.get().get(Integer.valueOf(aggregate.opNum()));
        int size = aggregate.getArguments().size();
        int size2 = aggregate.getIndexingArguments().size();
        int size3 = aggregate.getRealArguments().size();
        int size4 = aggregate.getShapes().size();
        int size5 = aggregate.getIntArrayArguments().size();
        PointerPointer argumentsPointer = aggregateMemoryBlock.getArgumentsPointer();
        ArrayList arrayList = new ArrayList();
        PointerPointer arraysPointer = aggregateMemoryBlock.getArraysPointer();
        DataType dataType = ((INDArray) aggregate.getArguments().get(0)).dataType();
        for (int i = 0; i < size; i++) {
            argumentsPointer.put(i, aggregate.getArguments().get(i) == null ? null : ((INDArray) aggregate.getArguments().get(i)).data().addressPointer());
        }
        PointerPointer shapesPointer = aggregateMemoryBlock.getShapesPointer();
        for (int i2 = 0; i2 < size4; i2++) {
            if (((DataBuffer) aggregate.getShapes().get(i2)).dataType() != DataType.LONG) {
                throw new RuntimeException("ShapeBuffers should have LONG data opType");
            }
            shapesPointer.put(i2, aggregate.getShapes().get(i2) == null ? null : ((DataBuffer) aggregate.getShapes().get(i2)).addressPointer());
        }
        IntPointer indexingPointer = aggregateMemoryBlock.getIndexingPointer();
        for (int i3 = 0; i3 < size2; i3++) {
            indexingPointer.put(i3, ((Integer) aggregate.getIndexingArguments().get(i3)).intValue());
        }
        double[] dArr = new double[size3];
        for (int i4 = 0; i4 < size3; i4++) {
            switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$buffer$DataType[dataType.ordinal()]) {
                case 1:
                    aggregateMemoryBlock.getRealArgumentsPointer().put(i4, ((Number) aggregate.getRealArguments().get(i4)).floatValue());
                    break;
                case 2:
                    aggregateMemoryBlock.getRealArgumentsPointer().put(i4, ((Number) aggregate.getRealArguments().get(i4)).doubleValue());
                    break;
                default:
                    throw new ND4JIllegalArgumentException("Only FLOAT and DOUBLE datatypes are supported");
            }
        }
        for (int i5 = 0; i5 < size5; i5++) {
            IntPointer intPointer = aggregateMemoryBlock.getIntArrays().get(i5);
            intPointer.put((int[]) aggregate.getIntArrayArguments().get(i5), 0, ((int[]) aggregate.getIntArrayArguments().get(i5)).length);
            arraysPointer.put(i5, intPointer);
            arrayList.add(intPointer);
        }
        this.loop.execAggregate((PointerPointer) null, aggregate.opNum(), argumentsPointer, size, shapesPointer, size4, indexingPointer, size2, arraysPointer, size5, aggregateMemoryBlock.getRealArgumentsPointer(), size3, FlatBuffersMapper.getDataTypeAsByte(dataType));
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
    }

    public Properties getEnvironmentInformation() {
        Properties environmentInformation = super.getEnvironmentInformation();
        environmentInformation.put("backend", "CPU");
        environmentInformation.put("omp.threads", Integer.valueOf(this.loop.ompGetMaxThreads()));
        environmentInformation.put("blas.threads", Integer.valueOf(Nd4j.factory().blas().getMaxThreads()));
        environmentInformation.put("blas.vendor", Nd4j.factory().blas().getBlasVendor().toString());
        environmentInformation.put("memory.free", Long.valueOf(Pointer.maxBytes() - Pointer.totalBytes()));
        if (PerformanceTracker.getInstance() != null) {
            environmentInformation.put("memoryBandwidth", PerformanceTracker.getInstance().getCurrentBandwidth());
        }
        return environmentInformation;
    }

    public INDArray exec(RandomOp randomOp) {
        return exec(randomOp, Nd4j.getRandom());
    }

    public INDArray exec(RandomOp randomOp, Random random) {
        return exec(randomOp, null, random);
    }

    public INDArray exec(RandomOp randomOp, OpContext opContext, Random random) {
        INDArray x = getX(randomOp, opContext);
        INDArray y = getY(randomOp, opContext);
        INDArray z = getZ(randomOp, opContext);
        if ((randomOp instanceof BaseRandomOp) && ((BaseRandomOp) randomOp).isTripleArgRngOp() && z != null && x == null && y == null) {
            x = z;
            y = z;
        }
        if (!(random instanceof CpuNativeRandom)) {
            throw new IllegalStateException("You should use one of NativeRandom classes for NativeOperations execution. Op class: " + randomOp.getClass().getName());
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(randomOp, new DataBuffer[0]);
        Preconditions.checkArgument(z.isR(), "Op.Z must have one of floating point types");
        OpaqueDataBuffer opaqueDataBuffer = x == null ? null : x.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer2 = y == null ? null : y.data().getOpaqueDataBuffer();
        OpaqueDataBuffer opaqueDataBuffer3 = z == null ? null : z.data().getOpaqueDataBuffer();
        if (x != null && y != null && z != null) {
            DataBuffer extraArgsDataBuff = randomOp.extraArgsDataBuff(z.dataType());
            this.loop.execRandom3((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer2, y.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, extraArgsDataBuff != null ? extraArgsDataBuff.addressPointer() : null);
        } else if (x == null || z == null) {
            this.loop.execRandom((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, randomOp.extraArgsDataBuff(z.dataType()).addressPointer());
        } else {
            DataBuffer extraArgsDataBuff2 = randomOp.extraArgsDataBuff(z.dataType());
            this.loop.execRandom2((PointerPointer) null, randomOp.opNum(), random.getStatePointer(), opaqueDataBuffer, x.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, opaqueDataBuffer3, z.shapeInfoDataBuffer().addressPointer(), (LongPointer) null, extraArgsDataBuff2 != null ? extraArgsDataBuff2.addressPointer() : null);
        }
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        profilingConfigurableHookOut(randomOp, opContext, profilingConfigurableHookIn);
        return z;
    }

    public TADManager getTADManager() {
        return this.tadManager;
    }

    public synchronized Map<String, CustomOpDescriptor> getCustomOperations() {
        if (this.customOps == null) {
            String allCustomOps = this.loop.getAllCustomOps();
            if (allCustomOps == null || allCustomOps.isEmpty()) {
                log.warn("No customs ops available!");
                this.customOps = Collections.emptyMap();
                return this.customOps;
            }
            HashMap hashMap = new HashMap();
            for (String str : allCustomOps.split(";")) {
                if (str != null && !str.isEmpty()) {
                    String[] split = str.split(":");
                    hashMap.put(split[0], CustomOpDescriptor.builder().hash(Long.valueOf(split[1]).longValue()).numInputs(Integer.valueOf(split[2]).intValue()).numOutputs(Integer.valueOf(split[3]).intValue()).allowsInplace(Integer.valueOf(split[4]).intValue() == 1).numTArgs(Integer.valueOf(split[5]).intValue()).numIArgs(Integer.valueOf(split[6]).intValue()).build());
                }
            }
            this.customOps = Collections.unmodifiableMap(hashMap);
        }
        return this.customOps;
    }

    private PointerPointer getPointerPointerFrom(ThreadLocal<Map<Integer, PointerPointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new PointerPointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        PointerPointer pointerPointer = new PointerPointer(i);
        threadLocal.get().put(Integer.valueOf(i), pointerPointer);
        return pointerPointer;
    }

    private ShortPointer getShortPointerFrom(ThreadLocal<Map<Integer, ShortPointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new ShortPointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        ShortPointer shortPointer = new ShortPointer(i);
        threadLocal.get().put(Integer.valueOf(i), shortPointer);
        return shortPointer;
    }

    private LongPointer getLongPointerFrom(ThreadLocal<Map<Integer, LongPointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new LongPointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        LongPointer longPointer = new LongPointer(i);
        threadLocal.get().put(Integer.valueOf(i), longPointer);
        return longPointer;
    }

    private DoublePointer getDoublePointerFrom(ThreadLocal<Map<Integer, DoublePointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new DoublePointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        DoublePointer doublePointer = new DoublePointer(i);
        threadLocal.get().put(Integer.valueOf(i), doublePointer);
        return doublePointer;
    }

    private BooleanPointer getBooleanPointerFrom(ThreadLocal<Map<Integer, BooleanPointer>> threadLocal, int i) {
        if (threadLocal.get() == null) {
            HashMap hashMap = new HashMap();
            hashMap.put(Integer.valueOf(i), new BooleanPointer(i));
            threadLocal.set(hashMap);
            return threadLocal.get().get(Integer.valueOf(i));
        }
        if (threadLocal.get().get(Integer.valueOf(i)) != null) {
            return threadLocal.get().get(Integer.valueOf(i));
        }
        BooleanPointer booleanPointer = new BooleanPointer(i);
        threadLocal.get().put(Integer.valueOf(i), booleanPointer);
        return booleanPointer;
    }

    private PointerPointer getInputShapes(int i) {
        return getPointerPointerFrom(this.inputShapes, i);
    }

    private PointerPointer getInputBuffers(int i) {
        return getPointerPointerFrom(this.inputBuffers, i);
    }

    private PointerPointer getOutputShapes(int i) {
        return getPointerPointerFrom(this.outputShapes, i);
    }

    private PointerPointer getOutputBuffers(int i) {
        return getPointerPointerFrom(this.outputBuffers, i);
    }

    public INDArray[] exec(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        boolean z = false;
        if (customOp.numOutputArguments() == 0 && !customOp.isInplaceCall()) {
            try {
                List<LongShapeDescriptor> calculateOutputShape = calculateOutputShape(customOp);
                if (calculateOutputShape.isEmpty()) {
                    throw new ND4JIllegalStateException("Op name " + customOp.opName() + " failed to calculate output datatypes");
                }
                Iterator<LongShapeDescriptor> it = calculateOutputShape.iterator();
                while (it.hasNext()) {
                    customOp.addOutputArgument(new INDArray[]{Nd4j.create(it.next(), false)});
                }
                z = true;
            } catch (ND4JIllegalStateException e) {
                throw e;
            } catch (Exception e2) {
                throw new ND4JIllegalStateException("Op name " + customOp.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e2);
            }
        }
        String opName = customOp.opName();
        try {
            OpContext buildContext = buildContext();
            Throwable th = null;
            if (z) {
                try {
                    try {
                        buildContext.shapeFunctionOverride(true);
                    } finally {
                    }
                } catch (Throwable th2) {
                    if (buildContext != null) {
                        if (th != null) {
                            try {
                                buildContext.close();
                            } catch (Throwable th3) {
                                th.addSuppressed(th3);
                            }
                        } else {
                            buildContext.close();
                        }
                    }
                    throw th2;
                }
            }
            buildContext.markInplace(customOp.isInplaceCall());
            buildContext.setRngStates(Nd4j.getRandom().rootState(), Nd4j.getRandom().nodeState());
            buildContext.setInputArrays(customOp.inputArguments());
            buildContext.setOutputArrays(customOp.outputArguments());
            buildContext.setBArguments(customOp.bArgs());
            buildContext.setIArguments(customOp.iArgs());
            buildContext.setTArguments(customOp.tArgs());
            buildContext.setDArguments(customOp.dArgs());
            INDArray[] exec = exec(customOp, buildContext);
            Pair rngStates = buildContext.getRngStates();
            for (INDArray iNDArray : customOp.inputArguments()) {
                if (!iNDArray.isEmpty()) {
                    iNDArray.data().actualizePointerAndIndexer();
                }
            }
            for (INDArray iNDArray2 : customOp.outputArguments()) {
                if (!iNDArray2.isEmpty()) {
                    iNDArray2.data().actualizePointerAndIndexer();
                }
            }
            Nd4j.getRandom().setStates(((Long) rngStates.getFirst()).longValue(), ((Long) rngStates.getSecond()).longValue());
            if (buildContext != null) {
                if (0 != 0) {
                    try {
                        buildContext.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    buildContext.close();
                }
            }
            return exec;
        } catch (Exception e3) {
            throw new RuntimeException("Op [" + opName + "] execution failed", e3);
        } catch (ND4JOpProfilerException e4) {
            throw e4;
        }
    }

    protected LongShapeDescriptor getShapeFromPointer(LongPointer longPointer) {
        long[] jArr = new long[(((int) longPointer.get(0L)) * 2) + 4];
        for (int i = 0; i < jArr.length; i++) {
            jArr[i] = longPointer.get(i);
        }
        return LongShapeDescriptor.fromShape(Shape.shape(jArr), Shape.stride(jArr), Shape.elementWiseStride(jArr), Shape.order(jArr), ArrayOptionsHelper.dataType(jArr), ArrayOptionsHelper.arrayType(jArr) == ArrayType.EMPTY);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp customOp) {
        if (customOp == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        return calculateOutputShape(customOp, null);
    }

    public List<LongShapeDescriptor> calculateOutputShape(@NonNull CustomOp customOp, OpContext opContext) {
        if (customOp == null) {
            throw new NullPointerException("op is marked non-null but is null");
        }
        customOp.opName().toLowerCase();
        long opHash = customOp.opHash();
        ArrayList arrayList = new ArrayList();
        int numInputArguments = opContext != null ? opContext.numInputArguments() : customOp.numInputArguments();
        if (numInputArguments == 0 && customOp.getDescriptor().getNumInputs() >= 1) {
            if (log.isTraceEnabled()) {
                log.trace("Could not calculate output shape for op {}: number of input args was 0", customOp.getClass().getName());
            }
            return Collections.emptyList();
        }
        PointerPointer pointerPointer = new PointerPointer(numInputArguments);
        PointerPointer pointerPointer2 = new PointerPointer(numInputArguments);
        List<INDArray> inputArrays = opContext != null ? opContext.getInputArrays() : customOp.inputArguments();
        int i = 0;
        for (INDArray iNDArray : inputArrays) {
            if (!iNDArray.isEmpty()) {
                pointerPointer.put(i, iNDArray.data().addressPointer());
            }
            int i2 = i;
            i++;
            pointerPointer2.put(i2, iNDArray.shapeInfoDataBuffer().addressPointer());
        }
        int numIArguments = opContext != null ? opContext.numIArguments() : customOp.numIArguments();
        LongPointer longPointer = numIArguments > 0 ? new LongPointer(numIArguments) : null;
        int i3 = 0;
        if (opContext != null) {
            Iterator it = opContext.getIArguments().iterator();
            while (it.hasNext()) {
                int i4 = i3;
                i3++;
                longPointer.put(i4, ((Long) it.next()).longValue());
            }
        } else {
            for (long j : customOp.iArgs()) {
                int i5 = i3;
                i3++;
                longPointer.put(i5, j);
            }
        }
        int numTArguments = opContext != null ? opContext.numTArguments() : customOp.numTArguments();
        DoublePointer doublePointer = numTArguments > 0 ? new DoublePointer(numTArguments) : null;
        int numBArguments = opContext != null ? opContext.numBArguments() : customOp.numBArguments();
        BooleanPointer booleanPointer = numBArguments > 0 ? new BooleanPointer(numBArguments) : null;
        int numDArguments = opContext != null ? opContext.numDArguments() : customOp.numDArguments();
        IntPointer intPointer = numDArguments > 0 ? new IntPointer(numDArguments) : null;
        int i6 = 0;
        if (opContext != null) {
            Iterator it2 = opContext.getBArguments().iterator();
            while (it2.hasNext()) {
                int i7 = i6;
                i6++;
                booleanPointer.put(i7, ((Boolean) it2.next()).booleanValue());
            }
        } else {
            for (boolean z : customOp.bArgs()) {
                int i8 = i6;
                i6++;
                booleanPointer.put(i8, z);
            }
        }
        int i9 = 0;
        if (opContext != null) {
            Iterator it3 = opContext.getTArguments().iterator();
            while (it3.hasNext()) {
                int i10 = i9;
                i9++;
                doublePointer.put(i10, ((Double) it3.next()).doubleValue());
            }
        } else {
            for (double d : customOp.tArgs()) {
                int i11 = i9;
                i9++;
                doublePointer.put(i11, d);
            }
        }
        int i12 = 0;
        if (opContext != null) {
            Iterator it4 = opContext.getDArguments().iterator();
            while (it4.hasNext()) {
                int i13 = i12;
                i12++;
                intPointer.put(i13, ((DataType) it4.next()).toInt());
            }
        } else {
            for (DataType dataType : customOp.dArgs()) {
                int i14 = i12;
                i12++;
                intPointer.put(i14, dataType.toInt());
            }
        }
        try {
            OpaqueShapeList calculateOutputShapes2 = this.loop.calculateOutputShapes2((PointerPointer) null, opHash, pointerPointer, pointerPointer2, numInputArguments, doublePointer, numTArguments, longPointer, numIArguments, booleanPointer, numBArguments, intPointer, numDArguments);
            if (this.loop.lastErrorCode() != 0) {
                throw new RuntimeException("Op " + customOp.opName() + " with name " + ((DifferentialFunction) customOp).getOwnName() + " failed to execute." + opContext.toString() + " Here is the error from c++: " + this.loop.lastErrorMessage());
            }
            if (this.loop.lastErrorCode() != 0) {
                throw new RuntimeException(this.loop.lastErrorMessage());
            }
            if (calculateOutputShapes2 == null) {
                throw new RuntimeException();
            }
            for (int i15 = 0; i15 < this.loop.getShapeListSize(calculateOutputShapes2); i15++) {
                arrayList.add(getShapeFromPointer(new PagedPointer(this.loop.getShape(calculateOutputShapes2, i15)).asLongPointer()));
            }
            this.loop.deleteShapeList(calculateOutputShapes2);
            if (log.isTraceEnabled()) {
                String[] strArr = new String[arrayList.size()];
                for (int i16 = 0; i16 < arrayList.size(); i16++) {
                    strArr[i16] = ((LongShapeDescriptor) arrayList.get(i16)).toString();
                }
                log.trace("Calculated output shapes for op  of name {} and type {} - {}", new Object[]{((DifferentialFunction) customOp).getOwnName(), customOp.getClass().getName(), Arrays.toString(strArr)});
            }
            return arrayList;
        } catch (Throwable th) {
            StringBuilder sb = new StringBuilder();
            sb.append("Inputs: [(");
            for (int i17 = 0; i17 < inputArrays.size(); i17++) {
                if (i17 > 0) {
                    sb.append("), (");
                }
                sb.append(Shape.shapeToStringShort((INDArray) inputArrays.get(i17)));
            }
            sb.append(")]");
            if ((customOp instanceof DifferentialFunction) && ((DifferentialFunction) customOp).getSameDiff() != null) {
                appendSameDiffInfo(sb, (DifferentialFunction) customOp);
            }
            log.error("Failed to calculate output shapes for op {}. Attempted to execute with {} inputs, {} outputs, {} targs, {} iargs, {} bargs and {} dargs. {} - Please see above message (printed out from c++) for a possible cause of error.", new Object[]{customOp.opName(), Integer.valueOf(numInputArguments), Integer.valueOf(opContext != null ? opContext.numOutputArguments() : customOp.numOutputArguments()), Integer.valueOf(numTArguments), Integer.valueOf(numIArguments), Integer.valueOf(numBArguments), Integer.valueOf(numDArguments), sb.toString()});
            throw th;
        }
    }

    public void enableDebugMode(boolean z) {
        this.debug.set(z);
        this.loop.enableDebugMode(z);
    }

    public void enableVerboseMode(boolean z) {
        this.verbose.set(z);
        this.loop.enableVerboseMode(z);
    }

    public void registerGraph(long j, Pointer pointer) {
        this.loop.registerGraph((PointerPointer) null, j, pointer);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
    }

    public Map<String, INDArray> executeGraph(long j, @NonNull Map<String, INDArray> map, @NonNull Map<String, Integer> map2) {
        if (map == null) {
            throw new NullPointerException("map is marked non-null but is null");
        }
        if (map2 == null) {
            throw new NullPointerException("reverseMap is marked non-null but is null");
        }
        PointerPointer pointerPointer = new PointerPointer(map.size());
        PointerPointer pointerPointer2 = new PointerPointer(map.size());
        IntPointer intPointer = new IntPointer(map.size());
        int i = 0;
        Iterator it = new ArrayList(map.keySet()).iterator();
        while (it.hasNext()) {
            String str = (String) it.next();
            INDArray iNDArray = map.get(str);
            pointerPointer.put(i, iNDArray.data().addressPointer());
            pointerPointer2.put(i, iNDArray.shapeInfoDataBuffer().addressPointer());
            intPointer.put(i, map2.get(str).intValue());
            i++;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        OpaqueVariablesSet executeStoredGraph = this.loop.executeStoredGraph((PointerPointer) null, j, pointerPointer, pointerPointer2, intPointer, map.size());
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        OpStatus byNumber = OpStatus.byNumber(this.loop.getVariablesSetStatus(executeStoredGraph));
        if (byNumber != OpStatus.ND4J_STATUS_OK) {
            throw new ND4JIllegalStateException("Op execution failed: " + byNumber);
        }
        for (int i2 = 0; i2 < this.loop.getVariablesSetSize(executeStoredGraph); i2++) {
            OpaqueVariable variable = this.loop.getVariable(executeStoredGraph, i2);
            this.loop.getVariableId(variable);
            this.loop.getVariableIndex(variable);
            LongPointer variableShape = this.loop.getVariableShape(variable);
            Pointer variableBuffer = this.loop.getVariableBuffer(variable);
            long[] jArr = new long[(((int) variableShape.get(0L)) * 2) + 4];
            for (int i3 = 0; i3 < jArr.length; i3++) {
                jArr[i3] = variableShape.get(i3);
            }
            long[] shapeOf = Shape.shapeOf(jArr);
            INDArray create = Nd4j.create(shapeOf, Shape.stridesOf(jArr), 0L, Shape.order(jArr));
            long helperStartTransaction = PerformanceTracker.getInstance().helperStartTransaction();
            Pointer.memcpy(create.data().addressPointer(), variableBuffer, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(create.dataType()));
            PerformanceTracker.getInstance().helperRegisterTransaction(0, helperStartTransaction, Shape.lengthOf(shapeOf) * Nd4j.sizeOfDataType(create.dataType()), MemcpyDirection.HOST_TO_HOST);
            linkedHashMap.put(this.loop.getVariableName(variable), create);
        }
        this.loop.deleteVariablesSet(executeStoredGraph);
        return linkedHashMap;
    }

    public void forgetGraph(long j) {
        this.loop.unregisterGraph((PointerPointer) null, j);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
    }

    public void setElementsThreshold(int i) {
        this.loop.setElementThreshold(i);
    }

    public void setTadThreshold(int i) {
        this.loop.setTADThreshold(i);
    }

    public String getString(DataBuffer dataBuffer, long j) {
        Preconditions.checkArgument(dataBuffer instanceof Utf8Buffer, "Expected Utf8Buffer");
        return new Nd4jCpu.utf8string((Pointer) new PagedPointer(dataBuffer.indexer().get(j)))._buffer().capacity(r0._length()).getString();
    }

    public OpExecutioner.ExecutionerType type() {
        return OpExecutioner.ExecutionerType.NATIVE_CPU;
    }

    public boolean isExperimentalMode() {
        return this.experimentalMode.get();
    }

    public void scatterUpdate(ScatterUpdate.UpdateOp updateOp, @NonNull INDArray iNDArray, @NonNull INDArray iNDArray2, @NonNull INDArray iNDArray3, @NonNull int[] iArr) {
        if (iNDArray == null) {
            throw new NullPointerException("array is marked non-null but is null");
        }
        if (iNDArray2 == null) {
            throw new NullPointerException("indices is marked non-null but is null");
        }
        if (iNDArray3 == null) {
            throw new NullPointerException("updates is marked non-null but is null");
        }
        if (iArr == null) {
            throw new NullPointerException("axis is marked non-null but is null");
        }
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo = this.tadManager.getTADOnlyShapeInfo(iNDArray, iArr);
        Pair<DataBuffer, DataBuffer> tADOnlyShapeInfo2 = this.tadManager.getTADOnlyShapeInfo(iNDArray3, iArr);
        if (((DataBuffer) tADOnlyShapeInfo2.getSecond()).length() != iNDArray2.length()) {
            throw new IllegalStateException("Number of updates doesn't match number of indices. Bad dimensions used?");
        }
        this.loop.scatterUpdate((PointerPointer) null, updateOp.ordinal(), (int) iNDArray2.length(), iNDArray.data().addressPointer(), ((DataBuffer) tADOnlyShapeInfo.getFirst()).addressPointer(), ((DataBuffer) tADOnlyShapeInfo.getSecond()).addressPointer(), (Pointer) null, (LongPointer) null, (LongPointer) null, iNDArray3.data().addressPointer(), ((DataBuffer) tADOnlyShapeInfo2.getFirst()).addressPointer(), ((DataBuffer) tADOnlyShapeInfo2.getSecond()).addressPointer(), (Pointer) null, (LongPointer) null, (LongPointer) null, iNDArray2.data().addressPointer(), iNDArray2.shapeInfoDataBuffer().addressPointer(), (Pointer) null, (LongPointer) null);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
    }

    public OpContext buildContext() {
        return new CpuOpContext();
    }

    public INDArray[] exec(CustomOp customOp, @NonNull OpContext opContext) {
        String ownName;
        if (opContext == null) {
            throw new NullPointerException("context is marked non-null but is null");
        }
        long profilingConfigurableHookIn = profilingConfigurableHookIn(customOp, opContext);
        boolean z = false;
        try {
            try {
                if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) {
                    customOp.opName();
                    Boolean bool = this.mklOverrides.get(customOp);
                    if (bool != null && bool.booleanValue()) {
                        z = true;
                        Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
                    }
                }
                int execCustomOp2 = this.loop.execCustomOp2((PointerPointer) null, customOp.opHash(), opContext.contextPointer());
                if (this.loop.lastErrorCode() != 0) {
                    throw new RuntimeException(this.loop.lastErrorMessage());
                }
                if (execCustomOp2 != 0) {
                    throw new RuntimeException("Op with name " + ((DifferentialFunction) customOp).getOwnName() + " and op type [" + customOp.opName() + "] execution failed");
                }
                if (opContext.getOutputArrays().isEmpty()) {
                    INDArray[] iNDArrayArr = new INDArray[0];
                    if (z) {
                        Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
                    }
                    profilingConfigurableHookOut(customOp, opContext, profilingConfigurableHookIn);
                    return iNDArrayArr;
                }
                INDArray[] iNDArrayArr2 = (INDArray[]) opContext.getOutputArrays().toArray(new INDArray[opContext.getOutputArrays().size()]);
                if (z) {
                    Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
                }
                profilingConfigurableHookOut(customOp, opContext, profilingConfigurableHookIn);
                return iNDArrayArr2;
            } catch (Exception e) {
                StringBuilder sb = new StringBuilder();
                sb.append("Inputs: [(");
                int size = opContext.getInputArrays() == null ? 0 : opContext.getInputArrays().size();
                for (int i = 0; i < size; i++) {
                    if (i > 0) {
                        sb.append("), (");
                    }
                    sb.append(Shape.shapeToStringShort((INDArray) opContext.getInputArrays().get(i)));
                }
                sb.append(")]. Outputs: [(");
                int size2 = opContext.getOutputArrays() == null ? 0 : opContext.getOutputArrays().size();
                for (int i2 = 0; i2 < size2; i2++) {
                    if (i2 > 0) {
                        sb.append("), (");
                    }
                    sb.append(Shape.shapeToStringShort((INDArray) opContext.getOutputArrays().get(i2)));
                }
                sb.append(")]. tArgs: ");
                int size3 = opContext.getTArguments() == null ? 0 : opContext.getTArguments().size();
                if (size3 > 0) {
                    sb.append(opContext.getTArguments());
                } else {
                    sb.append("-");
                }
                sb.append(". iArgs: ");
                int size4 = opContext.getIArguments() == null ? 0 : opContext.getIArguments().size();
                if (size4 > 0) {
                    sb.append(opContext.getIArguments());
                } else {
                    sb.append("-");
                }
                sb.append(". bArgs: ");
                int size5 = opContext.getBArguments() == null ? 0 : opContext.getBArguments().size();
                if (size5 > 0) {
                    sb.append(opContext.getBArguments());
                } else {
                    sb.append("-");
                }
                if ((customOp instanceof DifferentialFunction) && (ownName = ((DifferentialFunction) customOp).getOwnName()) != null && !ownName.equals(customOp.opName())) {
                    sb.append(". Op own name: \"").append(ownName).append("\"");
                }
                if ((customOp instanceof DifferentialFunction) && ((DifferentialFunction) customOp).getSameDiff() != null) {
                    appendSameDiffInfo(sb, (DifferentialFunction) customOp);
                }
                log.error("Failed to execute op " + customOp.opName() + ". Attempted to execute with " + size + " inputs, " + size2 + " outputs, " + size3 + " targs," + size5 + " bargs and " + size4 + " iargs. " + sb.toString() + " - Please see above message (printed out from c++) for a possible cause of error.");
                throw e;
            }
        } catch (Throwable th) {
            if (0 != 0) {
                Nd4jCpu.Environment.getInstance().setUseMKLDNN(true);
            }
            profilingConfigurableHookOut(customOp, opContext, profilingConfigurableHookIn);
            throw th;
        }
    }

    public INDArrayStatistics inspectArray(INDArray iNDArray) {
        Nd4jCpu.DebugInfo debugInfo = new Nd4jCpu.DebugInfo();
        this.loop.inspectArray((PointerPointer) null, iNDArray.data().addressPointer(), iNDArray.shapeInfoDataBuffer().addressPointer(), (Pointer) null, (LongPointer) null, debugInfo);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        return INDArrayStatistics.builder().minValue(debugInfo._minValue()).maxValue(debugInfo._maxValue()).meanValue(debugInfo._meanValue()).stdDevValue(debugInfo._stdDevValue()).countInf(debugInfo._infCount()).countNaN(debugInfo._nanCount()).countNegative(debugInfo._negativeCount()).countPositive(debugInfo._positiveCount()).countZero(debugInfo._zeroCount()).build();
    }

    public DataBuffer createShapeInfo(long[] jArr, long[] jArr2, long j, char c, DataType dataType, boolean z) {
        OpaqueConstantShapeBuffer shapeBuffer = this.loop.shapeBuffer(jArr.length, new LongPointer(jArr), new LongPointer(jArr2), dataType.toInt(), c, j, z);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        LongBuffer longBuffer = new LongBuffer(this.loop.getConstantShapeBufferPrimary(shapeBuffer), Shape.shapeInfoLength(jArr.length));
        this.loop.deleteConstantShapeBuffer(shapeBuffer);
        return longBuffer;
    }

    public DataBuffer createShapeInfo(long[] jArr, long[] jArr2, long j, char c, DataType dataType, long j2) {
        OpaqueConstantShapeBuffer shapeBufferEx = this.loop.shapeBufferEx(jArr.length, new LongPointer(jArr), new LongPointer(jArr2), dataType.toInt(), c, j, j2);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        LongBuffer longBuffer = new LongBuffer(this.loop.getConstantShapeBufferPrimary(shapeBufferEx), Shape.shapeInfoLength(jArr.length));
        this.loop.deleteConstantShapeBuffer(shapeBufferEx);
        return longBuffer;
    }

    public TadPack tadShapeInfoAndOffsets(INDArray iNDArray, int[] iArr) {
        OpaqueTadPack tadOnlyShapeInfo = this.loop.tadOnlyShapeInfo(iNDArray.shapeInfoDataBuffer().addressPointer(), new IntPointer(iArr), iArr.length);
        if (this.loop.lastErrorCode() != 0) {
            throw new RuntimeException(this.loop.lastErrorMessage());
        }
        LongBuffer longBuffer = new LongBuffer((Pointer) this.loop.getPrimaryShapeInfo(tadOnlyShapeInfo), this.loop.getShapeInfoLength(tadOnlyShapeInfo));
        LongBuffer longBuffer2 = new LongBuffer((Pointer) this.loop.getPrimaryOffsets(tadOnlyShapeInfo), this.loop.getNumberOfTads(tadOnlyShapeInfo));
        this.loop.deleteTadPack(tadOnlyShapeInfo);
        return new TadPack(longBuffer, longBuffer2);
    }

    protected void appendSameDiffInfo(StringBuilder sb, DifferentialFunction differentialFunction) {
        String[] argNames = differentialFunction.argNames();
        String[] outputVariablesNames = differentialFunction.outputVariablesNames();
        if (argNames != null) {
            sb.append(". Input var names: ").append(Arrays.toString(argNames));
        }
        if (outputVariablesNames != null) {
            sb.append(". Output var names: ").append(Arrays.toString(outputVariablesNames));
        }
    }
}
