package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupConst.class */
public class ColGroupConst extends ColGroupCompressed {
    private static final long serialVersionUID = -7387793538322386611L;
    protected ADictionary _dict;

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupConst() {
    }

    public ColGroupConst(int[] iArr, ADictionary aDictionary) {
        super(iArr);
        this._dict = aDictionary;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeRowSums(double[] dArr, boolean z, int i, int i2) {
        double d = this._dict.sumAllRowsToDouble(z, this._colIndexes.length)[0];
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + d;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2) {
        double d = this._dict.aggregateTuples(builtin, this._colIndexes.length)[0];
        for (int i3 = i; i3 < i2; i3++) {
            dArr[i3] = builtin.execute(dArr[i3], d);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.CONST;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.CONST;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int i, int i2, int i3) {
        DenseBlock denseBlock = matrixBlock.getDenseBlock();
        int i4 = i;
        while (i4 < i2) {
            double[] values = denseBlock.values(i3);
            int pos = denseBlock.pos(i3);
            for (int i5 = 0; i5 < this._colIndexes.length; i5++) {
                int i6 = pos + this._colIndexes[i5];
                values[i6] = values[i6] + this._dict.getValue(i5);
            }
            i4++;
            i3++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double get(int i, int i2) {
        return this._dict.getValue(Arrays.binarySearch(this._colIndexes, i2));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup scalarOperation(ScalarOperator scalarOperator) {
        return new ColGroupConst(this._colIndexes, this._dict.mo453clone().apply(scalarOperator));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOp(BinaryOperator binaryOperator, double[] dArr, boolean z, boolean z2) {
        return new ColGroupConst(this._colIndexes, this._dict.mo453clone().applyBinaryRowOp(binaryOperator, dArr, true, this._colIndexes, z2));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void countNonZerosPerRow(int[] iArr, int i, int i2) {
        int i3 = 0;
        for (double d : this._dict.getValues()) {
            i3 += d == DataExpression.DEFAULT_DELIM_FILL_VALUE ? 0 : 1;
        }
        for (int i4 = 0; i4 < i2 - i; i4++) {
            iArr[i4] = i3;
        }
    }

    public void addToCommon(double[] dArr) {
        double[] values = this._dict.getValues();
        if (values == null || dArr == null) {
            return;
        }
        for (int i = 0; i < this._colIndexes.length; i++) {
            int i2 = this._colIndexes[i];
            dArr[i2] = dArr[i2] + values[i];
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double[] getValues() {
        if (this._dict != null) {
            return this._dict.getValues();
        }
        return null;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final boolean isLossy() {
        return this._dict.isLossy();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected double computeMxx(double d, Builtin builtin) {
        return this._dict.aggregate(d, builtin);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeColMxx(double[] dArr, Builtin builtin) {
        this._dict.aggregateCols(dArr, builtin, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeSum(double[] dArr, int i, boolean z) {
        if (this._dict != null) {
            if (z) {
                dArr[0] = dArr[0] + this._dict.sumsq(new int[]{i}, this._colIndexes.length);
            } else {
                dArr[0] = dArr[0] + this._dict.sum(new int[]{i}, this._colIndexes.length);
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeColSums(double[] dArr, int i, boolean z) {
        this._dict.colSum(dArr, new int[]{i}, this._colIndexes, z);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public int getNumValues() {
        return 1;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public MatrixBlock getValuesAsBlock() {
        this._dict = this._dict.getAsMatrixBlockDictionary(this._colIndexes.length);
        return ((MatrixBlockDictionary) this._dict).getMatrixBlock();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup rightMultByMatrix(MatrixBlock matrixBlock) {
        if (matrixBlock.isEmpty()) {
            return null;
        }
        int numRows = matrixBlock.getNumRows();
        int numColumns = matrixBlock.getNumColumns();
        if (this._colIndexes.length != numRows) {
            throw new NotImplementedException();
        }
        MatrixBlock valuesAsBlock = getValuesAsBlock();
        MatrixBlock matrixBlock2 = new MatrixBlock(1, numColumns, false);
        LibMatrixMult.matrixMult(valuesAsBlock, matrixBlock, matrixBlock2);
        return ColGroupFactory.getColGroupConst(numColumns, new MatrixBlockDictionary(matrixBlock2));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    public void tsmm(double[] dArr, int i, int i2) {
        tsmm(dArr, i, new int[]{i2}, this._dict, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void leftMultByMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void leftMultByAColGroup(AColGroup aColGroup, MatrixBlock matrixBlock) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void tsmmAColGroup(AColGroup aColGroup, MatrixBlock matrixBlock) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public boolean isDense() {
        return true;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected AColGroup sliceSingleColumn(int i) {
        int[] iArr = {0};
        return this._dict.getValue(i) == DataExpression.DEFAULT_DELIM_FILL_VALUE ? new ColGroupEmpty(iArr) : new ColGroupConst(iArr, new Dictionary(new double[]{this._dict.getValue(i)}));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected AColGroup sliceMultiColumns(int i, int i2, int[] iArr) {
        return new ColGroupConst(iArr, this._dict.sliceOutColumnRange(i, i2, this._colIndexes.length));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup copy() {
        return new ColGroupConst(this._colIndexes, this._dict.mo453clone());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public boolean containsValue(double d) {
        return this._dict.containsValue(d);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getNumberNonZeros(int i) {
        return this._dict.getNumberNonZeros(new int[]{i}, this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup replace(double d, double d2) {
        return new ColGroupConst(this._colIndexes, this._dict.replace(d, d2, this._colIndexes.length));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this._dict = DictionaryFactory.read(dataInput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        this._dict.write(dataOutput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        long exactSizeOnDisk = super.getExactSizeOnDisk();
        if (this._dict != null) {
            exactSizeOnDisk += this._dict.getExactSizeOnDisk();
        }
        return exactSizeOnDisk;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public String toString() {
        return super.toString() + String.format("\n%15s ", "Values: " + this._dict.getClass().getSimpleName()) + this._dict.getString(this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeProduct(double[] dArr, int i) {
        double[] values = getValues();
        for (int i2 = 0; i2 < this._colIndexes.length; i2++) {
            double d = values[i2];
            if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                dArr[0] = dArr[0] * Math.pow(d, i);
            } else {
                dArr[0] = 0.0d;
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeRowProduct(double[] dArr, int i, int i2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeColProduct(double[] dArr, int i) {
        throw new NotImplementedException();
    }
}
