package org.apache.sysds.runtime.compress.lib;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibMatrixMult.class */
public class CLALibMatrixMult {
    private static final Log LOG = LogFactory.getLog(CLALibMatrixMult.class.getName());

    public static MatrixBlock matrixMult(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i) {
        return matrixMultiply(matrixBlock, matrixBlock2, matrixBlock3, i, false, false);
    }

    public static MatrixBlock matrixMultiply(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, int i, boolean z, boolean z2) {
        Timing timing = LOG.isTraceEnabled() ? new Timing(true) : null;
        if ((matrixBlock instanceof CompressedMatrixBlock) && (matrixBlock2 instanceof CompressedMatrixBlock)) {
            return doubleCompressedMatrixMultiply((CompressedMatrixBlock) matrixBlock, (CompressedMatrixBlock) matrixBlock2, matrixBlock3, i, z, z2);
        }
        boolean z3 = false;
        if (z || z2) {
            if (((matrixBlock instanceof CompressedMatrixBlock) && z) || ((matrixBlock2 instanceof CompressedMatrixBlock) && z2)) {
                z3 = true;
                matrixBlock = matrixBlock2;
                matrixBlock2 = matrixBlock;
                z = !z2;
                z2 = !z;
            }
            if (!(matrixBlock instanceof CompressedMatrixBlock) && z) {
                matrixBlock = LibMatrixReorg.transpose(matrixBlock, i);
            } else if (!(matrixBlock2 instanceof CompressedMatrixBlock) && z2) {
                matrixBlock2 = LibMatrixReorg.transpose(matrixBlock2, i);
            }
        }
        boolean z4 = matrixBlock instanceof CompressedMatrixBlock;
        CompressedMatrixBlock compressedMatrixBlock = (CompressedMatrixBlock) (z4 ? matrixBlock : matrixBlock2);
        MatrixBlock matrixBlock4 = z4 ? matrixBlock2 : matrixBlock;
        MatrixBlock rightMultByMatrix = z4 ? CLALibRightMultBy.rightMultByMatrix(compressedMatrixBlock, matrixBlock4, matrixBlock3, i) : CLALibLeftMultBy.leftMultByMatrix(compressedMatrixBlock, matrixBlock4, matrixBlock3, i);
        if (LOG.isTraceEnabled()) {
            LOG.trace("MM: Time block w/ sharedDim: " + matrixBlock.getNumColumns() + " rowLeft: " + matrixBlock.getNumRows() + " colRight:" + matrixBlock2.getNumColumns() + " in " + timing.stop() + "ms.");
        }
        if (z3) {
            if (rightMultByMatrix instanceof CompressedMatrixBlock) {
                LOG.warn("Transposing decompression");
                rightMultByMatrix = ((CompressedMatrixBlock) rightMultByMatrix).decompress(i);
            }
            rightMultByMatrix = LibMatrixReorg.transpose(rightMultByMatrix, i);
        }
        return rightMultByMatrix;
    }

    private static MatrixBlock doubleCompressedMatrixMultiply(CompressedMatrixBlock compressedMatrixBlock, CompressedMatrixBlock compressedMatrixBlock2, MatrixBlock matrixBlock, int i, boolean z, boolean z2) {
        if (!z && !z2) {
            LOG.warn("Matrix decompression from multiplying two compressed matrices.");
            return matrixMultiply(compressedMatrixBlock, CompressedMatrixBlock.getUncompressed(compressedMatrixBlock2), matrixBlock, i, z, z2);
        }
        if (z && !z2) {
            return compressedMatrixBlock.getNumColumns() > compressedMatrixBlock2.getNumColumns() ? CLALibLeftMultBy.leftMultByMatrixTransposed(compressedMatrixBlock, compressedMatrixBlock2, matrixBlock, i).reorgOperations(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), i), (MatrixValue) new MatrixBlock(), 0, 0, 0) : CLALibLeftMultBy.leftMultByMatrixTransposed(compressedMatrixBlock2, compressedMatrixBlock, matrixBlock, i);
        }
        if (z || !z2) {
            return matrixMult(compressedMatrixBlock2, compressedMatrixBlock, matrixBlock, i).reorgOperations(new ReorgOperator(SwapIndex.getSwapIndexFnObject(), i), (MatrixValue) new MatrixBlock(), 0, 0, 0);
        }
        throw new DMLCompressionException("Not Implemented compressed Matrix Mult, to produce larger matrix");
    }
}
