package org.nd4j.linalg.api.ops.executioner;

import java.beans.ConstructorProperties;
import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil.class */
public class OpExecutionerUtil {

    /* loaded from: input_file:org/nd4j/linalg/api/ops/executioner/OpExecutionerUtil$Tensor1DStats.class */
    public static class Tensor1DStats {
        public final int firstTensorOffset;
        public final int tensorStartSeparation;
        public final int numTensors;
        public final int tensorLength;
        public final int elementWiseStride;

        @ConstructorProperties({"firstTensorOffset", "tensorStartSeparation", "numTensors", "tensorLength", "elementWiseStride"})
        public Tensor1DStats(int i, int i2, int i3, int i4, int i5) {
            this.firstTensorOffset = i;
            this.tensorStartSeparation = i2;
            this.numTensors = i3;
            this.tensorLength = i4;
            this.elementWiseStride = i5;
        }

        public int getFirstTensorOffset() {
            return this.firstTensorOffset;
        }

        public int getTensorStartSeparation() {
            return this.tensorStartSeparation;
        }

        public int getNumTensors() {
            return this.numTensors;
        }

        public int getTensorLength() {
            return this.tensorLength;
        }

        public int getElementWiseStride() {
            return this.elementWiseStride;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Tensor1DStats)) {
                return false;
            }
            Tensor1DStats tensor1DStats = (Tensor1DStats) obj;
            return tensor1DStats.canEqual(this) && getFirstTensorOffset() == tensor1DStats.getFirstTensorOffset() && getTensorStartSeparation() == tensor1DStats.getTensorStartSeparation() && getNumTensors() == tensor1DStats.getNumTensors() && getTensorLength() == tensor1DStats.getTensorLength() && getElementWiseStride() == tensor1DStats.getElementWiseStride();
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Tensor1DStats;
        }

        public int hashCode() {
            return (((((((((1 * 59) + getFirstTensorOffset()) * 59) + getTensorStartSeparation()) * 59) + getNumTensors()) * 59) + getTensorLength()) * 59) + getElementWiseStride();
        }

        public String toString() {
            return "OpExecutionerUtil.Tensor1DStats(firstTensorOffset=" + getFirstTensorOffset() + ", tensorStartSeparation=" + getTensorStartSeparation() + ", numTensors=" + getNumTensors() + ", tensorLength=" + getTensorLength() + ", elementWiseStride=" + getElementWiseStride() + ")";
        }
    }

    private OpExecutionerUtil() {
    }

    public static boolean canDoOpDirectly(INDArray iNDArray) {
        if (iNDArray.elementWiseStride() < 1) {
            return false;
        }
        if (iNDArray.isVector() || iNDArray.lengthLong() == iNDArray.data().length()) {
            return true;
        }
        int[] shape = iNDArray.shape();
        return Arrays.equals(iNDArray.stride(), iNDArray.ordering() == 'c' ? ArrayUtil.calcStrides(shape) : ArrayUtil.calcStridesFortran(shape));
    }

    public static boolean canDoOpDirectly(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.isVector()) {
            return true;
        }
        if (iNDArray.ordering() != iNDArray2.ordering() || iNDArray.elementWiseStride() < 1 || iNDArray2.elementWiseStride() < 1) {
            return false;
        }
        long lengthLong = iNDArray.lengthLong();
        long length = iNDArray.data().length();
        long lengthLong2 = iNDArray2.lengthLong();
        long length2 = iNDArray2.data().length();
        int[] stride = iNDArray.stride();
        boolean equals = Arrays.equals(stride, iNDArray2.stride());
        if (lengthLong == length && lengthLong2 == length2 && equals) {
            return true;
        }
        if (!equals) {
            return false;
        }
        int[] shape = iNDArray.shape();
        return Arrays.equals(stride, iNDArray.ordering() == 'c' ? ArrayUtil.calcStrides(shape) : ArrayUtil.calcStridesFortran(shape));
    }

    public static boolean canDoOpDirectly(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray.isVector()) {
            return true;
        }
        if (iNDArray.ordering() != iNDArray2.ordering() || iNDArray.ordering() != iNDArray3.ordering() || iNDArray.elementWiseStride() < 1 || iNDArray2.elementWiseStride() < 1) {
            return false;
        }
        long lengthLong = iNDArray.lengthLong();
        long length = iNDArray.data().length();
        long lengthLong2 = iNDArray2.lengthLong();
        long length2 = iNDArray2.data().length();
        long lengthLong3 = iNDArray3.lengthLong();
        long length3 = iNDArray3.data().length();
        int[] stride = iNDArray.stride();
        boolean z = Arrays.equals(stride, iNDArray2.stride()) && Arrays.equals(stride, iNDArray3.stride());
        if (lengthLong == length && lengthLong2 == length2 && lengthLong3 == length3 && z) {
            return true;
        }
        if (!z) {
            return false;
        }
        int[] shape = iNDArray.shape();
        return Arrays.equals(stride, iNDArray.ordering() == 'c' ? ArrayUtil.calcStrides(shape) : ArrayUtil.calcStridesFortran(shape));
    }

    public static int chooseElementWiseTensorDimension(INDArray iNDArray) {
        if (iNDArray.isVector()) {
            return ArrayUtil.argMax(iNDArray.shape());
        }
        int argMin = ArrayUtil.argMin(iNDArray.stride());
        int argMax = ArrayUtil.argMax(iNDArray.shape());
        return (iNDArray.isVector() || iNDArray.size(argMin) == 1) ? argMax : ArrayUtil.prod(ArrayUtil.removeIndex(iNDArray.shape(), argMin)) <= 10 * ArrayUtil.prod(ArrayUtil.removeIndex(iNDArray.shape(), argMax)) ? argMin : argMax;
    }

    public static int chooseElementWiseTensorDimension(INDArray iNDArray, INDArray iNDArray2) {
        if (iNDArray.isVector()) {
            return ArrayUtil.argMax(iNDArray.shape());
        }
        int argMinOfMax = ArrayUtil.argMinOfMax(iNDArray.stride(), iNDArray2.stride());
        int argMax = ArrayUtil.argMax(iNDArray.shape());
        return (argMinOfMax == argMax || iNDArray.size(argMinOfMax) == 1) ? argMax : ArrayUtil.prod(ArrayUtil.removeIndex(iNDArray.shape(), argMinOfMax)) <= 10 * ArrayUtil.prod(ArrayUtil.removeIndex(iNDArray.shape(), argMax)) ? argMinOfMax : argMax;
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [int[], int[][]] */
    public static int chooseElementWiseTensorDimension(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray.isVector()) {
            return ArrayUtil.argMax(iNDArray.shape());
        }
        int argMinOfMax = ArrayUtil.argMinOfMax((int[][]) new int[]{iNDArray.stride(), iNDArray2.stride(), iNDArray3.stride()});
        int argMax = ArrayUtil.argMax(iNDArray.shape());
        return (argMinOfMax == argMax || iNDArray.size(argMinOfMax) == 1) ? argMax : ArrayUtil.prod(ArrayUtil.removeIndex(iNDArray.shape(), argMinOfMax)) <= 10 * ArrayUtil.prod(ArrayUtil.removeIndex(iNDArray.shape(), argMax)) ? argMinOfMax : argMax;
    }

    public static Tensor1DStats get1DTensorStats(INDArray iNDArray, int... iArr) {
        int offset;
        int elementWiseStride;
        int size = iNDArray.size(iArr[0]);
        int tensorssAlongDimension = iNDArray.tensorssAlongDimension(iArr);
        int offset2 = iNDArray.offset();
        if (tensorssAlongDimension == 1) {
            offset = -1;
            elementWiseStride = iNDArray.elementWiseStride();
        } else {
            INDArray tensorAlongDimension = iNDArray.tensorAlongDimension(1, iArr);
            offset = tensorAlongDimension.offset() - offset2;
            elementWiseStride = tensorAlongDimension.elementWiseStride();
        }
        return new Tensor1DStats(offset2, offset, tensorssAlongDimension, size, elementWiseStride);
    }
}
