package org.nd4j.linalg.api.blas.params;

import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/blas/params/GemmParams.class */
public class GemmParams {
    private int lda;
    private int ldb;
    private int ldc;
    private int m;
    private int n;
    private int k;
    private INDArray a;
    private INDArray b;
    private INDArray c;

    public GemmParams(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        if (iNDArray2.columns() != iNDArray3.columns()) {
            throw new IllegalArgumentException("B columns must match c columns");
        }
        if (iNDArray.rows() != iNDArray3.rows()) {
            throw new IllegalArgumentException("A rows must equal c rows");
        }
        if (iNDArray.ordering() != iNDArray2.ordering()) {
            if (iNDArray.ordering() != 'f') {
                INDArray create = Nd4j.create(iNDArray.shape(), 'f');
                NdIndexIterator ndIndexIterator = new NdIndexIterator('c', create.shape());
                while (ndIndexIterator.hasNext()) {
                    int[] next = ndIndexIterator.next();
                    create.putScalar(next, iNDArray.getDouble(next));
                }
                iNDArray = create;
            }
            if (iNDArray2.ordering() != 'f') {
                INDArray create2 = Nd4j.create(iNDArray2.shape(), 'f');
                NdIndexIterator ndIndexIterator2 = new NdIndexIterator('c', create2.shape());
                while (ndIndexIterator2.hasNext()) {
                    int[] next2 = ndIndexIterator2.next();
                    create2.putScalar(next2, iNDArray2.getDouble(next2));
                }
                iNDArray2 = create2;
            }
        }
        this.a = iNDArray;
        this.b = iNDArray2;
        this.c = iNDArray3;
        this.m = iNDArray.rows();
        this.n = iNDArray2.columns();
        this.k = iNDArray.columns();
        if (iNDArray.ordering() == 'c' && iNDArray2.ordering() == 'c') {
            int i = this.n;
            int i2 = this.m;
            this.m = i;
            this.n = i2;
            this.a = iNDArray2;
            this.b = iNDArray;
        }
        this.lda = Math.max(1, this.m);
        this.ldb = Math.max(1, this.k);
        this.ldc = Math.max(1, this.m);
        if (unevenStrides(iNDArray)) {
            this.a = iNDArray.dup();
        }
        if (unevenStrides(iNDArray2)) {
            this.b = iNDArray2.dup();
        }
        validate();
    }

    protected boolean unevenStrides(INDArray iNDArray) {
        return iNDArray.ordering() == 'f' && iNDArray.offset() > 0;
    }

    private void validate() {
        if (this.m < 0) {
            throw new IllegalStateException("M must be >= 0");
        }
        if (this.n < 0) {
            throw new IllegalStateException("N must be >= 0");
        }
        if (this.k < 0) {
            throw new IllegalStateException("K must be at least 0");
        }
    }

    public int getLda() {
        return this.lda;
    }

    public int getLdb() {
        return this.ldb;
    }

    public int getLdc() {
        return this.ldc;
    }

    public int getM() {
        return this.m;
    }

    public int getN() {
        return this.n;
    }

    public int getK() {
        return this.k;
    }

    public INDArray getA() {
        return this.a;
    }

    public INDArray getB() {
        return this.b;
    }

    public INDArray getC() {
        return this.c;
    }

    public void setLda(int i) {
        this.lda = i;
    }

    public void setLdb(int i) {
        this.ldb = i;
    }

    public void setLdc(int i) {
        this.ldc = i;
    }

    public void setM(int i) {
        this.m = i;
    }

    public void setN(int i) {
        this.n = i;
    }

    public void setK(int i) {
        this.k = i;
    }

    public void setA(INDArray iNDArray) {
        this.a = iNDArray;
    }

    public void setB(INDArray iNDArray) {
        this.b = iNDArray;
    }

    public void setC(INDArray iNDArray) {
        this.c = iNDArray;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GemmParams)) {
            return false;
        }
        GemmParams gemmParams = (GemmParams) obj;
        if (!gemmParams.canEqual(this) || getLda() != gemmParams.getLda() || getLdb() != gemmParams.getLdb() || getLdc() != gemmParams.getLdc() || getM() != gemmParams.getM() || getN() != gemmParams.getN() || getK() != gemmParams.getK()) {
            return false;
        }
        INDArray a = getA();
        INDArray a2 = gemmParams.getA();
        if (a == null) {
            if (a2 != null) {
                return false;
            }
        } else if (!a.equals(a2)) {
            return false;
        }
        INDArray b = getB();
        INDArray b2 = gemmParams.getB();
        if (b == null) {
            if (b2 != null) {
                return false;
            }
        } else if (!b.equals(b2)) {
            return false;
        }
        INDArray c = getC();
        INDArray c2 = gemmParams.getC();
        return c == null ? c2 == null : c.equals(c2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof GemmParams;
    }

    public int hashCode() {
        int lda = (((((((((((1 * 59) + getLda()) * 59) + getLdb()) * 59) + getLdc()) * 59) + getM()) * 59) + getN()) * 59) + getK();
        INDArray a = getA();
        int hashCode = (lda * 59) + (a == null ? 0 : a.hashCode());
        INDArray b = getB();
        int hashCode2 = (hashCode * 59) + (b == null ? 0 : b.hashCode());
        INDArray c = getC();
        return (hashCode2 * 59) + (c == null ? 0 : c.hashCode());
    }

    public String toString() {
        return "GemmParams(lda=" + getLda() + ", ldb=" + getLdb() + ", ldc=" + getLdc() + ", m=" + getM() + ", n=" + getN() + ", k=" + getK() + ", a=" + getA() + ", b=" + getB() + ", c=" + getC() + ")";
    }
}
