package org.nd4j.linalg.api.blas.params;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/api/blas/params/GemmParams.class */
public class GemmParams {
    private int lda;
    private int ldb;
    private int ldc;
    private int m;
    private int n;
    private int k;
    private INDArray a;
    private INDArray b;
    private INDArray c;
    private char transA;
    private char transB;
    private char ordering;

    public GemmParams(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        this.transA = 'N';
        this.transB = 'N';
        this.ordering = 'f';
        if (iNDArray.columns() != iNDArray2.rows()) {
            throw new IllegalArgumentException("A columns must equal B rows. MMul attempt: " + Arrays.toString(iNDArray.shape()) + "x" + Arrays.toString(iNDArray2.shape()));
        }
        if (iNDArray2.columns() != iNDArray3.columns()) {
            throw new IllegalArgumentException("B columns must match C columns. MMul attempt: " + Arrays.toString(iNDArray.shape()) + "x" + Arrays.toString(iNDArray2.shape()) + "; result array provided: " + Arrays.toString(iNDArray3.shape()));
        }
        if (iNDArray.rows() != iNDArray3.rows()) {
            throw new IllegalArgumentException("A rows must equal C rows. MMul attempt: " + Arrays.toString(iNDArray.shape()) + "x" + Arrays.toString(iNDArray2.shape()) + "; result array provided: " + Arrays.toString(iNDArray3.shape()));
        }
        if (iNDArray.columns() > Integer.MAX_VALUE || iNDArray.rows() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        if (iNDArray2.columns() > Integer.MAX_VALUE || iNDArray2.rows() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        if (iNDArray3.columns() > Integer.MAX_VALUE || iNDArray3.rows() > Integer.MAX_VALUE) {
            throw new ND4JArraySizeException();
        }
        if (!Nd4j.allowsSpecifyOrdering()) {
            this.a = copyIfNeccessary(iNDArray);
            this.b = copyIfNeccessary(iNDArray2);
            this.c = iNDArray3;
            this.m = iNDArray3.rows();
            this.n = iNDArray3.columns();
            this.k = iNDArray.columns();
            this.lda = this.a.ordering() == 'f' ? this.a.rows() : this.a.columns();
            this.ldb = this.b.ordering() == 'f' ? this.b.rows() : this.b.columns();
            this.ldc = iNDArray3.rows();
            this.transA = this.a.ordering() == 'c' ? 'T' : 'N';
            this.transB = this.b.ordering() == 'c' ? 'T' : 'N';
            return;
        }
        if (iNDArray.ordering() != iNDArray2.ordering()) {
            this.a = copyIfNeccessary(iNDArray);
            this.b = iNDArray2.dup(iNDArray.ordering());
            this.c = iNDArray3;
            this.m = iNDArray3.rows();
            this.n = iNDArray3.columns();
            this.k = iNDArray.columns();
            this.ordering = iNDArray.ordering();
            this.lda = iNDArray.rows();
            this.ldb = iNDArray2.rows();
            this.ldc = iNDArray3.rows();
            this.transA = 'N';
            this.transB = 'N';
            return;
        }
        this.ordering = iNDArray.ordering();
        this.a = copyIfNeccessary(iNDArray);
        this.b = copyIfNeccessary(iNDArray2);
        this.c = iNDArray3;
        if (this.ordering == 'c') {
            this.m = iNDArray3.columns();
            this.n = iNDArray3.rows();
            this.k = iNDArray.columns();
        } else {
            this.m = iNDArray3.rows();
            this.n = iNDArray3.columns();
            this.k = iNDArray2.columns();
        }
        this.lda = iNDArray.rows();
        this.ldb = iNDArray2.rows();
        this.ldc = iNDArray3.rows();
        this.transA = 'N';
        this.transB = 'N';
    }

    public GemmParams(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, boolean z2) {
        this(z ? iNDArray.transpose() : iNDArray, z2 ? iNDArray2.transpose() : iNDArray2, iNDArray3);
    }

    private INDArray copyIfNeccessary(INDArray iNDArray) {
        return (Nd4j.allowsSpecifyOrdering() || iNDArray.ordering() != 'c' || (((long) iNDArray.stride(0)) == iNDArray.size(1) && iNDArray.stride(1) == 1)) ? (iNDArray.ordering() != 'f' || (iNDArray.stride(0) == 1 && ((long) iNDArray.stride(1)) == iNDArray.size(0))) ? iNDArray.elementWiseStride() < 0 ? iNDArray.dup() : iNDArray : iNDArray.dup() : iNDArray.dup();
    }

    public int getLda() {
        return this.lda;
    }

    public int getLdb() {
        return this.ldb;
    }

    public int getLdc() {
        return this.ldc;
    }

    public int getM() {
        return this.m;
    }

    public int getN() {
        return this.n;
    }

    public int getK() {
        return this.k;
    }

    public INDArray getA() {
        return this.a;
    }

    public INDArray getB() {
        return this.b;
    }

    public INDArray getC() {
        return this.c;
    }

    public char getTransA() {
        return this.transA;
    }

    public char getTransB() {
        return this.transB;
    }

    public char getOrdering() {
        return this.ordering;
    }

    public void setLda(int i) {
        this.lda = i;
    }

    public void setLdb(int i) {
        this.ldb = i;
    }

    public void setLdc(int i) {
        this.ldc = i;
    }

    public void setM(int i) {
        this.m = i;
    }

    public void setN(int i) {
        this.n = i;
    }

    public void setK(int i) {
        this.k = i;
    }

    public void setA(INDArray iNDArray) {
        this.a = iNDArray;
    }

    public void setB(INDArray iNDArray) {
        this.b = iNDArray;
    }

    public void setC(INDArray iNDArray) {
        this.c = iNDArray;
    }

    public void setTransA(char c) {
        this.transA = c;
    }

    public void setTransB(char c) {
        this.transB = c;
    }

    public void setOrdering(char c) {
        this.ordering = c;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof GemmParams)) {
            return false;
        }
        GemmParams gemmParams = (GemmParams) obj;
        if (!gemmParams.canEqual(this) || getLda() != gemmParams.getLda() || getLdb() != gemmParams.getLdb() || getLdc() != gemmParams.getLdc() || getM() != gemmParams.getM() || getN() != gemmParams.getN() || getK() != gemmParams.getK()) {
            return false;
        }
        INDArray a = getA();
        INDArray a2 = gemmParams.getA();
        if (a == null) {
            if (a2 != null) {
                return false;
            }
        } else if (!a.equals(a2)) {
            return false;
        }
        INDArray b = getB();
        INDArray b2 = gemmParams.getB();
        if (b == null) {
            if (b2 != null) {
                return false;
            }
        } else if (!b.equals(b2)) {
            return false;
        }
        INDArray c = getC();
        INDArray c2 = gemmParams.getC();
        if (c == null) {
            if (c2 != null) {
                return false;
            }
        } else if (!c.equals(c2)) {
            return false;
        }
        return getTransA() == gemmParams.getTransA() && getTransB() == gemmParams.getTransB() && getOrdering() == gemmParams.getOrdering();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof GemmParams;
    }

    public int hashCode() {
        int lda = (((((((((((1 * 59) + getLda()) * 59) + getLdb()) * 59) + getLdc()) * 59) + getM()) * 59) + getN()) * 59) + getK();
        INDArray a = getA();
        int hashCode = (lda * 59) + (a == null ? 43 : a.hashCode());
        INDArray b = getB();
        int hashCode2 = (hashCode * 59) + (b == null ? 43 : b.hashCode());
        INDArray c = getC();
        return (((((((hashCode2 * 59) + (c == null ? 43 : c.hashCode())) * 59) + getTransA()) * 59) + getTransB()) * 59) + getOrdering();
    }

    public String toString() {
        return "GemmParams(lda=" + getLda() + ", ldb=" + getLdb() + ", ldc=" + getLdc() + ", m=" + getM() + ", n=" + getN() + ", k=" + getK() + ", a=" + getA() + ", b=" + getB() + ", c=" + getC() + ", transA=" + getTransA() + ", transB=" + getTransB() + ", ordering=" + getOrdering() + ")";
    }
}
