package org.nd4j.linalg.api.shape.tensor;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;

/* loaded from: input_file:org/nd4j/linalg/api/shape/tensor/TensorCalculator1d.class */
public class TensorCalculator1d implements TensorCalculator {
    private int baseOffset;
    private int[] shape;
    private int[] stride;
    private int tensorDim;
    private int[] shapeMinusTensorDim;
    private int elementWiseStride;
    private int[] tensorShape;
    private int[] tensorStride;
    private int numTensors;

    public TensorCalculator1d(INDArray iNDArray, int i) {
        this(iNDArray.offset(), iNDArray.shape(), iNDArray.stride(), i);
    }

    public TensorCalculator1d(int i, int[] iArr, int[] iArr2, int i2) {
        this.baseOffset = i;
        this.shape = iArr;
        this.stride = iArr2;
        i2 = i2 < 0 ? i2 + iArr.length : i2;
        this.tensorDim = i2;
        this.shapeMinusTensorDim = ArrayUtil.removeIndex(iArr, i2);
        this.elementWiseStride = iArr2[i2];
        this.tensorShape = new int[]{1, iArr[i2]};
        this.tensorStride = new int[]{1, this.elementWiseStride};
        this.numTensors = ArrayUtil.prod(this.shapeMinusTensorDim);
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int getNumTensors() {
        return this.numTensors;
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int getOffsetForTensor(int i) {
        int[] ind2subC = Shape.ind2subC(this.shapeMinusTensorDim, i);
        int i2 = this.baseOffset;
        int i3 = 0;
        for (int i4 = 0; i4 < this.shape.length; i4++) {
            if (i4 != this.tensorDim) {
                int i5 = i3;
                i3++;
                i2 += ind2subC[i5] * this.stride[i4];
            }
        }
        return i2;
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int[] getShape() {
        return this.tensorShape;
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int[] getStride() {
        return this.tensorStride;
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int getBaseOffset() {
        return this.baseOffset;
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int getElementWiseStrideForTensor() {
        return this.elementWiseStride;
    }

    @Override // org.nd4j.linalg.api.shape.tensor.TensorCalculator
    public int getTensorLength() {
        return this.shape[this.tensorDim];
    }
}
