package org.apache.sysds.runtime.compress.colgroup;

import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/AColGroupCompressed.class */
public abstract class AColGroupCompressed extends AColGroup {
    private static final long serialVersionUID = 6219835795420081223L;

    /* JADX INFO: Access modifiers changed from: protected */
    public AColGroupCompressed() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AColGroupCompressed(int[] iArr) {
        super(iArr);
    }

    protected abstract double computeMxx(double d, Builtin builtin);

    protected abstract void computeColMxx(double[] dArr, Builtin builtin);

    protected abstract void computeSum(double[] dArr, int i);

    protected abstract void computeSumSq(double[] dArr, int i);

    protected abstract void computeColSumsSq(double[] dArr, int i);

    protected abstract void computeRowSums(double[] dArr, int i, int i2, double[] dArr2);

    protected abstract void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2, double[] dArr2);

    protected abstract void computeProduct(double[] dArr, int i);

    protected abstract void computeRowProduct(double[] dArr, int i, int i2, double[] dArr2);

    protected abstract void computeColProduct(double[] dArr, int i);

    protected abstract double[] preAggSumRows();

    protected abstract double[] preAggSumSqRows();

    protected abstract double[] preAggProductRows();

    protected abstract double[] preAggBuiltinRows(Builtin builtin);

    public double[] preAggRows(AggregateUnaryOperator aggregateUnaryOperator) {
        ValueFunction valueFunction = aggregateUnaryOperator.aggOp.increOp.fn;
        if (valueFunction instanceof KahanPlusSq) {
            return preAggSumSqRows();
        }
        if ((valueFunction instanceof Plus) || (valueFunction instanceof KahanPlus)) {
            return preAggSumRows();
        }
        if (valueFunction instanceof Multiply) {
            return preAggProductRows();
        }
        if (!(valueFunction instanceof Builtin)) {
            throw new DMLScriptException("Unknown UnaryAggregate operator on CompressedMatrixBlock " + aggregateUnaryOperator);
        }
        Builtin builtin = (Builtin) valueFunction;
        Builtin.BuiltinCode builtinCode = builtin.getBuiltinCode();
        if (builtinCode == Builtin.BuiltinCode.MAX || builtinCode == Builtin.BuiltinCode.MIN) {
            return preAggBuiltinRows(builtin);
        }
        throw new DMLScriptException("unsupported builtin type: " + builtin);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getMin() {
        return computeMxx(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getMax() {
        return computeMxx(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, double[] dArr, int i, int i2, int i3) {
        unaryAggregateOperations(aggregateUnaryOperator, dArr, i, i2, i3, null);
    }

    public final void unaryAggregateOperations(AggregateUnaryOperator aggregateUnaryOperator, double[] dArr, int i, int i2, int i3, double[] dArr2) {
        ValueFunction valueFunction = aggregateUnaryOperator.aggOp.increOp.fn;
        if (valueFunction instanceof KahanPlusSq) {
            if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
                computeSumSq(dArr, i);
                return;
            } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
                computeRowSums(dArr, i2, i3, dArr2);
                return;
            } else {
                if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
                    computeColSumsSq(dArr, i);
                    return;
                }
                return;
            }
        }
        if ((valueFunction instanceof Plus) || (valueFunction instanceof KahanPlus)) {
            if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
                computeSum(dArr, i);
                return;
            } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
                computeRowSums(dArr, i2, i3, dArr2);
                return;
            } else {
                if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
                    computeColSums(dArr, i);
                    return;
                }
                return;
            }
        }
        if (valueFunction instanceof Multiply) {
            if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
                computeProduct(dArr, i);
                return;
            } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
                computeRowProduct(dArr, i2, i3, dArr2);
                return;
            } else {
                if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
                    computeColProduct(dArr, i);
                    return;
                }
                return;
            }
        }
        if (!(valueFunction instanceof Builtin)) {
            throw new DMLScriptException("Unknown UnaryAggregate operator on CompressedMatrixBlock");
        }
        Builtin builtin = (Builtin) valueFunction;
        Builtin.BuiltinCode builtinCode = builtin.getBuiltinCode();
        if (builtinCode != Builtin.BuiltinCode.MAX && builtinCode != Builtin.BuiltinCode.MIN) {
            throw new DMLScriptException("unsupported builtin type: " + builtin);
        }
        if (aggregateUnaryOperator.indexFn instanceof ReduceAll) {
            dArr[0] = computeMxx(dArr[0], builtin);
        } else if (aggregateUnaryOperator.indexFn instanceof ReduceCol) {
            computeRowMxx(dArr, builtin, i2, i3, dArr2);
        } else if (aggregateUnaryOperator.indexFn instanceof ReduceRow) {
            computeColMxx(dArr, builtin);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void tsmm(MatrixBlock matrixBlock, int i) {
        tsmm(matrixBlock.getDenseBlockValues(), matrixBlock.getNumColumns(), i);
    }

    protected abstract void tsmm(double[] dArr, int i, int i2);

    /* JADX INFO: Access modifiers changed from: protected */
    public static void tsmm(double[] dArr, int i, int[] iArr, ADictionary aDictionary, int[] iArr2) {
        MatrixBlockDictionary mBDict = aDictionary.getMBDict(iArr2.length);
        if (!(mBDict instanceof MatrixBlockDictionary)) {
            tsmmDense(dArr, i, mBDict.getValues(), iArr, iArr2);
            return;
        }
        MatrixBlock matrixBlock = mBDict.getMatrixBlock();
        if (matrixBlock.isEmpty()) {
            return;
        }
        if (matrixBlock.isInSparseFormat()) {
            tsmmSparse(dArr, i, matrixBlock.getSparseBlock(), iArr, iArr2);
        } else {
            tsmmDense(dArr, i, matrixBlock.getDenseBlockValues(), iArr, iArr2);
        }
    }

    protected static void tsmmDense(double[] dArr, int i, double[] dArr2, int[] iArr, int[] iArr2) {
        if (dArr2 == null) {
            return;
        }
        int length = iArr2.length;
        int length2 = iArr.length;
        for (int i2 = 0; i2 < length2; i2++) {
            int i3 = length * i2;
            int i4 = iArr[i2];
            for (int i5 = 0; i5 < length; i5++) {
                int i6 = i * iArr2[i5];
                double d = dArr2[i3 + i5] * i4;
                if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    for (int i7 = i5; i7 < length; i7++) {
                        int i8 = i6 + iArr2[i7];
                        dArr[i8] = dArr[i8] + (d * dArr2[i3 + i7]);
                    }
                }
            }
        }
    }

    protected static void tsmmSparse(double[] dArr, int i, SparseBlock sparseBlock, int[] iArr, int[] iArr2) {
        for (int i2 = 0; i2 < iArr.length; i2++) {
            if (!sparseBlock.isEmpty(i2)) {
                int pos = sparseBlock.pos(i2);
                int size = sparseBlock.size(i2);
                int[] indexes = sparseBlock.indexes(i2);
                double[] values = sparseBlock.values(i2);
                for (int i3 = pos; i3 < pos + size; i3++) {
                    int i4 = iArr2[indexes[i3]] * i;
                    double d = values[i3] * iArr[i2];
                    for (int i5 = i3; i5 < pos + size; i5++) {
                        int i6 = i4 + iArr2[indexes[i5]];
                        dArr[i6] = dArr[i6] + (d * values[i5]);
                    }
                }
            }
        }
    }
}
