package org.nd4j.linalg.api.blas.params;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/linalg/api/blas/params/MMulTranspose.class */
public class MMulTranspose implements Serializable {
    private static MMulTranspose allFalse = builder().build();
    private boolean transposeA;
    private boolean transposeB;
    private boolean transposeResult;

    /* loaded from: input_file:org/nd4j/linalg/api/blas/params/MMulTranspose$MMulTransposeBuilder.class */
    public static class MMulTransposeBuilder {
        private boolean transposeA;
        private boolean transposeB;
        private boolean transposeResult;

        MMulTransposeBuilder() {
        }

        public MMulTransposeBuilder transposeA(boolean z) {
            this.transposeA = z;
            return this;
        }

        public MMulTransposeBuilder transposeB(boolean z) {
            this.transposeB = z;
            return this;
        }

        public MMulTransposeBuilder transposeResult(boolean z) {
            this.transposeResult = z;
            return this;
        }

        public MMulTranspose build() {
            return new MMulTranspose(this.transposeA, this.transposeB, this.transposeResult);
        }

        public String toString() {
            return "MMulTranspose.MMulTransposeBuilder(transposeA=" + this.transposeA + ", transposeB=" + this.transposeB + ", transposeResult=" + this.transposeResult + ")";
        }
    }

    public MMulTranspose(boolean z, boolean z2, boolean z3) {
        this.transposeA = z;
        this.transposeB = z2;
        this.transposeResult = z3;
    }

    public static MMulTranspose allFalse() {
        return allFalse;
    }

    public INDArray exec(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        INDArray transposeIfReq = transposeIfReq(this.transposeA, iNDArray);
        INDArray transposeIfReq2 = transposeIfReq(this.transposeB, iNDArray2);
        if (iNDArray3 == null) {
            return transposeIfReq(this.transposeResult, transposeIfReq.mmul(transposeIfReq2));
        }
        return !this.transposeResult ? transposeIfReq.mmuli(transposeIfReq2, iNDArray3) : transposeIfReq.mmuli(transposeIfReq2, iNDArray3).transpose();
    }

    private static INDArray transposeIfReq(boolean z, INDArray iNDArray) {
        if (z) {
            if (iNDArray.rank() == 2) {
                return iNDArray.transpose();
            }
            if (iNDArray.rank() == 3) {
                return iNDArray.permute(0, 2, 1);
            }
        }
        return iNDArray;
    }

    public static MMulTransposeBuilder builder() {
        return new MMulTransposeBuilder();
    }

    public boolean isTransposeA() {
        return this.transposeA;
    }

    public boolean isTransposeB() {
        return this.transposeB;
    }

    public boolean isTransposeResult() {
        return this.transposeResult;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof MMulTranspose)) {
            return false;
        }
        MMulTranspose mMulTranspose = (MMulTranspose) obj;
        return mMulTranspose.canEqual(this) && isTransposeA() == mMulTranspose.isTransposeA() && isTransposeB() == mMulTranspose.isTransposeB() && isTransposeResult() == mMulTranspose.isTransposeResult();
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof MMulTranspose;
    }

    public int hashCode() {
        return (((((1 * 59) + (isTransposeA() ? 79 : 97)) * 59) + (isTransposeB() ? 79 : 97)) * 59) + (isTransposeResult() ? 79 : 97);
    }
}
