package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashSet;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/compress/colgroup/ADictBasedColGroup.class */
public abstract class ADictBasedColGroup extends AColGroupCompressed implements IContainADictionary {
    private static final long serialVersionUID = -3737025296618703668L;
    protected final IDictionary _dict;

    /* JADX INFO: Access modifiers changed from: protected */
    public ADictBasedColGroup(IColIndex iColIndex, IDictionary iDictionary) {
        super(iColIndex);
        this._dict = iDictionary;
        if (iDictionary == null) {
            throw new NullPointerException("null dict is invalid");
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.IContainADictionary
    public IDictionary getDictionary() {
        return this._dict;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void decompressToDenseBlock(DenseBlock denseBlock, int i, int i2, int i3, int i4) {
        if (this._dict instanceof IdentityDictionary) {
            MatrixBlock matrixBlock = ((IdentityDictionary) this._dict).getMBDict().getMatrixBlock();
            if (matrixBlock.isInSparseFormat()) {
                decompressToDenseBlockSparseDictionary(denseBlock, i, i2, i3, i4, matrixBlock.getSparseBlock());
                return;
            } else {
                decompressToDenseBlockDenseDictionary(denseBlock, i, i2, i3, i4, matrixBlock.getDenseBlockValues());
                return;
            }
        }
        if (!(this._dict instanceof MatrixBlockDictionary)) {
            decompressToDenseBlockDenseDictionary(denseBlock, i, i2, i3, i4, this._dict.getValues());
            return;
        }
        MatrixBlock matrixBlock2 = ((MatrixBlockDictionary) this._dict).getMatrixBlock();
        if (matrixBlock2.isInSparseFormat()) {
            decompressToDenseBlockSparseDictionary(denseBlock, i, i2, i3, i4, matrixBlock2.getSparseBlock());
        } else {
            decompressToDenseBlockDenseDictionary(denseBlock, i, i2, i3, i4, matrixBlock2.getDenseBlockValues());
        }
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final void decompressToSparseBlock(SparseBlock sparseBlock, int i, int i2, int i3, int i4) {
        if (this._dict instanceof IdentityDictionary) {
            MatrixBlock matrixBlock = ((IdentityDictionary) this._dict).getMBDict().getMatrixBlock();
            if (matrixBlock.isInSparseFormat()) {
                decompressToSparseBlockSparseDictionary(sparseBlock, i, i2, i3, i4, matrixBlock.getSparseBlock());
                return;
            } else {
                decompressToSparseBlockDenseDictionary(sparseBlock, i, i2, i3, i4, matrixBlock.getDenseBlockValues());
                return;
            }
        }
        if (!(this._dict instanceof MatrixBlockDictionary)) {
            decompressToSparseBlockDenseDictionary(sparseBlock, i, i2, i3, i4, this._dict.getValues());
            return;
        }
        MatrixBlock matrixBlock2 = ((MatrixBlockDictionary) this._dict).getMatrixBlock();
        if (matrixBlock2.isInSparseFormat()) {
            decompressToSparseBlockSparseDictionary(sparseBlock, i, i2, i3, i4, matrixBlock2.getSparseBlock());
        } else {
            decompressToSparseBlockDenseDictionary(sparseBlock, i, i2, i3, i4, matrixBlock2.getDenseBlockValues());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract void decompressToDenseBlockSparseDictionary(DenseBlock denseBlock, int i, int i2, int i3, int i4, SparseBlock sparseBlock);

    protected abstract void decompressToDenseBlockDenseDictionary(DenseBlock denseBlock, int i, int i2, int i3, int i4, double[] dArr);

    protected abstract void decompressToSparseBlockSparseDictionary(SparseBlock sparseBlock, int i, int i2, int i3, int i4, SparseBlock sparseBlock2);

    protected abstract void decompressToSparseBlockDenseDictionary(SparseBlock sparseBlock, int i, int i2, int i3, int i4, double[] dArr);

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public void write(DataOutput dataOutput) throws IOException {
        super.write(dataOutput);
        this._dict.write(dataOutput);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long getExactSizeOnDisk() {
        return super.getExactSizeOnDisk() + this._dict.getExactSizeOnDisk();
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public long estimateInMemorySize() {
        return super.estimateInMemorySize() + this._dict.getInMemorySize() + 8;
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final AColGroup rightMultByMatrix(MatrixBlock matrixBlock, IColIndex iColIndex) {
        if (matrixBlock.isEmpty()) {
            return null;
        }
        int numColumns = matrixBlock.getNumColumns();
        IColIndex create = iColIndex == null ? ColIndexFactory.create(numColumns) : iColIndex;
        IColIndex rightMMGetColsSparse = matrixBlock.isInSparseFormat() ? rightMMGetColsSparse(matrixBlock.getSparseBlock(), numColumns, create) : rightMMGetColsDense(matrixBlock.getDenseBlockValues(), numColumns, create, matrixBlock.getNonZeros());
        if (rightMMGetColsSparse == null) {
            return null;
        }
        int numValues = getNumValues();
        return allocateRightMultiplication(matrixBlock, rightMMGetColsSparse, matrixBlock.isInSparseFormat() ? rightMMPreAggSparse(numValues, matrixBlock.getSparseBlock(), rightMMGetColsSparse, 0, numColumns) : this._dict.preaggValuesFromDense(numValues, this._colIndexes, rightMMGetColsSparse, matrixBlock.getDenseBlockValues(), numColumns));
    }

    protected abstract AColGroup allocateRightMultiplication(MatrixBlock matrixBlock, IColIndex iColIndex, IDictionary iDictionary);

    protected IColIndex rightMMGetColsDense(double[] dArr, int i, IColIndex iColIndex, long j) {
        if (i > 200 || j > dArr.length * 0.7d) {
            return iColIndex;
        }
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this._colIndexes.size(); i2++) {
            int i3 = this._colIndexes.get(i2) * i;
            for (int i4 = 0; i4 < i; i4++) {
                if (dArr[i3 + i4] != DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                    hashSet.add(Integer.valueOf(i4));
                }
            }
        }
        if (hashSet.size() == i) {
            return iColIndex;
        }
        if (hashSet.size() == 0) {
            return null;
        }
        int[] array = hashSet.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
        Arrays.sort(array);
        return ColIndexFactory.create(array);
    }

    protected IColIndex rightMMGetColsSparse(SparseBlock sparseBlock, int i, IColIndex iColIndex) {
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < this._colIndexes.size(); i2++) {
            int i3 = this._colIndexes.get(i2);
            if (!sparseBlock.isEmpty(i3)) {
                int[] indexes = sparseBlock.indexes(i3);
                for (int pos = sparseBlock.pos(i3); pos < sparseBlock.size(i3) + sparseBlock.pos(i3); pos++) {
                    hashSet.add(Integer.valueOf(indexes[pos]));
                }
            }
            if (hashSet.size() == i) {
                return iColIndex;
            }
        }
        if (hashSet.size() == 0) {
            return null;
        }
        int[] array = hashSet.stream().mapToInt(num -> {
            return num.intValue();
        }).toArray();
        Arrays.sort(array);
        return ColIndexFactory.create(array);
    }

    private IDictionary rightMMPreAggSparse(int i, SparseBlock sparseBlock, IColIndex iColIndex, int i2, int i3) {
        double[] dArr = new double[i * iColIndex.size()];
        for (int i4 = 0; i4 < this._colIndexes.size(); i4++) {
            int i5 = this._colIndexes.get(i4);
            if (!sparseBlock.isEmpty(i5)) {
                double[] values = sparseBlock.values(i5);
                int[] indexes = sparseBlock.indexes(i5);
                int i6 = 0;
                for (int pos = sparseBlock.pos(i5); pos < sparseBlock.size(i5) + sparseBlock.pos(i5); pos++) {
                    while (iColIndex.get(i6) < indexes[pos]) {
                        i6++;
                    }
                    int i7 = 0;
                    int i8 = i4;
                    while (true) {
                        int i9 = i8;
                        if (i7 < i * iColIndex.size()) {
                            int i10 = i7 + i6;
                            dArr[i10] = dArr[i10] + (this._dict.getValue(i9) * values[pos]);
                            i7 += iColIndex.size();
                            i8 = i9 + this._colIndexes.size();
                        }
                    }
                }
            }
        }
        return Dictionary.create(dArr);
    }

    @Override // org.apache.sysds.runtime.compress.colgroup.AColGroup
    public final AColGroup copyAndSet(IColIndex iColIndex) {
        return copyAndSet(iColIndex, this._dict);
    }

    public final AColGroup copyAndSet(IDictionary iDictionary) {
        return copyAndSet(this._colIndexes, iDictionary);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public abstract AColGroup copyAndSet(IColIndex iColIndex, IDictionary iDictionary);
}
