package org.apache.sysds.runtime.compress.colgroup;

import java.util.Arrays;
import java.util.Iterator;
import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.colgroup.ColGroup;
import org.apache.sysds.runtime.compress.utils.ABitmap;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.IJV;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.class */
public abstract class ColGroupDDC extends ColGroupValue {
    private static final long serialVersionUID = -3204391646123465004L;

    /* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupDDC$DDCIterator.class */
    private class DDCIterator implements Iterator<IJV> {
        private final int _ru;
        private final boolean _inclZeros;
        private int _rpos;
        private int _cpos;
        private final IJV _buff = new IJV();
        private double _value = DataExpression.DEFAULT_DELIM_FILL_VALUE;

        public DDCIterator(int i, int i2, boolean z) {
            this._rpos = -1;
            this._cpos = -1;
            this._ru = i2;
            this._inclZeros = z;
            this._rpos = i;
            this._cpos = -1;
            getNextValue();
        }

        @Override // java.util.Iterator
        public boolean hasNext() {
            return this._rpos < this._ru;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.Iterator
        public IJV next() {
            this._buff.set(this._rpos, ColGroupDDC.this._colIndexes[this._cpos], this._value);
            getNextValue();
            return this._buff;
        }

        private void getNextValue() {
            do {
                boolean z = this._cpos + 1 >= ColGroupDDC.this.getNumCols();
                this._rpos += z ? 1 : 0;
                this._cpos = z ? 0 : this._cpos + 1;
                if (this._rpos >= this._ru) {
                    return;
                }
                this._value = ColGroupDDC.this._dict.getValue(ColGroupDDC.this.getIndex(this._rpos, this._cpos));
                if (this._inclZeros) {
                    return;
                }
            } while (this._value == DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupDDC$DDCRowIterator.class */
    private class DDCRowIterator extends ColGroup.ColGroupRowIterator {
        public DDCRowIterator(int i, int i2) {
            super();
        }

        @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup.ColGroupRowIterator
        public void next(double[] dArr, int i, int i2, boolean z) {
            int numCols = ColGroupDDC.this.getNumCols();
            int index = ColGroupDDC.this.getIndex(i) * numCols;
            double[] values = ColGroupDDC.this.getValues();
            for (int i3 = 0; i3 < numCols; i3++) {
                dArr[ColGroupDDC.this._colIndexes[i3]] = values[index + i3];
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupDDC() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupDDC(int[] iArr, int i, ABitmap aBitmap, CompressionSettings compressionSettings) {
        super(iArr, i, aBitmap, compressionSettings);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupDDC(int[] iArr, int i, ADictionary aDictionary) {
        super(iArr, i, aDictionary);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public ColGroup.CompressionType getCompType() {
        return ColGroup.CompressionType.DDC;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int i, int i2) {
        int numCols = getNumCols();
        double[] values = getValues();
        for (int i3 = i; i3 < i2; i3++) {
            for (int i4 = 0; i4 < numCols; i4++) {
                matrixBlock.appendValue(i3, this._colIndexes[i4], getData(i3, i4, values));
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int[] iArr) {
        int numCols = getNumCols();
        double[] values = getValues();
        for (int i = 0; i < this._numRows; i++) {
            for (int i2 = 0; i2 < numCols; i2++) {
                matrixBlock.quickSetValue(i, iArr[getColIndex(i2)], getData(i, i2, values));
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void decompressToBlock(MatrixBlock matrixBlock, int i) {
        int numCols = getNumCols();
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        double[] values = getValues();
        int i2 = 0;
        for (int i3 = 0; i3 < this._numRows; i3++) {
            int index = getIndex(i3);
            if (index != values.length) {
                int i4 = i2;
                double d = values[(index * numCols) + i];
                denseBlockValues[i3] = d;
                i2 = i4 + (d != DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1 : 0);
            } else {
                denseBlockValues[i3] = 0.0d;
            }
        }
        matrixBlock.setNonZeros(i2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public double get(int i, int i2) {
        int binarySearch = Arrays.binarySearch(this._colIndexes, i2);
        if (binarySearch < 0) {
            throw new RuntimeException("Column index " + i2 + " not in DDC group.");
        }
        int index = getIndex(i, binarySearch);
        return index != getNumValues() ? this._dict.getValue(index) : DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void countNonZerosPerRow(int[] iArr, int i, int i2) {
        int numCols = getNumCols();
        int numValues = getNumValues();
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = 0;
            for (int i5 = 0; i5 < numCols; i5++) {
                if (getIndex(i3, i5) < numValues) {
                    i4 += this._dict.getValue(getIndex(i3, i5)) != DataExpression.DEFAULT_DELIM_FILL_VALUE ? 1 : 0;
                }
            }
            int i6 = i3 - i;
            iArr[i6] = iArr[i6] + i4;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    protected void computeSum(double[] dArr, KahanFunction kahanFunction) {
        dArr[0] = dArr[0] + this._dict.sum(getCounts(), this._colIndexes.length, kahanFunction);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    protected void computeColSums(double[] dArr, KahanFunction kahanFunction) {
        this._dict.colSum(dArr, getCounts(), this._colIndexes, kahanFunction);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    protected void computeRowSums(double[] dArr, KahanFunction kahanFunction, int i, int i2, boolean z) {
        int numValues = getNumValues();
        KahanObject kahanObject = new KahanObject(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        KahanPlus kahanPlusFnObject = KahanPlus.getKahanPlusFnObject();
        double[] sumAllRowsToDouble = this._dict.sumAllRowsToDouble(kahanFunction, kahanObject, this._colIndexes.length);
        for (int i3 = i; i3 < i2; i3++) {
            int index = getIndex(i3);
            if (index != numValues) {
                setandExecute(dArr, kahanObject, kahanPlusFnObject, sumAllRowsToDouble[index], i3 * (2 + (z ? 1 : 0)));
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    protected void computeRowMxx(double[] dArr, Builtin builtin, int i, int i2) {
        int numValues = getNumValues();
        int numCols = getNumCols();
        double[] values = getValues();
        for (int i3 = i; i3 < i2; i3++) {
            int index = getIndex(i3);
            if (index != numValues) {
                for (int i4 = 0; i4 < numCols; i4++) {
                    dArr[i3] = builtin.execute(dArr[i3], values[index + i4]);
                }
            } else {
                dArr[i3] = builtin.execute(dArr[i3], DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
        }
    }

    public void postScaling(double[] dArr, double[] dArr2, double[] dArr3, int i) {
        postScaling(dArr, dArr2, dArr3, i, 0, 0);
    }

    public void postScaling(double[] dArr, double[] dArr2, double[] dArr3, int i, int i2, int i3) {
        int numCols = getNumCols();
        for (int i4 = 0; i4 < numCols; i4++) {
            int i5 = this._colIndexes[i4] + (i2 * i3);
            int i6 = 0;
            int i7 = 0;
            while (true) {
                int i8 = i7;
                if (i6 < i) {
                    double d = dArr2[i6];
                    if (i8 != i) {
                        dArr3[i5] = dArr3[i5] + (d * dArr[i8 + i4]);
                    }
                    i6++;
                    i7 = i8 + numCols;
                }
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    public int[] getCounts(int[] iArr) {
        return getCounts(0, this._numRows, iArr);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    public int[] getCounts(int i, int i2, int[] iArr) {
        for (int i3 = i; i3 < i2; i3++) {
            int index = getIndex(i3);
            iArr[index] = iArr[index] + 1;
        }
        return iArr;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void rightMultByMatrix(double[] dArr, double[] dArr2, int i, double[] dArr3, int i2, int i3, int i4) {
        throw new NotImplementedException("Not Implemented");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void leftMultByMatrix(double[] dArr, double[] dArr2, int i, double[] dArr3, int i2, int i3, int i4, int i5, int i6) {
        int index;
        int numValues = i == -1 ? getNumValues() : i;
        int i7 = i4;
        int i8 = i6;
        while (i7 < i5) {
            if (8 * numValues < this._numRows) {
                postScaling(dArr3, preAggregate(dArr, numValues, i8), dArr2, numValues, i7, i3);
            } else {
                int i9 = 0;
                int i10 = i8 * this._numRows;
                while (i9 < this._numRows) {
                    double d = dArr[i10];
                    if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE && (index = getIndex(i9) * this._colIndexes.length) != numValues) {
                        for (int i11 = 0; i11 < this._colIndexes.length; i11++) {
                            int i12 = this._colIndexes[i11] + (i7 * i3);
                            dArr2[i12] = dArr2[i12] + (d * dArr3[index + i11]);
                        }
                    }
                    i9++;
                    i10++;
                }
            }
            i7++;
            i8++;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void leftMultByRowVector(double[] dArr, double[] dArr2, int i) {
        leftMultByRowVector(dArr, dArr2, i == -1 ? getNumValues() : i, getValues());
    }

    public double[] preAggregate(double[] dArr, int i) {
        return preAggregate(dArr, i, 0);
    }

    public double[] preAggregate(double[] dArr, int i, int i2) {
        double[] allocDVector;
        if (i2 > 0) {
            allocDVector = allocDVector(i, true);
            int i3 = 0;
            int i4 = this._numRows * i2;
            while (i3 < this._numRows) {
                int index = getIndex(i3);
                if (index != i) {
                    allocDVector[index] = allocDVector[index] + dArr[i4];
                }
                i3++;
                i4++;
            }
        } else {
            allocDVector = allocDVector(i, true);
            for (int i5 = 0; i5 < this._numRows; i5++) {
                int index2 = getIndex(i5);
                if (index2 != i) {
                    allocDVector[index2] = allocDVector[index2] + dArr[i5];
                }
            }
        }
        return allocDVector;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public void leftMultByRowVector(double[] dArr, double[] dArr2, int i, double[] dArr3) {
        int numValues = i == -1 ? getNumValues() : i;
        if (8 * numValues < this._numRows) {
            postScaling(dArr3, preAggregate(dArr, numValues), dArr2, numValues);
            return;
        }
        for (int i2 = 0; i2 < this._numRows; i2++) {
            double d = dArr[i2];
            if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                int index = getIndex(i2) * this._colIndexes.length;
                for (int i3 = 0; i3 < this._colIndexes.length; i3++) {
                    if (index != numValues) {
                        int i4 = this._colIndexes[i3];
                        dArr2[i4] = dArr2[i4] + (d * dArr3[index + i3]);
                    }
                }
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public Iterator<IJV> getIterator(int i, int i2, boolean z, boolean z2) {
        return new DDCIterator(i, i2, z);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroup
    public ColGroup.ColGroupRowIterator getRowIterator(int i, int i2) {
        return new DDCRowIterator(i, i2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupValue
    public String toString() {
        return super.toString();
    }

    protected abstract int getIndex(int i);

    protected abstract int getIndex(int i, int i2);

    protected abstract double getData(int i, double[] dArr);

    protected abstract double getData(int i, int i2, double[] dArr);

    protected abstract void setData(int i, int i2);
}
