package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Minus;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.class */
public class ColGroupDDC extends APreAgg {
    private static final long serialVersionUID = -5769772089913918987L;
    protected AMapToData _data;

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupDDC(int i) {
        super(i);
    }

    private ColGroupDDC(int[] iArr, int i, ADictionary aDictionary, AMapToData aMapToData, int[] iArr2) {
        super(iArr, i, aDictionary, iArr2);
        if (aMapToData.getUnique() != aDictionary.getNumberOfValues(iArr.length)) {
            throw new DMLCompressionException("Invalid construction of DDC group " + aMapToData.getUnique() + " vs. " + aDictionary.getNumberOfValues(iArr.length));
        }
        this._zeros = false;
        this._data = aMapToData;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static AColGroup create(int[] iArr, int i, ADictionary aDictionary, AMapToData aMapToData, int[] iArr2) {
        return aDictionary == null ? new ColGroupEmpty(iArr) : new ColGroupDDC(iArr, i, aDictionary, aMapToData, iArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.DDC;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    public void decompressToDenseBlockSparseDictionary(DenseBlock denseBlock, int i, int i2, int i3, int i4, SparseBlock sparseBlock) {
        int i5 = i;
        int i6 = i + i3;
        while (i5 < i2) {
            int index = this._data.getIndex(i5);
            if (!sparseBlock.isEmpty(index)) {
                double[] values = denseBlock.values(i6);
                int pos = denseBlock.pos(i6) + i4;
                int pos2 = sparseBlock.pos(index);
                int size = sparseBlock.size(index) + pos2;
                int[] indexes = sparseBlock.indexes(index);
                double[] values2 = sparseBlock.values(index);
                for (int i7 = pos2; i7 < size; i7++) {
                    int i8 = pos + this._colIndexes[indexes[i7]];
                    values[i8] = values[i8] + values2[i7];
                }
            }
            i5++;
            i6++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    protected void decompressToDenseBlockDenseDictionary(DenseBlock denseBlock, int i, int i2, int i3, int i4, double[] dArr) {
        if (denseBlock.isContiguous() && this._colIndexes.length == 1) {
            if (denseBlock.getDim(1) == 1) {
                decompressToDenseBlockDenseDictSingleColOutContiguous(denseBlock, i, i2, i3, i4, dArr);
                return;
            } else {
                decompressToDenseBlockDenseDictSingleColContiguous(denseBlock, i, i2, i3, i4, dArr);
                return;
            }
        }
        if (denseBlock.isContiguous() && this._colIndexes.length == denseBlock.getDim(1) && i4 == 0) {
            decompressToDenseBlockDenseDictAllColumnsContiguous(denseBlock, i, i2, i3, dArr);
        } else if (denseBlock.isContiguous() && i4 == 0) {
            decompressToDenseBlockDenseDictNoColOffset(denseBlock, i, i2, i3, dArr);
        } else {
            decompressToDenseBlockDenseDictGeneric(denseBlock, i, i2, i3, i4, dArr);
        }
    }

    private void decompressToDenseBlockDenseDictSingleColContiguous(DenseBlock denseBlock, int i, int i2, int i3, int i4, double[] dArr) {
        double[] values = denseBlock.values(0);
        int dim = denseBlock.getDim(1);
        int i5 = i;
        int i6 = (i + i3) * dim;
        int i7 = this._colIndexes[0] + i4;
        while (true) {
            int i8 = i6 + i7;
            if (i5 >= i2) {
                return;
            }
            values[i8] = values[i8] + dArr[this._data.getIndex(i5)];
            i5++;
            i6 = i8;
            i7 = dim;
        }
    }

    private void decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock denseBlock, int i, int i2, int i3, int i4, double[] dArr) {
        double[] values = denseBlock.values(0);
        int i5 = i;
        int i6 = i + i3 + this._colIndexes[0] + i4;
        while (i5 < i2) {
            int i7 = i6;
            values[i7] = values[i7] + dArr[this._data.getIndex(i5)];
            i5++;
            i6++;
        }
    }

    private void decompressToDenseBlockDenseDictAllColumnsContiguous(DenseBlock denseBlock, int i, int i2, int i3, double[] dArr) {
        double[] values = denseBlock.values(0);
        int length = this._colIndexes.length;
        for (int i4 = i; i4 < i2; i4++) {
            int index = this._data.getIndex(i4) * length;
            int i5 = index + length;
            int i6 = index;
            int i7 = (i3 + i4) * length;
            while (i6 < i5) {
                int i8 = i7;
                values[i8] = values[i8] + dArr[i6];
                i6++;
                i7++;
            }
        }
    }

    private void decompressToDenseBlockDenseDictNoColOffset(DenseBlock denseBlock, int i, int i2, int i3, double[] dArr) {
        int length = this._colIndexes.length;
        int dim = denseBlock.getDim(1);
        int i4 = (i + i3) * dim;
        int i5 = i;
        int i6 = i + i3;
        while (i5 < i2) {
            double[] values = denseBlock.values(i6);
            int index = this._data.getIndex(i5) * length;
            for (int i7 = 0; i7 < length; i7++) {
                int i8 = i4 + this._colIndexes[i7];
                values[i8] = values[i8] + dArr[index + i7];
            }
            i5++;
            i4 += dim;
        }
    }

    private void decompressToDenseBlockDenseDictGeneric(DenseBlock denseBlock, int i, int i2, int i3, int i4, double[] dArr) {
        int length = this._colIndexes.length;
        int i5 = i;
        int i6 = i + i3;
        while (i5 < i2) {
            double[] values = denseBlock.values(i6);
            int pos = denseBlock.pos(i6) + i4;
            int index = this._data.getIndex(i5) * length;
            for (int i7 = 0; i7 < length; i7++) {
                int i8 = pos + this._colIndexes[i7];
                values[i8] = values[i8] + dArr[index + i7];
            }
            i5++;
            i6++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    protected void decompressToSparseBlockSparseDictionary(SparseBlock sparseBlock, int i, int i2, int i3, int i4, SparseBlock sparseBlock2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    protected void decompressToSparseBlockDenseDictionary(SparseBlock sparseBlock, int i, int i2, int i3, int i4, double[] dArr) {
        int length = this._colIndexes.length;
        int i5 = i;
        int i6 = i + i3;
        while (i5 < i2) {
            int index = this._data.getIndex(i5) * length;
            for (int i7 = 0; i7 < length; i7++) {
                sparseBlock.append(i6, this._colIndexes[i7] + i4, dArr[index + i7]);
            }
            i5++;
            i6++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getIdx(int i, int i2) {
        return this._dict.getValue((this._data.getIndex(i) * this._colIndexes.length) + i2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowSums(double[] dArr, int i, int i2, double[] dArr2) {
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + dArr2[this._data.getIndex(i3)];
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed
    protected void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2, double[] dArr2) {
        for (int i3 = i; i3 < i2; i3++) {
            dArr[i3] = builtin.execute(dArr[i3], dArr2[this._data.getIndex(i3)]);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue
    public int[] getCounts(int[] iArr) {
        return this._data.getCounts(iArr);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4) {
        if (this._colIndexes.length == 1) {
            leftMultByMatrixNoPreAggSingleCol(matrixBlock, matrixBlock2, i, i2, i3, i4);
        } else {
            lmMatrixNoPreAggMultiCol(matrixBlock, matrixBlock2, i, i2, i3, i4);
        }
    }

    private void leftMultByMatrixNoPreAggSingleCol(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4) {
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        int numColumns = matrixBlock.getNumColumns();
        int numColumns2 = matrixBlock2.getNumColumns();
        double[] values = this._dict.getValues();
        if (matrixBlock.isInSparseFormat()) {
            lmSparseMatrixNoPreAggSingleCol(matrixBlock.getSparseBlock(), numColumns, denseBlockValues, numColumns2, values, i, i2, i3, i4);
        } else {
            lmDenseMatrixNoPreAggSingleCol(matrixBlock.getDenseBlockValues(), numColumns, denseBlockValues, numColumns2, values, i, i2, i3, i4);
        }
    }

    private void lmSparseMatrixNoPreAggSingleCol(SparseBlock sparseBlock, int i, double[] dArr, int i2, double[] dArr2, int i3, int i4, int i5, int i6) {
        int i7 = this._colIndexes[0];
        for (int i8 = i3; i8 < i4; i8++) {
            if (!sparseBlock.isEmpty(i8)) {
                int pos = sparseBlock.pos(i8);
                int size = sparseBlock.size(i8) + pos;
                int[] indexes = sparseBlock.indexes(i8);
                double[] values = sparseBlock.values(i8);
                int i9 = i8 * i2;
                for (int i10 = pos; i10 < size; i10++) {
                    int i11 = i9 + i7;
                    dArr[i11] = dArr[i11] + (values[i10] * dArr2[this._data.getIndex(indexes[i10])]);
                }
            }
        }
    }

    private void lmDenseMatrixNoPreAggSingleCol(double[] dArr, int i, double[] dArr2, int i2, double[] dArr3, int i3, int i4, int i5, int i6) {
        int i7 = this._colIndexes[0];
        for (int i8 = i3; i8 < i4; i8++) {
            int i9 = i8 * i;
            int i10 = i8 * i2;
            for (int i11 = i5; i11 < i6; i11++) {
                int i12 = i10 + i7;
                dArr2[i12] = dArr2[i12] + (dArr[i9 + i11] * dArr3[this._data.getIndex(i8)]);
            }
        }
    }

    private void lmMatrixNoPreAggMultiCol(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4) {
        if (matrixBlock.isInSparseFormat()) {
            lmSparseMatrixNoPreAggMultiCol(matrixBlock, matrixBlock2, i, i2, i3, i4);
        } else {
            lmDenseMatrixNoPreAggMultiCol(matrixBlock, matrixBlock2, i, i2, i3, i4);
        }
    }

    private void lmSparseMatrixNoPreAggMultiCol(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4) {
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        int numColumns = matrixBlock2.getNumColumns();
        SparseBlock sparseBlock = matrixBlock.getSparseBlock();
        for (int i5 = i; i5 < i2; i5++) {
            if (!sparseBlock.isEmpty(i5)) {
                int pos = sparseBlock.pos(i5);
                int size = sparseBlock.size(i5) + pos;
                int[] indexes = sparseBlock.indexes(i5);
                double[] values = sparseBlock.values(i5);
                int i6 = i5 * numColumns;
                for (int i7 = pos; i7 < size; i7++) {
                    this._dict.multiplyScalar(values[i7], denseBlockValues, i6, this._data.getIndex(indexes[i7]), this._colIndexes);
                }
            }
        }
    }

    private void lmDenseMatrixNoPreAggMultiCol(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4) {
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        int numColumns = matrixBlock.getNumColumns();
        int numColumns2 = matrixBlock2.getNumColumns();
        double[] denseBlockValues2 = matrixBlock.getDenseBlockValues();
        for (int i5 = i; i5 < i2; i5++) {
            int i6 = i5 * numColumns;
            int i7 = i5 * numColumns2;
            for (int i8 = i3; i8 < i4; i8++) {
                this._dict.multiplyScalar(denseBlockValues2[i6 + i8], denseBlockValues, i7, this._data.getIndex(i8), this._colIndexes);
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    public void preAggregateDense(MatrixBlock matrixBlock, double[] dArr, int i, int i2, int i3, int i4) {
        this._data.preAggregateDense(matrixBlock, dArr, i, i2, i3, i4);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    public void preAggregateSparse(SparseBlock sparseBlock, double[] dArr, int i, int i2) {
        this._data.preAggregateSparse(sparseBlock, dArr, i, i2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    public void preAggregateThatDDCStructure(ColGroupDDC colGroupDDC, Dictionary dictionary) {
        this._data.preAggregateDDC_DDC(colGroupDDC._data, colGroupDDC._dict, dictionary, colGroupDDC._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    public void preAggregateThatSDCZerosStructure(ColGroupSDCZeros colGroupSDCZeros, Dictionary dictionary) {
        this._data.preAggregateDDC_SDCZ(colGroupSDCZeros._data, colGroupSDCZeros._dict, colGroupSDCZeros._indexes, dictionary, colGroupSDCZeros._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    public void preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros colGroupSDCSingleZeros, Dictionary dictionary) {
        AIterator iterator = colGroupSDCSingleZeros._indexes.getIterator();
        int length = colGroupSDCSingleZeros._colIndexes.length;
        int offsetToLast = colGroupSDCSingleZeros._indexes.getOffsetToLast();
        double[] values = dictionary.getValues();
        while (true) {
            colGroupSDCSingleZeros._dict.addToEntry(values, 0, this._data.getIndex(iterator.value()), length);
            if (iterator.value() == offsetToLast) {
                return;
            } else {
                iterator.next();
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    public boolean sameIndexStructure(AColGroupCompressed aColGroupCompressed) {
        return (aColGroupCompressed instanceof ColGroupDDC) && ((ColGroupDDC) aColGroupCompressed)._data == this._data;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.DDC;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + this._data.getInMemorySize();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup scalarOperation(ScalarOperator scalarOperator) {
        if ((!(scalarOperator.fn instanceof Plus) && !(scalarOperator.fn instanceof Minus)) || !(this._dict instanceof MatrixBlockDictionary) || !((MatrixBlockDictionary) this._dict).getMatrixBlock().isInSparseFormat()) {
            return create(this._colIndexes, this._numRows, this._dict.applyScalarOp(scalarOperator), this._data, getCachedCounts());
        }
        double executeScalar = scalarOperator.executeScalar(DataExpression.DEFAULT_DELIM_FILL_VALUE);
        if (executeScalar == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            return this;
        }
        return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, getCachedCounts(), FORUtil.createReference(this._colIndexes.length, executeScalar));
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup unaryOperation(UnaryOperator unaryOperator) {
        return create(this._colIndexes, this._numRows, this._dict.applyUnaryOp(unaryOperator), this._data, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpLeft(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        return create(this._colIndexes, this._numRows, this._dict.binOpLeft(binaryOperator, dArr, this._colIndexes), this._data, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup binaryRowOpRight(BinaryOperator binaryOperator, double[] dArr, boolean z) {
        if (((binaryOperator.fn instanceof Plus) || (binaryOperator.fn instanceof Minus)) && (this._dict instanceof MatrixBlockDictionary) && ((MatrixBlockDictionary) this._dict).getMatrixBlock().isInSparseFormat()) {
            return ColGroupDDCFOR.create(this._colIndexes, this._numRows, this._dict, this._data, getCachedCounts(), ColGroupUtils.binaryDefRowRight(binaryOperator, dArr, this._colIndexes));
        }
        return create(this._colIndexes, this._numRows, this._dict.binOpRight(binaryOperator, dArr, this._colIndexes), this._data, getCachedCounts());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        this._data.write(dataOutput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this._data = MapToFactory.readIn(dataInput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        return super.getExactSizeOnDisk() + this._data.getExactSizeOnDisk();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public double getCost(ComputationCostEstimator computationCostEstimator, int i) {
        return computationCostEstimator.getCost(i, i, getNumCols(), getNumValues(), this._dict.getSparsity());
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.APreAgg
    protected int numRowsToMultiply() {
        return this._numRows;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroupValue, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public String toString() {
        return super.toString() + String.format("\n%15s ", "Data: ") + this._data;
    }
}
