package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupConst.class */
public class ColGroupConst extends AColGroupCompressed {
    private static final long serialVersionUID = -7387793538322386611L;
    protected ADictionary _dict;

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupConst() {
    }

    private ColGroupConst(int[] iArr, ADictionary aDictionary) {
        super(iArr);
        this._dict = aDictionary;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static AColGroup create(int[] iArr, ADictionary aDictionary) {
        return aDictionary == null ? new ColGroupEmpty(iArr) : new ColGroupConst(iArr, aDictionary);
    }

    public static AColGroup create(double[] dArr) {
        return create(Util.genColsIndices(dArr.length), dArr);
    }

    public static AColGroup create(int[] iArr, double d) {
        int length = iArr.length;
        double[] dArr = new double[length];
        for (int i = 0; i < length; i++) {
            dArr[i] = d;
        }
        return create(iArr, dArr);
    }

    public static AColGroup create(int[] iArr, double[] dArr) {
        if (iArr.length != dArr.length) {
            throw new DMLCompressionException("Invalid size of values compared to columns");
        }
        return create(iArr, new Dictionary(dArr));
    }

    public static AColGroup create(int i, ADictionary aDictionary) {
        if (i != aDictionary.getValues().length) {
            throw new DMLCompressionException("Invalid construction of const column group with different number of columns in arguments");
        }
        return create(Util.genColsIndices(i), aDictionary);
    }

    public static AColGroup create(int i, double d) {
        if (i <= 0) {
            throw new DMLCompressionException("Invalid construction of constant column group with cols: " + i);
        }
        int[] genColsIndices = Util.genColsIndices(i);
        return d == DataExpression.DEFAULT_DELIM_FILL_VALUE ? new ColGroupEmpty(genColsIndices) : create(genColsIndices, d);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2, double[] dArr2) {
        double d = dArr2[0];
        for (int i3 = i; i3 < i2; i3++) {
            dArr[i3] = builtin.execute(dArr[i3], d);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.CONST;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.CONST;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void decompressToDenseBlock(DenseBlock denseBlock, int i, int i2, int i3, int i4) {
        if (denseBlock.isContiguous() && this._colIndexes.length == denseBlock.getDim(1) && i4 == 0) {
            decompressToDenseBlockAllColumnsContiguous(denseBlock, i, i2, i3, i4);
        } else {
            decompressToDenseBlockGeneric(denseBlock, i, i2, i3, i4);
        }
    }

    private void decompressToDenseBlockAllColumnsContiguous(DenseBlock denseBlock, int i, int i2, int i3, int i4) {
        double[] values = denseBlock.values(0);
        int length = this._colIndexes.length;
        double[] values2 = this._dict.getValues();
        for (int i5 = i; i5 < i2; i5++) {
            int i6 = 0;
            int i7 = (i3 + i5) * length;
            while (i6 < length) {
                int i8 = i7;
                values[i8] = values[i8] + values2[i6];
                i6++;
                i7++;
            }
        }
    }

    private void decompressToDenseBlockGeneric(DenseBlock denseBlock, int i, int i2, int i3, int i4) {
        int i5 = i;
        int i6 = i + i3;
        while (i5 < i2) {
            double[] values = denseBlock.values(i6);
            int pos = denseBlock.pos(i6) + i4;
            for (int i7 = 0; i7 < this._colIndexes.length; i7++) {
                int i8 = pos + this._colIndexes[i7];
                values[i8] = values[i8] + this._dict.getValue(i7);
            }
            i5++;
            i6++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void decompressToSparseBlock(SparseBlock sparseBlock, int i, int i2, int i3, int i4) {
        int length = this._colIndexes.length;
        int i5 = i;
        int i6 = i + i3;
        while (i5 < i2) {
            for (int i7 = 0; i7 < length; i7++) {
                sparseBlock.append(i6, this._colIndexes[i7] + i4, this._dict.getValue(i7));
            }
            i5++;
            i6++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getIdx(int i, int i2) {
        return this._dict.getValue(i2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup scalarOperation(ScalarOperator scalarOperator) {
        return create(this._colIndexes, this._dict.applyScalarOp(scalarOperator));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup unaryOperation(UnaryOperator unaryOperator) {
        return create(this._colIndexes, this._dict.applyUnaryOp(unaryOperator));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpLeft(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        return create(this._colIndexes, this._dict.binOpLeft(binaryOperator, dArr, this._colIndexes));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpRight(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        return create(this._colIndexes, this._dict.binOpRight(binaryOperator, dArr, this._colIndexes));
    }

    public void addToCommon(double[] dArr) {
        double[] values = this._dict.getValues();
        for (int i = 0; i < this._colIndexes.length; i++) {
            int i2 = this._colIndexes[i];
            dArr[i2] = dArr[i2] + values[i];
        }
    }

    public double[] getValues() {
        return this._dict.getValues();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double computeMxx(double d, Builtin builtin) {
        return this._dict.aggregate(d, builtin);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeColMxx(double[] dArr, Builtin builtin) {
        this._dict.aggregateCols(dArr, builtin, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeSum(double[] dArr, int i) {
        dArr[0] = dArr[0] + this._dict.sum(new int[]{i}, this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void computeColSums(double[] dArr, int i) {
        this._dict.colSum(dArr, new int[]{i}, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeSumSq(double[] dArr, int i) {
        dArr[0] = dArr[0] + this._dict.sumSq(new int[]{i}, this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeColSumsSq(double[] dArr, int i) {
        this._dict.colSumSq(dArr, new int[]{i}, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowSums(double[] dArr, int i, int i2, double[] dArr2) {
        double d = dArr2[0];
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + d;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public int getNumValues() {
        return 1;
    }

    private synchronized MatrixBlock forceValuesToMatrixBlock() {
        this._dict = this._dict.getMBDict(this._colIndexes.length);
        return ((MatrixBlockDictionary) this._dict).getMatrixBlock();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup rightMultByMatrix(MatrixBlock matrixBlock) {
        if (matrixBlock.isEmpty()) {
            return null;
        }
        int numRows = matrixBlock.getNumRows();
        int numColumns = matrixBlock.getNumColumns();
        if (this._colIndexes.length != numRows) {
            throw new NotImplementedException();
        }
        MatrixBlock forceValuesToMatrixBlock = forceValuesToMatrixBlock();
        if (forceValuesToMatrixBlock == null) {
            return null;
        }
        MatrixBlock matrixBlock2 = new MatrixBlock(1, numColumns, false);
        LibMatrixMult.matrixMult(forceValuesToMatrixBlock, matrixBlock, matrixBlock2);
        if (matrixBlock2.isEmpty()) {
            return null;
        }
        return create(numColumns, new MatrixBlockDictionary(matrixBlock2, numColumns));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    public void tsmm(double[] dArr, int i, int i2) {
        tsmm(dArr, i, new int[]{i2}, this._dict, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4) {
        throw new DMLCompressionException("This method should never be called");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void leftMultByAColGroup(AColGroup aColGroup, MatrixBlock matrixBlock) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void tsmmAColGroup(AColGroup aColGroup, MatrixBlock matrixBlock) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected AColGroup sliceSingleColumn(int i) {
        int[] iArr = {0};
        return this._dict.getValue(i) == DataExpression.DEFAULT_DELIM_FILL_VALUE ? new ColGroupEmpty(iArr) : create(iArr, new Dictionary(new double[]{this._dict.getValue(i)}));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected AColGroup sliceMultiColumns(int i, int i2, int[] iArr) {
        return create(iArr, this._dict.sliceOutColumnRange(i, i2, this._colIndexes.length));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup copy() {
        return create(this._colIndexes, this._dict.mo477clone());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public boolean containsValue(double d) {
        return this._dict.containsValue(d);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getNumberNonZeros(int i) {
        return this._dict.getNumberNonZeros(new int[]{i}, this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup replace(double d, double d2) {
        return create(this._colIndexes, this._dict.replace(d, d2, this._colIndexes.length));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this._dict = DictionaryFactory.read(dataInput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        this._dict.write(dataOutput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        return super.getExactSizeOnDisk() + this._dict.getExactSizeOnDisk();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeProduct(double[] dArr, int i) {
        this._dict.product(dArr, new int[]{i}, this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowProduct(double[] dArr, int i, int i2, double[] dArr2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeColProduct(double[] dArr, int i) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDouble(this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSq(this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggProductRows() {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRows(builtin, this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + this._dict.getInMemorySize() + 8;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public CM_COV_Object centralMoment(CMOperator cMOperator, int i) {
        CM_COV_Object cM_COV_Object = new CM_COV_Object();
        cMOperator.fn.execute(cM_COV_Object, this._dict.getValue(0), i);
        return cM_COV_Object;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup rexpandCols(int i, boolean z, boolean z2, int i2) {
        ADictionary rexpandCols = this._dict.rexpandCols(i, z, z2, this._colIndexes.length);
        return rexpandCols == null ? ColGroupEmpty.create(i) : create(i, rexpandCols);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getCost(ComputationCostEstimator computationCostEstimator, int i) {
        return computationCostEstimator.getCost(i, 1, getNumCols(), 1, 1.0d);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public String toString() {
        return super.toString() + String.format("\n%15s", "Values: " + this._dict.getClass().getSimpleName()) + this._dict.getString(this._colIndexes.length);
    }
}
