package org.apache.sysds.runtime.compress.lib;

import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.matrix.data.LibMatrixBincell;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibMMChain.class */
public final class CLALibMMChain {
    static final Log LOG = LogFactory.getLog(CLALibMMChain.class.getName());

    private CLALibMMChain() {
    }

    public static MatrixBlock mmChain(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MapMultChain.ChainType chainType, int i) {
        if (compressedMatrixBlock.isEmpty()) {
            return returnEmpty(compressedMatrixBlock, matrixBlock3);
        }
        CompressedMatrixBlock filterColGroups = filterColGroups(compressedMatrixBlock);
        MatrixBlock rightMultByMatrix = CLALibRightMultBy.rightMultByMatrix(filterColGroups, matrixBlock, null, i, filterColGroups.getColGroups().size() == 1 && isOverlappingAllowed());
        if (chainType == MapMultChain.ChainType.XtwXv) {
            rightMultByMatrix = binaryMultW(rightMultByMatrix, matrixBlock2, i);
        }
        if (rightMultByMatrix instanceof CompressedMatrixBlock) {
            CLALibLeftMultBy.leftMultByMatrixTransposed(filterColGroups, (CompressedMatrixBlock) rightMultByMatrix, matrixBlock3, i);
        } else {
            CLALibLeftMultBy.leftMultByMatrixTransposed(filterColGroups, rightMultByMatrix, matrixBlock3, i);
        }
        if (matrixBlock3.getNumColumns() != 1) {
            matrixBlock3 = LibMatrixReorg.transposeInPlace(matrixBlock3, i);
        }
        return matrixBlock3;
    }

    private static boolean isOverlappingAllowed() {
        return ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.COMPRESSED_OVERLAPPING);
    }

    private static MatrixBlock returnEmpty(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock) {
        return prepareReturn(compressedMatrixBlock, matrixBlock);
    }

    private static MatrixBlock prepareReturn(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock) {
        int numColumns = compressedMatrixBlock.getNumColumns();
        if (matrixBlock != null) {
            matrixBlock.reset(numColumns, 1, false);
        } else {
            matrixBlock = new MatrixBlock(numColumns, 1, false);
        }
        return matrixBlock;
    }

    private static MatrixBlock binaryMultW(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i) {
        BinaryOperator binaryOperator = new BinaryOperator(Multiply.getMultiplyFnObject(), i);
        if (matrixBlock instanceof CompressedMatrixBlock) {
            matrixBlock = CLALibBinaryCellOp.binaryOperationsRight(binaryOperator, (CompressedMatrixBlock) matrixBlock, matrixBlock2, null);
        } else {
            LibMatrixBincell.bincellOpInPlace(matrixBlock, matrixBlock2, binaryOperator);
        }
        return matrixBlock;
    }

    private static CompressedMatrixBlock filterColGroups(CompressedMatrixBlock compressedMatrixBlock) {
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        if (!CLALibUtils.shouldPreFilter(colGroups)) {
            return compressedMatrixBlock;
        }
        double[] dArr = new double[compressedMatrixBlock.getNumColumns()];
        List<AColGroup> filterGroups = CLALibUtils.filterGroups(colGroups, dArr);
        filterGroups.add(ColGroupConst.create(dArr));
        compressedMatrixBlock.allocateColGroupList(filterGroups);
        return compressedMatrixBlock;
    }
}
