package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.ref.SoftReference;
import java.util.Arrays;
import java.util.HashSet;
import org.apache.commons.lang.NotImplementedException;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.data.DenseBlockFP64;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ColGroupValue.class */
public abstract class ColGroupValue extends ColGroupCompressed implements Cloneable {
    private static final long serialVersionUID = -6835757655517301955L;
    private static ThreadLocal<Pair<int[], double[]>> memPool = new ThreadLocal<Pair<int[], double[]>>() { // from class: org.apache.sysds.runtime.compress.colgroup.ColGroupValue.1
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public Pair<int[], double[]> initialValue() {
            return null;
        }
    };
    private static ThreadLocal<double[]> tmpLeftMultDoubleArray = new ThreadLocal<double[]>() { // from class: org.apache.sysds.runtime.compress.colgroup.ColGroupValue.2
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public double[] initialValue() {
            return null;
        }
    };
    protected final int _numRows;
    protected boolean _zeros;
    protected transient ADictionary _dict;
    private transient SoftReference<int[]> counts;

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupValue(int i) {
        this._zeros = false;
        this._numRows = i;
    }

    protected ColGroupValue(int[] iArr, int i, ADictionary aDictionary) {
        super(iArr);
        this._zeros = false;
        this._numRows = i;
        this._dict = aDictionary;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColGroupValue(int[] iArr, int i, ADictionary aDictionary, int[] iArr2) {
        super(iArr);
        this._zeros = false;
        this._numRows = i;
        this._dict = aDictionary;
        this.counts = new SoftReference<>(iArr2);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void decompressToBlock(MatrixBlock matrixBlock, int i, int i2, int i3) {
        if (!(this._dict instanceof MatrixBlockDictionary)) {
            decompressToBlockDenseDictionary(matrixBlock, i, i2, i3, this._dict.getValues());
            return;
        }
        MatrixBlock matrixBlock2 = ((MatrixBlockDictionary) this._dict).getMatrixBlock();
        if (matrixBlock2.isEmpty()) {
            return;
        }
        if (matrixBlock2.isInSparseFormat()) {
            decompressToBlockSparseDictionary(matrixBlock, i, i2, i3, matrixBlock2.getSparseBlock());
        } else {
            decompressToBlockDenseDictionary(matrixBlock, i, i2, i3, matrixBlock2.getDenseBlockValues());
        }
    }

    protected abstract void decompressToBlockSparseDictionary(MatrixBlock matrixBlock, int i, int i2, int i3, SparseBlock sparseBlock);

    protected abstract void decompressToBlockDenseDictionary(MatrixBlock matrixBlock, int i, int i2, int i3, double[] dArr);

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final int getNumValues() {
        return this._dict.getNumberOfValues(this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final double[] getValues() {
        if (this._dict != null) {
            return this._dict.getValues();
        }
        return null;
    }

    public final ADictionary getDictionary() {
        return this._dict;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final MatrixBlock getValuesAsBlock() {
        this._dict = this._dict.getAsMatrixBlockDictionary(this._colIndexes.length);
        MatrixBlock matrixBlock = ((MatrixBlockDictionary) this._dict).getMatrixBlock();
        if (!this._zeros) {
            return matrixBlock;
        }
        MatrixBlock matrixBlock2 = new MatrixBlock();
        matrixBlock.append(new MatrixBlock(1, this._colIndexes.length, 0L), matrixBlock2, false);
        return matrixBlock2;
    }

    public final int[] getCounts() {
        int[] iArr = null;
        if (this._dict != null) {
            if (this.counts == null || this.counts.get() == null) {
                iArr = getCounts(new int[getNumValues() + (this._zeros ? 1 : 0)]);
                this.counts = new SoftReference<>(iArr);
            } else {
                iArr = this.counts.get();
            }
        }
        return iArr;
    }

    protected final void setCounts(int[] iArr) {
        this.counts = new SoftReference<>(iArr);
    }

    public final int[] getCachedCounts() {
        if (this.counts != null) {
            return this.counts.get();
        }
        return null;
    }

    public final int[] getCounts(int i, int i2) {
        return getCounts(i, i2, this._zeros ? allocIVector(getNumValues() + 1, true) : allocIVector(getNumValues(), true));
    }

    public boolean getIfCountsType() {
        return true;
    }

    protected final double sumValues(int i, double[] dArr, double[] dArr2) {
        int numCols = getNumCols();
        int i2 = i * numCols;
        double d = 0.0d;
        for (int i3 = 0; i3 < numCols; i3++) {
            d += dArr2[i2 + i3] * dArr[this._colIndexes[i3]];
        }
        return d;
    }

    protected final double sumValues(int i, double[] dArr, double[] dArr2, int i2) {
        int numCols = getNumCols();
        int i3 = i * numCols;
        double d = 0.0d;
        for (int i4 = 0; i4 < numCols; i4++) {
            d += dArr2[i3 + i4] * dArr[this._colIndexes[i4] + i2];
        }
        return d;
    }

    private int[] getAggregateColumnsSetDense(double[] dArr, int i, int i2, int i3) {
        HashSet hashSet = new HashSet();
        int i4 = i2 - i;
        for (int i5 = 0; i5 < this._colIndexes.length; i5++) {
            int i6 = this._colIndexes[i5] * i3;
            for (int i7 = i; i7 < i2; i7++) {
                if (dArr[i6 + i7] != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    hashSet.add(Integer.valueOf(i7));
                }
            }
            if (hashSet.size() == i4) {
                break;
            }
        }
        int[] array = hashSet.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
        Arrays.sort(array);
        return array;
    }

    private int[] getAggregateColumnsSetSparse(SparseBlock sparseBlock, int i) {
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this._colIndexes.length; i2++) {
            int i3 = this._colIndexes[i2];
            if (!sparseBlock.isEmpty(i3)) {
                int[] indexes = sparseBlock.indexes(i3);
                for (int pos = sparseBlock.pos(i3); pos < sparseBlock.size(i3) + sparseBlock.pos(i3); pos++) {
                    hashSet.add(Integer.valueOf(indexes[pos]));
                }
            }
            if (hashSet.size() == i) {
                break;
            }
        }
        int[] array = hashSet.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
        Arrays.sort(array);
        return array;
    }

    private double[] preaggValuesFromSparse(int i, SparseBlock sparseBlock, int[] iArr, int i2, int i3, int i4) {
        double[] dArr = new double[i * iArr.length];
        for (int i5 = 0; i5 < this._colIndexes.length; i5++) {
            int i6 = this._colIndexes[i5];
            if (!sparseBlock.isEmpty(i6)) {
                double[] values = sparseBlock.values(i6);
                int[] indexes = sparseBlock.indexes(i6);
                int i7 = 0;
                for (int pos = sparseBlock.pos(i6); pos < sparseBlock.size(i6) + sparseBlock.pos(i6); pos++) {
                    while (iArr[i7] < indexes[pos]) {
                        i7++;
                    }
                    if (indexes[pos] == iArr[i7]) {
                        int i8 = 0;
                        int i9 = i5;
                        while (true) {
                            int i10 = i9;
                            if (i8 < i * iArr.length) {
                                int i11 = i8 + i7;
                                dArr[i11] = dArr[i11] + (this._dict.getValue(i10) * values[pos]);
                                i8 += iArr.length;
                                i9 = i10 + this._colIndexes.length;
                            }
                        }
                    }
                }
            }
        }
        return dArr;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected final double computeMxx(double d, Builtin builtin) {
        if (this._zeros) {
            d = builtin.execute(d, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        return this._dict != null ? this._dict.aggregate(d, builtin) : d;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected final void computeColMxx(double[] dArr, Builtin builtin) {
        if (this._zeros) {
            for (int i = 0; i < this._colIndexes.length; i++) {
                dArr[this._colIndexes[i]] = builtin.execute(dArr[this._colIndexes[i]], DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
        }
        if (this._dict != null) {
            this._dict.aggregateCols(dArr, builtin, this._colIndexes);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final ADictionary applyScalarOp(ScalarOperator scalarOperator) {
        return this._dict.mo453clone().apply(scalarOperator);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final ADictionary applyScalarOp(ScalarOperator scalarOperator, double d, int i) {
        return this._dict.applyScalarOp(scalarOperator, d, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final ADictionary applyBinaryRowOp(BinaryOperator binaryOperator, double[] dArr, boolean z, boolean z2) {
        return z ? this._dict.mo453clone().applyBinaryRowOp(binaryOperator, dArr, z, this._colIndexes, z2) : this._dict.applyBinaryRowOp(binaryOperator, dArr, z, this._colIndexes, z2);
    }

    public static void setupThreadLocalMemory(int i) {
        if (memPool.get() == null || ((int[]) memPool.get().getLeft()).length < i) {
            memPool.set(new ImmutablePair<>(new int[i], new double[i]));
        }
    }

    public static void setupLeftMultThreadLocalMemory(int i) {
        if (tmpLeftMultDoubleArray.get() == null || tmpLeftMultDoubleArray.get().length < i) {
            tmpLeftMultDoubleArray.set(new double[i]);
        }
    }

    public static void cleanupThreadLocalMemory() {
        memPool.remove();
    }

    protected static double[] allocDVector(int i, boolean z) {
        Pair<int[], double[]> pair = memPool.get();
        if (pair == null) {
            return new double[i];
        }
        if (((double[]) pair.getValue()).length < i) {
            setupThreadLocalMemory(i);
            return (double[]) pair.getValue();
        }
        double[] dArr = (double[]) pair.getValue();
        if (z) {
            Arrays.fill(dArr, 0, i, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        return dArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int[] allocIVector(int i, boolean z) {
        Pair<int[], double[]> pair = memPool.get();
        if (pair == null) {
            return new int[i + 1];
        }
        if (((int[]) pair.getKey()).length < i) {
            setupThreadLocalMemory(i);
            return (int[]) pair.getKey();
        }
        int[] iArr = (int[]) pair.getKey();
        if (z) {
            Arrays.fill(iArr, 0, i, 0);
        }
        return iArr;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(" Is Lossy: " + this._dict.isLossy() + " num Rows: " + getNumRows() + " contain zero row:" + this._zeros);
        sb.append(super.toString());
        if (this._dict != null) {
            sb.append(String.format("\n%15s ", "Values: " + this._dict.getClass().getSimpleName()));
            sb.append(this._dict.getString(this._colIndexes.length));
        }
        return sb.toString();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed, org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final boolean isLossy() {
        return this._dict.isLossy();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void readFields(DataInput dataInput) throws IOException {
        super.readFields(dataInput);
        this._zeros = dataInput.readBoolean();
        this._dict = DictionaryFactory.read(dataInput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        dataOutput.writeBoolean(this._zeros);
        this._dict.write(dataOutput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        long exactSizeOnDisk = super.getExactSizeOnDisk() + 1;
        if (this._dict != null) {
            exactSizeOnDisk += this._dict.getExactSizeOnDisk();
        }
        return exactSizeOnDisk;
    }

    public abstract int[] getCounts(int[] iArr);

    public abstract int[] getCounts(int i, int i2, int[] iArr);

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected final void computeSum(double[] dArr, int i, boolean z) {
        if (this._dict != null) {
            if (z) {
                dArr[0] = dArr[0] + this._dict.sumsq(getCounts(), this._colIndexes.length);
            } else {
                dArr[0] = dArr[0] + this._dict.sum(getCounts(), this._colIndexes.length);
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected final void computeColSums(double[] dArr, int i, boolean z) {
        this._dict.colSum(dArr, getCounts(), this._colIndexes, z);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeProduct(double[] dArr, int i) {
        dArr[0] = dArr[0] * this._dict.product(getCounts(), this._colIndexes.length);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeRowProduct(double[] dArr, int i, int i2) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected void computeColProduct(double[] dArr, int i) {
        this._dict.colProduct(dArr, getCounts(), this._colIndexes);
    }

    protected Object clone() {
        try {
            return super.clone();
        } catch (CloneNotSupportedException e) {
            throw new DMLCompressionException("Error while cloning: " + getClass().getSimpleName(), e);
        }
    }

    public AColGroup copyAndSet(double[] dArr) {
        return copyAndSet(new Dictionary(dArr));
    }

    public AColGroup copyAndSet(ADictionary aDictionary) {
        ColGroupValue colGroupValue = (ColGroupValue) clone();
        colGroupValue._dict = aDictionary;
        return colGroupValue;
    }

    public AColGroup copyAndSet(int[] iArr, double[] dArr) {
        return copyAndSet(iArr, new Dictionary(dArr));
    }

    public AColGroup copyAndSet(int[] iArr, ADictionary aDictionary) {
        ColGroupValue colGroupValue = (ColGroupValue) clone();
        colGroupValue._dict = aDictionary;
        colGroupValue.setColIndices(iArr);
        return colGroupValue;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public ColGroupValue copy() {
        return (ColGroupValue) clone();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected final AColGroup sliceSingleColumn(int i) {
        ColGroupValue copy = copy();
        copy._colIndexes = new int[]{0};
        if (copy._dict != null) {
            if (this._colIndexes.length == 1) {
                copy._dict = copy._dict.mo453clone();
            } else {
                copy._dict = copy._dict.sliceOutColumnRange(i, i + 1, this._colIndexes.length);
            }
        }
        return copy;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    protected final AColGroup sliceMultiColumns(int i, int i2, int[] iArr) {
        ColGroupValue copy = copy();
        copy._dict = copy._dict != null ? copy._dict.sliceOutColumnRange(i, i2, this._colIndexes.length) : null;
        copy._colIndexes = iArr;
        return copy;
    }

    public static final MatrixBlock allocatePreAggregate(MatrixBlock matrixBlock, int i, int i2, int i3) {
        int i4 = i3 - i2;
        return new MatrixBlock(i4, i, new DenseBlockFP64(new int[]{i4, i}, allocDVector(i4 * i, true)));
    }

    public abstract void preAggregate(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2);

    public abstract void preAggregateDense(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2, int i3, int i4);

    public final Dictionary preAggregateThatIndexStructure(ColGroupValue colGroupValue, boolean z) {
        Dictionary dictionary = new Dictionary(new double[colGroupValue._colIndexes.length * getNumValues()]);
        if (colGroupValue instanceof ColGroupDDC) {
            return preAggregateThatDDCStructure((ColGroupDDC) colGroupValue, dictionary);
        }
        if (colGroupValue instanceof ColGroupSDC) {
            return preAggregateThatSDCStructure((ColGroupSDC) colGroupValue, dictionary, z);
        }
        if (colGroupValue instanceof ColGroupSDCSingle) {
            return preAggregateThatSDCSingleStructure((ColGroupSDCSingle) colGroupValue, dictionary, z);
        }
        if (colGroupValue instanceof ColGroupSDCSingleZeros) {
            return preAggregateThatSDCSingleZerosStructure((ColGroupSDCSingleZeros) colGroupValue, dictionary);
        }
        if (colGroupValue instanceof ColGroupSDCZeros) {
            return preAggregateThatSDCZerosStructure((ColGroupSDCZeros) colGroupValue, dictionary);
        }
        throw new NotImplementedException("Not supported pre aggregate using index structure of :" + colGroupValue.getClass().getSimpleName() + " in " + getClass().getSimpleName());
    }

    protected int getIndexStructureHash() {
        throw new NotImplementedException("This base function should not be called");
    }

    protected Dictionary preAggregateThatDDCStructure(ColGroupDDC colGroupDDC, Dictionary dictionary) {
        throw new DMLCompressionException("Does not make sense to call this, implement function for sub class");
    }

    protected Dictionary preAggregateThatSDCStructure(ColGroupSDC colGroupSDC, Dictionary dictionary, boolean z) {
        throw new DMLCompressionException("Does not make sense to call this, implement function for sub class");
    }

    protected Dictionary preAggregateThatSDCZerosStructure(ColGroupSDCZeros colGroupSDCZeros, Dictionary dictionary) {
        throw new DMLCompressionException("Does not make sense to call this, implement function for sub class");
    }

    protected Dictionary preAggregateThatSDCSingleZerosStructure(ColGroupSDCSingleZeros colGroupSDCSingleZeros, Dictionary dictionary) {
        throw new DMLCompressionException("Does not make sense to call this, implement function for sub class");
    }

    protected Dictionary preAggregateThatSDCSingleStructure(ColGroupSDCSingle colGroupSDCSingle, Dictionary dictionary, boolean z) {
        throw new DMLCompressionException("Does not make sense to call this, implement function for sub class");
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void leftMultByAColGroup(AColGroup aColGroup, MatrixBlock matrixBlock) {
        if (aColGroup instanceof ColGroupEmpty) {
            return;
        }
        if (aColGroup instanceof ColGroupValue) {
            leftMultByColGroupValue((ColGroupValue) aColGroup, matrixBlock);
        } else {
            if (!(aColGroup instanceof ColGroupUncompressed)) {
                throw new DMLCompressionException("Not supported left multiplication with A ColGroup of type: " + aColGroup.getClass().getSimpleName());
            }
            leftMultByUncompressedColGroup((ColGroupUncompressed) aColGroup, matrixBlock);
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void tsmmAColGroup(AColGroup aColGroup, MatrixBlock matrixBlock) {
        double[] values;
        double[] values2;
        if (aColGroup instanceof ColGroupEmpty) {
            return;
        }
        if (aColGroup instanceof ColGroupValue) {
            ColGroupValue colGroupValue = (ColGroupValue) aColGroup;
            if (!shouldPreAggregateLeft(aColGroup)) {
                values = colGroupValue._dict.getValues();
                values2 = preAggLeft(colGroupValue).getValues();
            } else {
                values = preAggRight(colGroupValue).getValues();
                values2 = this._dict.getValues();
            }
            matrixMultDictionariesAndOutputToColIndexesDenseDenseUpperTriangle(values, values2, colGroupValue._colIndexes, this._colIndexes, matrixBlock);
            return;
        }
        if (!(aColGroup instanceof ColGroupUncompressed)) {
            throw new DMLCompressionException("Unsupported column group type " + aColGroup.getClass().getSimpleName());
        }
        LOG.warn("Inefficient multiplication with uncompressed column group");
        MatrixBlock transpose = LibMatrixReorg.transpose(((ColGroupUncompressed) aColGroup).getData());
        MatrixBlock matrixBlock2 = new MatrixBlock(transpose.getNumRows(), matrixBlock.getNumColumns(), false);
        matrixBlock2.allocateDenseBlock();
        leftMultByMatrix(transpose, matrixBlock2);
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        int numColumns = matrixBlock.getNumColumns();
        double[] denseBlockValues2 = matrixBlock.getDenseBlockValues();
        for (int i = 0; i < aColGroup._colIndexes.length; i++) {
            for (int i2 = 0; i2 < this._colIndexes.length; i2++) {
                addToUpperTriangle(numColumns, this._colIndexes[i2], aColGroup._colIndexes[i], denseBlockValues2, denseBlockValues[this._colIndexes[i2]]);
            }
        }
    }

    private boolean shouldPreAggregateLeft(AColGroup aColGroup) {
        return ((double) (getNumValues() * this._colIndexes.length)) < ((double) (aColGroup.getNumValues() * aColGroup._colIndexes.length));
    }

    private ADictionary preAggLeft(ColGroupValue colGroupValue) {
        return colGroupValue.preAggregateThatIndexStructure(this, false);
    }

    private ADictionary preAggRight(ColGroupValue colGroupValue) {
        return preAggregateThatIndexStructure(colGroupValue, false);
    }

    private void leftMultByColGroupValue(ColGroupValue colGroupValue, MatrixBlock matrixBlock) {
        if (sameIndexStructure(colGroupValue)) {
            if (this._dict == colGroupValue._dict) {
                tsmmDictionaryWithScaling(this._dict, getCounts(), colGroupValue._colIndexes, this._colIndexes, matrixBlock.getDenseBlockValues(), matrixBlock.getNumColumns());
                return;
            } else {
                matrixMultDictionariesAndOutputToColIndexesWithScaling(colGroupValue._dict, this._dict, colGroupValue._colIndexes, this._colIndexes, matrixBlock, getCounts());
                return;
            }
        }
        if (shouldPreAggregateLeft(colGroupValue)) {
            matrixMultDictionariesAndOutputToColIndexes(colGroupValue._dict, preAggLeft(colGroupValue), colGroupValue._colIndexes, this._colIndexes, matrixBlock);
        } else {
            matrixMultDictionariesAndOutputToColIndexes(preAggRight(colGroupValue), this._dict, colGroupValue._colIndexes, this._colIndexes, matrixBlock);
        }
    }

    private void leftMultByUncompressedColGroup(ColGroupUncompressed colGroupUncompressed, MatrixBlock matrixBlock) {
        MatrixBlock data = colGroupUncompressed.getData();
        LOG.warn("Inefficient transpose of uncompressed to fit to template need t(UnCompressedColGroup) %*% AColGroup support");
        MatrixBlock matrixBlock2 = new MatrixBlock(data.getNumColumns(), data.getNumRows(), data.isInSparseFormat());
        LibMatrixReorg.transpose(data, matrixBlock2, InfrastructureAnalyzer.getLocalParallelism());
        leftMultByMatrix(matrixBlock2, matrixBlock, colGroupUncompressed._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.ColGroupCompressed
    protected final void tsmm(double[] dArr, int i, int i2) {
        tsmm(dArr, i, getCounts(), this._dict, this._colIndexes);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final boolean containsValue(double d) {
        if (d == DataExpression.DEFAULT_DELIM_FILL_VALUE && this._zeros) {
            return true;
        }
        return this._dict.containsValue(d);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final long getNumberNonZeros(int i) {
        return this._dict.getNumberNonZeros(getCounts(), this._colIndexes.length);
    }

    private static void matrixMultDictionariesAndOutputToColIndexesWithScaling(ADictionary aDictionary, ADictionary aDictionary2, int[] iArr, int[] iArr2, MatrixBlock matrixBlock, int[] iArr3) {
        boolean z = aDictionary2.getInMemorySize() > aDictionary.getInMemorySize();
        matrixMultDictionariesAndOutputToColIndexes(z ? aDictionary : aDictionary.scaleTuples(iArr3, iArr.length), z ? aDictionary2.scaleTuples(iArr3, iArr2.length) : aDictionary2, iArr, iArr2, matrixBlock);
    }

    private static void tsmmDictionaryWithScaling(ADictionary aDictionary, int[] iArr, int[] iArr2, int[] iArr3, double[] dArr, int i) {
        if (!(aDictionary instanceof MatrixBlockDictionary)) {
            double[] values = aDictionary.getValues();
            for (int i2 = 0; i2 < iArr2.length; i2++) {
                int length = iArr3.length * i2;
                int i3 = i * iArr2[i2];
                for (int i4 = 0; i4 < iArr3.length; i4++) {
                    double d = values[length + i4] * iArr[i2];
                    if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                        for (int i5 = i4; i5 < iArr3.length; i5++) {
                            int i6 = i3 + iArr3[i4];
                            dArr[i6] = dArr[i6] + (d * values[length + i5]);
                        }
                    }
                }
            }
            return;
        }
        MatrixBlock matrixBlock = ((MatrixBlockDictionary) aDictionary).getMatrixBlock();
        if (matrixBlock.isEmpty()) {
            return;
        }
        if (!matrixBlock.isInSparseFormat()) {
            throw new NotImplementedException();
        }
        SparseBlock sparseBlock = matrixBlock.getSparseBlock();
        for (int i7 = 0; i7 < sparseBlock.numRows(); i7++) {
            if (!sparseBlock.isEmpty(i7)) {
                int pos = sparseBlock.pos(i7);
                int size = sparseBlock.size(i7);
                int[] indexes = sparseBlock.indexes(i7);
                double[] values2 = sparseBlock.values(i7);
                for (int i8 = pos; i8 < pos + size; i8++) {
                    int i9 = iArr2[indexes[i8]] * i;
                    double d2 = values2[i8] * iArr[i7];
                    for (int i10 = i8; i10 < pos + size; i10++) {
                        int i11 = i9 + iArr3[indexes[i10]];
                        dArr[i11] = dArr[i11] + (d2 * values2[i10]);
                    }
                }
            }
        }
    }

    protected static void matrixMultDictionariesAndOutputToColIndexes(ADictionary aDictionary, ADictionary aDictionary2, int[] iArr, int[] iArr2, MatrixBlock matrixBlock) {
        double[] values;
        double[] values2;
        try {
            if (aDictionary instanceof MatrixBlockDictionary) {
                MatrixBlock matrixBlock2 = aDictionary.getAsMatrixBlockDictionary(iArr.length).getMatrixBlock();
                if (matrixBlock2.isEmpty()) {
                    return;
                }
                if (aDictionary2 instanceof MatrixBlockDictionary) {
                    MatrixBlock matrixBlock3 = aDictionary2.getAsMatrixBlockDictionary(iArr2.length).getMatrixBlock();
                    if (matrixBlock3.isEmpty()) {
                        return;
                    }
                    if (matrixBlock3.isInSparseFormat() && matrixBlock2.isInSparseFormat()) {
                        throw new NotImplementedException("Not Supported sparse sparse dictionary multiplication");
                    }
                    if (matrixBlock3.isInSparseFormat()) {
                        matrixMultDictionariesAndOutputToColIndecesDenseSparse(matrixBlock2.getDenseBlockValues(), matrixBlock3.getSparseBlock(), iArr, iArr2, matrixBlock);
                        return;
                    } else if (matrixBlock2.isInSparseFormat()) {
                        matrixMultDictionariesAndOutputToColIndecesSparseDense(matrixBlock2.getSparseBlock(), matrixBlock3.getDenseBlockValues(), iArr, iArr2, matrixBlock);
                        return;
                    } else {
                        matrixMultDictionariesAndOutputToColIndexesDenseDense(matrixBlock2.getDenseBlockValues(), matrixBlock3.getDenseBlockValues(), iArr, iArr2, matrixBlock);
                        return;
                    }
                }
                if (matrixBlock2.isInSparseFormat()) {
                    matrixMultDictionariesAndOutputToColIndecesSparseDense(matrixBlock2.getSparseBlock(), aDictionary2.getValues(), iArr, iArr2, matrixBlock);
                    return;
                }
                values = matrixBlock2.getDenseBlockValues();
            } else {
                values = aDictionary.getValues();
            }
            if (aDictionary2 instanceof MatrixBlockDictionary) {
                MatrixBlock matrixBlock4 = aDictionary2.getAsMatrixBlockDictionary(iArr2.length).getMatrixBlock();
                if (matrixBlock4.isEmpty()) {
                    return;
                }
                if (matrixBlock4.isInSparseFormat()) {
                    matrixMultDictionariesAndOutputToColIndecesDenseSparse(values, matrixBlock4.getSparseBlock(), iArr, iArr2, matrixBlock);
                    return;
                }
                values2 = matrixBlock4.getDenseBlockValues();
            } else {
                values2 = aDictionary2.getValues();
            }
            if (values != null && values2 != null) {
                matrixMultDictionariesAndOutputToColIndexesDenseDense(values, values2, iArr, iArr2, matrixBlock);
            }
        } catch (Exception e) {
            throw new DMLCompressionException("MM of pre aggregated colGroups failed", e);
        }
    }

    private static void matrixMultDictionariesAndOutputToColIndexesDenseDense(double[] dArr, double[] dArr2, int[] iArr, int[] iArr2, MatrixBlock matrixBlock) {
        int min = Math.min(dArr.length / iArr.length, dArr2.length / iArr2.length);
        int numColumns = matrixBlock.getNumColumns();
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        for (int i = 0; i < min; i++) {
            int length = i * iArr.length;
            int length2 = i * iArr2.length;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                int i3 = iArr[i2] * numColumns;
                double d = dArr[length + i2];
                if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    for (int i4 = 0; i4 < iArr2.length; i4++) {
                        double d2 = dArr2[length2 + i4];
                        int i5 = i3 + iArr2[i4];
                        denseBlockValues[i5] = denseBlockValues[i5] + (d * d2);
                    }
                }
            }
        }
    }

    private static void matrixMultDictionariesAndOutputToColIndexesDenseDenseUpperTriangle(double[] dArr, double[] dArr2, int[] iArr, int[] iArr2, MatrixBlock matrixBlock) {
        int min = Math.min(dArr.length / iArr.length, dArr2.length / iArr2.length);
        int numColumns = matrixBlock.getNumColumns();
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        for (int i = 0; i < min; i++) {
            int length = i * iArr.length;
            int length2 = i * iArr2.length;
            for (int i2 = 0; i2 < iArr.length; i2++) {
                int i3 = iArr[i2];
                double d = dArr[length + i2];
                if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    for (int i4 = 0; i4 < iArr2.length; i4++) {
                        addToUpperTriangle(numColumns, i3, iArr2[i4], denseBlockValues, d * dArr2[length2 + i4]);
                    }
                }
            }
        }
    }

    private static void matrixMultDictionariesAndOutputToColIndecesSparseDense(SparseBlock sparseBlock, double[] dArr, int[] iArr, int[] iArr2, MatrixBlock matrixBlock) {
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        int min = Math.min(sparseBlock.numRows(), dArr.length / iArr2.length);
        for (int i = 0; i < min; i++) {
            if (!sparseBlock.isEmpty(i)) {
                int pos = sparseBlock.pos(i);
                int size = sparseBlock.size(i) + pos;
                int[] indexes = sparseBlock.indexes(i);
                double[] values = sparseBlock.values(i);
                int length = i * iArr2.length;
                for (int i2 = pos; i2 < size; i2++) {
                    int numColumns = iArr[indexes[i2]] * matrixBlock.getNumColumns();
                    double d = values[i2];
                    for (int i3 = 0; i3 < iArr2.length; i3++) {
                        int i4 = numColumns + iArr2[i3];
                        denseBlockValues[i4] = denseBlockValues[i4] + (d * dArr[length + i3]);
                    }
                }
            }
        }
    }

    private static void matrixMultDictionariesAndOutputToColIndecesDenseSparse(double[] dArr, SparseBlock sparseBlock, int[] iArr, int[] iArr2, MatrixBlock matrixBlock) {
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        int min = Math.min(dArr.length / iArr.length, sparseBlock.numRows());
        for (int i = 0; i < min; i++) {
            if (!sparseBlock.isEmpty(i)) {
                int pos = sparseBlock.pos(i);
                int size = sparseBlock.size(i) + pos;
                int[] indexes = sparseBlock.indexes(i);
                double[] values = sparseBlock.values(i);
                int length = i * iArr.length;
                for (int i2 = 0; i2 < iArr.length; i2++) {
                    int numColumns = iArr[i2] * matrixBlock.getNumColumns();
                    double d = dArr[length + i2];
                    if (d != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                        for (int i3 = pos; i3 < size; i3++) {
                            int i4 = numColumns + iArr2[indexes[i3]];
                            denseBlockValues[i4] = denseBlockValues[i4] + (d * values[i3]);
                        }
                    }
                }
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final boolean isDense() {
        return !this._zeros;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void leftMultByMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2) {
        try {
            MatrixBlock allocatePreAggregate = allocatePreAggregate(matrixBlock, getNumValues(), i, i2);
            preAggregate(matrixBlock, allocatePreAggregate, i, i2);
            allocatePreAggregate.recomputeNonZeros();
            addMatrixToResult(leftMultByPreAggregateMatrix(allocatePreAggregate), matrixBlock2, i, i2);
        } catch (Exception e) {
            throw new DMLCompressionException(getClass().getSimpleName() + " Failed to Left Matrix Multiply", e);
        }
    }

    public final MatrixBlock leftMultByPreAggregateMatrix(MatrixBlock matrixBlock) {
        MatrixBlock matrixBlock2;
        int length = this._colIndexes.length;
        int numRows = matrixBlock.getNumRows();
        double[] dArr = tmpLeftMultDoubleArray.get();
        if (dArr == null || dArr.length < length * numRows) {
            matrixBlock2 = new MatrixBlock(numRows, length, false);
        } else {
            matrixBlock2 = new MatrixBlock(numRows, length, new DenseBlockFP64(new int[]{numRows, length}, dArr));
            matrixBlock2.reset();
        }
        return leftMultByPreAggregateMatrix(matrixBlock, matrixBlock2);
    }

    public final MatrixBlock leftMultByPreAggregateMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2) {
        LibMatrixMult.matrixMult(matrixBlock, forceMatrixBlockDictionary().getMatrixBlock(), matrixBlock2);
        return matrixBlock2;
    }

    private void leftMultByMatrix(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int[] iArr) {
        try {
            MatrixBlock allocatePreAggregate = allocatePreAggregate(matrixBlock, getNumValues(), 0, matrixBlock.getNumRows());
            preAggregate(matrixBlock, allocatePreAggregate, 0, matrixBlock.getNumRows());
            allocatePreAggregate.recomputeNonZeros();
            addMatrixToResult(leftMultByPreAggregateMatrix(allocatePreAggregate), matrixBlock2, iArr);
        } catch (Exception e) {
            throw new DMLCompressionException(getClass().getSimpleName() + " Failed to multiply with an uncompressed column group", e);
        }
    }

    private MatrixBlockDictionary forceMatrixBlockDictionary() {
        if (!(this._dict instanceof MatrixBlockDictionary)) {
            this._dict = this._dict.getAsMatrixBlockDictionary(this._colIndexes.length);
        }
        return (MatrixBlockDictionary) this._dict;
    }

    public void addMatrixToResult(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int i, int i2) {
        if (matrixBlock.isEmpty()) {
            return;
        }
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        int numColumns = matrixBlock2.getNumColumns();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            int i3 = i;
            int i4 = 0;
            while (i3 < i2) {
                int pos = sparseBlock.pos(i4);
                int size = sparseBlock.size(i4);
                int[] indexes = sparseBlock.indexes(i4);
                double[] values = sparseBlock.values(i4);
                int i5 = i3 * numColumns;
                for (int i6 = pos; i6 < pos + size; i6++) {
                    int i7 = i5 + this._colIndexes[indexes[i6]];
                    denseBlockValues[i7] = denseBlockValues[i7] + values[i6];
                }
                i3++;
                i4++;
            }
            return;
        }
        double[] denseBlockValues2 = matrixBlock.getDenseBlockValues();
        int length = this._colIndexes.length;
        int i8 = i;
        int i9 = 0;
        while (true) {
            int i10 = i9;
            if (i8 >= i2) {
                return;
            }
            int i11 = i8 * numColumns;
            for (int i12 = 0; i12 < length; i12++) {
                int i13 = i11 + this._colIndexes[i12];
                denseBlockValues[i13] = denseBlockValues[i13] + denseBlockValues2[i10 + i12];
            }
            i8++;
            i9 = i10 + length;
        }
    }

    private void addMatrixToResult(MatrixBlock matrixBlock, MatrixBlock matrixBlock2, int[] iArr) {
        if (matrixBlock.isEmpty()) {
            return;
        }
        double[] denseBlockValues = matrixBlock2.getDenseBlockValues();
        int numColumns = matrixBlock2.getNumColumns();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            for (int i = 0; i < iArr.length; i++) {
                int pos = sparseBlock.pos(i);
                int size = sparseBlock.size(i);
                int[] indexes = sparseBlock.indexes(i);
                double[] values = sparseBlock.values(i);
                int i2 = iArr[i] * numColumns;
                for (int i3 = pos; i3 < pos + size; i3++) {
                    int i4 = i2 + this._colIndexes[indexes[i3]];
                    denseBlockValues[i4] = denseBlockValues[i4] + values[i3];
                }
            }
            return;
        }
        double[] denseBlockValues2 = matrixBlock.getDenseBlockValues();
        int length = this._colIndexes.length;
        int i5 = 0;
        int i6 = 0;
        while (true) {
            int i7 = i6;
            if (i5 >= iArr.length) {
                return;
            }
            int i8 = iArr[i5] * numColumns;
            for (int i9 = 0; i9 < length; i9++) {
                int i10 = i8 + this._colIndexes[i9];
                denseBlockValues[i10] = denseBlockValues[i10] + denseBlockValues2[i7 + i9];
            }
            i5++;
            i6 = i7 + length;
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final AColGroup rightMultByMatrix(MatrixBlock matrixBlock) {
        ADictionary preaggValuesFromDense;
        if (matrixBlock.isEmpty()) {
            return null;
        }
        int numColumns = matrixBlock.getNumColumns();
        int numValues = getNumValues();
        if (matrixBlock.isInSparseFormat()) {
            SparseBlock sparseBlock = matrixBlock.getSparseBlock();
            int[] aggregateColumnsSetSparse = getAggregateColumnsSetSparse(sparseBlock, numColumns);
            if (aggregateColumnsSetSparse.length == 0) {
                return null;
            }
            return copyAndSet(aggregateColumnsSetSparse, preaggValuesFromSparse(numValues, sparseBlock, aggregateColumnsSetSparse, 0, numColumns, numColumns));
        }
        double[] denseBlockValues = matrixBlock.getDenseBlockValues();
        int[] aggregateColumnsSetDense = getAggregateColumnsSetDense(denseBlockValues, 0, numColumns, numColumns);
        if (aggregateColumnsSetDense.length == 0 || (preaggValuesFromDense = this._dict.preaggValuesFromDense(numValues, this._colIndexes, aggregateColumnsSetDense, denseBlockValues, numColumns)) == null) {
            return null;
        }
        return copyAndSet(aggregateColumnsSetDense, preaggValuesFromDense);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + 8 + 8 + 4 + 1 + 1 + 2 + this._dict.getInMemorySize();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public AColGroup replace(double d, double d2) {
        return copyAndSet(this._dict.replace(d, d2, this._colIndexes.length));
    }

    public final int getNumRows() {
        return this._numRows;
    }

    protected abstract boolean sameIndexStructure(ColGroupCompressed colGroupCompressed);
}
