package org.apache.sysds.runtime.compress.lib;

import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibAggTernaryOp.class */
public final class CLALibAggTernaryOp {
    private final MatrixBlock m1;
    private final MatrixBlock m2;
    private final MatrixBlock m3;
    private final MatrixBlock ret;
    private final AggregateTernaryOperator op;
    private final boolean inCP;
    private static final Log LOG = LogFactory.getLog(CLALibAggTernaryOp.class.getName());
    private static boolean warned = false;

    public static MatrixBlock agg(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MatrixBlock matrixBlock4, AggregateTernaryOperator aggregateTernaryOperator, boolean z) {
        int i = aggregateTernaryOperator.indexFn instanceof ReduceRow ? 2 : 1;
        int numColumns = aggregateTernaryOperator.indexFn instanceof ReduceRow ? matrixBlock.getNumColumns() : 2;
        if (matrixBlock4 == null) {
            matrixBlock4 = new MatrixBlock(i, numColumns, false);
        } else {
            matrixBlock4.reset(i, numColumns, false);
        }
        return new CLALibAggTernaryOp(matrixBlock, matrixBlock2, matrixBlock3, matrixBlock4, aggregateTernaryOperator, z).exec();
    }

    private CLALibAggTernaryOp(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, MatrixBlock matrixBlock3, MatrixBlock matrixBlock4, AggregateTernaryOperator aggregateTernaryOperator, boolean z) {
        this.m1 = matrixBlock;
        this.m2 = matrixBlock2;
        this.m3 = matrixBlock3;
        this.ret = matrixBlock4;
        this.op = aggregateTernaryOperator;
        this.inCP = z;
    }

    private MatrixBlock exec() {
        if ((this.op.indexFn instanceof ReduceAll) && (this.op.aggOp.increOp.fn instanceof KahanPlus) && (this.op.binaryFn instanceof Multiply)) {
            if (this.m1.isEmptyBlock(false) || this.m2.isEmptyBlock(false) || (this.m3 != null && this.m3.isEmptyBlock(false))) {
                if (this.op.aggOp.existsCorrection() && this.inCP) {
                    this.ret.dropLastRowsOrColumns(this.op.aggOp.correction);
                }
                return this.ret;
            }
            if (isConst(this.m1) && this.m1.quickGetValue(0, 0) == 1.0d) {
                return new CLALibAggTernaryOp(this.m2, this.m3, null, this.ret, this.op, this.inCP).exec();
            }
        }
        return fallBack();
    }

    private static boolean isConst(MatrixBlock matrixBlock) {
        if (matrixBlock == null || !(matrixBlock instanceof CompressedMatrixBlock)) {
            return false;
        }
        List<AColGroup> colGroups = ((CompressedMatrixBlock) matrixBlock).getColGroups();
        return colGroups.size() == 1 && (colGroups.get(0) instanceof ColGroupConst);
    }

    private MatrixBlock fallBack() {
        warnDecompression();
        MatrixBlock aggregateTernaryOperations = MatrixBlock.aggregateTernaryOperations(CompressedMatrixBlock.getUncompressed(this.m1), CompressedMatrixBlock.getUncompressed(this.m2), CompressedMatrixBlock.getUncompressed(this.m3), this.ret, this.op, this.inCP);
        if (aggregateTernaryOperations.getNumRows() == 0 || aggregateTernaryOperations.getNumColumns() == 0) {
            throw new DMLCompressionException("Invalid output");
        }
        return aggregateTernaryOperations;
    }

    private void warnDecompression() {
        if (warned) {
            return;
        }
        boolean z = this.m1 instanceof CompressedMatrixBlock;
        boolean z2 = this.m2 instanceof CompressedMatrixBlock;
        boolean z3 = this.m3 instanceof CompressedMatrixBlock;
        StringBuilder sb = new StringBuilder(120);
        sb.append("aggregateTernaryOperations ");
        sb.append(this.op.aggOp.getClass().getSimpleName());
        sb.append(" ");
        sb.append(this.op.indexFn.getClass().getSimpleName());
        sb.append(" ");
        sb.append(this.op.aggOp.increOp.fn.getClass().getSimpleName());
        sb.append(" ");
        sb.append(this.op.binaryFn.getClass().getSimpleName());
        sb.append(" m1,m2,m3 ");
        sb.append(z);
        sb.append(" ");
        sb.append(z2);
        sb.append(" ");
        sb.append(z3);
        LOG.warn(sb.toString());
        warned = true;
    }
}
