package org.apache.sysds.runtime.compress.estim.encoding;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AIterator;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.utils.IntArrayList;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.class */
public class SparseEncoding extends AEncode {
    protected final AMapToData map;
    protected final AOffset off;
    protected final int nRows;

    /* JADX INFO: Access modifiers changed from: protected */
    public SparseEncoding(AMapToData aMapToData, AOffset aOffset, int i) {
        this.map = aMapToData;
        this.off = aOffset;
        this.nRows = i;
        if (CompressedMatrixBlock.debug) {
            for (int i2 : aMapToData.getCounts()) {
                if (i2 == 0) {
                    throw new DMLCompressionException("Invalid counts in fact contains 0");
                }
            }
        }
    }

    @Override // org.apache.sysds.runtime.compress.estim.encoding.IEncode
    public IEncode combine(IEncode iEncode) {
        if ((iEncode instanceof EmptyEncoding) || (iEncode instanceof ConstEncoding)) {
            return this;
        }
        if (!(iEncode instanceof SparseEncoding)) {
            return iEncode.combine(this);
        }
        SparseEncoding sparseEncoding = (SparseEncoding) iEncode;
        return (sparseEncoding.off == this.off && sparseEncoding.map == this.map) ? this : combineSparse(sparseEncoding);
    }

    @Override // org.apache.sysds.runtime.compress.estim.encoding.IEncode
    public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode iEncode) {
        if ((iEncode instanceof EmptyEncoding) || (iEncode instanceof ConstEncoding)) {
            return new ImmutablePair(this, (Object) null);
        }
        if (!(iEncode instanceof SparseEncoding)) {
            throw new DMLCompressionException("Not allowed other to be dense");
        }
        SparseEncoding sparseEncoding = (SparseEncoding) iEncode;
        return (sparseEncoding.off == this.off && sparseEncoding.map == this.map) ? new ImmutablePair(this, (Object) null) : combineSparseNoResizeDense(sparseEncoding);
    }

    protected IEncode combineSparse(SparseEncoding sparseEncoding) {
        int unique = sparseEncoding.getUnique() * getUnique();
        int[] iArr = new int[unique - 1];
        int offsetToLast = this.off.getOffsetToLast();
        int offsetToLast2 = sparseEncoding.off.getOffsetToLast();
        AIterator iterator = this.off.getIterator();
        AIterator iterator2 = sparseEncoding.off.getIterator();
        int unique2 = getUnique();
        int unique3 = sparseEncoding.getUnique();
        int size = this.map.size();
        int size2 = sparseEncoding.map.size();
        if (size + size2 > this.nRows / 2) {
            return combineSparseToDense(this.map, sparseEncoding.map, iterator, iterator2, offsetToLast, offsetToLast2, unique2, unique3, iArr, this.nRows, unique);
        }
        IntArrayList intArrayList = new IntArrayList(Math.max(size2, size));
        IntArrayList intArrayList2 = new IntArrayList(Math.max(size2, size));
        int combineSparse = combineSparse(this.map, sparseEncoding.map, iterator, iterator2, intArrayList, intArrayList2, offsetToLast, offsetToLast2, unique2, unique3, iArr);
        if (intArrayList.size() < this.nRows / 4) {
            return new SparseEncoding(MapToFactory.create(intArrayList2.size(), intArrayList2.extractValues(), combineSparse - 1), OffsetFactory.createOffset(intArrayList), this.nRows);
        }
        AMapToData create = MapToFactory.create(this.nRows, combineSparse);
        for (int i = 0; i < intArrayList.size(); i++) {
            create.set(intArrayList.get(i), intArrayList2.get(i) + 1);
        }
        return new DenseEncoding(create);
    }

    private Pair<IEncode, Map<Integer, Integer>> combineSparseNoResizeDense(SparseEncoding sparseEncoding) {
        int offsetToLast = this.off.getOffsetToLast();
        int offsetToLast2 = sparseEncoding.off.getOffsetToLast();
        AIterator iterator = this.off.getIterator();
        AIterator iterator2 = sparseEncoding.off.getIterator();
        int unique = getUnique();
        AMapToData create = MapToFactory.create(this.nRows, (unique + 1) * (sparseEncoding.getUnique() + 1));
        int value = iterator.value();
        while (true) {
            int i = value;
            if (i >= offsetToLast) {
                break;
            }
            create.set(i, this.map.getIndex(iterator.getDataIndex()) + 1);
            value = iterator.next();
        }
        create.set(offsetToLast, this.map.getIndex(iterator.getDataIndex()) + 1);
        int value2 = iterator2.value();
        while (true) {
            int i2 = value2;
            if (i2 >= offsetToLast2) {
                break;
            }
            create.set(i2, create.getIndex(i2) + ((sparseEncoding.map.getIndex(iterator2.getDataIndex()) + 1) * unique));
            value2 = iterator2.next();
        }
        create.set(offsetToLast2, create.getIndex(offsetToLast2) + ((sparseEncoding.map.getIndex(iterator2.getDataIndex()) + 1) * unique));
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < create.size(); i3++) {
            addValHashMap(create.getIndex(i3), i3, hashMap, create);
        }
        return new ImmutablePair(new DenseEncoding(create.resize(hashMap.size())), hashMap);
    }

    protected static void addValHashMap(int i, int i2, Map<Integer, Integer> map, AMapToData aMapToData) {
        int size = map.size();
        Integer putIfAbsent = map.putIfAbsent(Integer.valueOf(i), Integer.valueOf(size));
        if (putIfAbsent == null) {
            aMapToData.set(i2, size);
        } else {
            aMapToData.set(i2, putIfAbsent);
        }
    }

    private static int combineSparse(AMapToData aMapToData, AMapToData aMapToData2, AIterator aIterator, AIterator aIterator2, IntArrayList intArrayList, IntArrayList intArrayList2, int i, int i2, int i3, int i4, int[] iArr) {
        int i5 = (i4 - 1) * i3;
        int i6 = i3 - 1;
        int i7 = 1;
        int value = aIterator.value();
        int value2 = aIterator2.value();
        while (value < i && value2 < i2) {
            if (value == value2) {
                i7 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + (aMapToData2.getIndex(aIterator2.getDataIndex()) * i3), value, iArr, i7, intArrayList2, intArrayList);
                value = aIterator.next();
                value2 = aIterator2.next();
            } else if (value < value2) {
                i7 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i5, value, iArr, i7, intArrayList2, intArrayList);
                value = aIterator.next();
            } else {
                i7 = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i6, value2, iArr, i7, intArrayList2, intArrayList);
                value2 = aIterator2.next();
            }
        }
        return combineSparseTail(aMapToData, aMapToData2, aIterator, aIterator2, intArrayList, intArrayList2, i, i2, i3, i4, iArr, i7);
    }

    private static int combineSparseTail(AMapToData aMapToData, AMapToData aMapToData2, AIterator aIterator, AIterator aIterator2, IntArrayList intArrayList, IntArrayList intArrayList2, int i, int i2, int i3, int i4, int[] iArr, int i5) {
        int addVal;
        int addVal2;
        int addVal3;
        int i6 = (i4 - 1) * i3;
        int i7 = i3 - 1;
        int value = aIterator.value();
        int value2 = aIterator2.value();
        if (value == i && value2 == i2) {
            if (i == i2) {
                return addVal(aMapToData.getIndex(aIterator.getDataIndex()) + (aMapToData2.getIndex(aIterator2.getDataIndex()) * i3), value, iArr, i5, intArrayList2, intArrayList);
            }
            if (i < i2) {
                addVal2 = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, i2, iArr, addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, i5, intArrayList2, intArrayList), intArrayList2, intArrayList);
            } else {
                addVal2 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, i2, iArr, i5, intArrayList2, intArrayList), intArrayList2, intArrayList);
            }
        } else if (value < i) {
            if (i < i2) {
                while (value < i) {
                    i5 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, i5, intArrayList2, intArrayList);
                    value = aIterator.next();
                }
                return addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, i2, iArr, addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, i5, intArrayList2, intArrayList), intArrayList2, intArrayList);
            }
            while (value < i2) {
                i5 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, i5, intArrayList2, intArrayList);
                value = aIterator.next();
            }
            if (i == i2) {
                return addVal(aMapToData.getIndex(aIterator.getDataIndex()) + (aMapToData2.getIndex(aIterator2.getDataIndex()) * i3), value, iArr, i5, intArrayList2, intArrayList);
            }
            if (value == i2) {
                addVal3 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + (aMapToData2.getIndex(aIterator2.getDataIndex()) * i3), value, iArr, i5, intArrayList2, intArrayList);
                value = aIterator.next();
            } else {
                addVal3 = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, i2, iArr, i5, intArrayList2, intArrayList);
            }
            while (value < i) {
                addVal3 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, addVal3, intArrayList2, intArrayList);
                value = aIterator.next();
            }
            addVal2 = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, value, iArr, addVal3, intArrayList2, intArrayList);
        } else {
            if (i2 < i) {
                while (value2 < i2) {
                    i5 = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, value2, iArr, i5, intArrayList2, intArrayList);
                    value2 = aIterator2.next();
                }
                return addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, i, iArr, addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, value2, iArr, i5, intArrayList2, intArrayList), intArrayList2, intArrayList);
            }
            while (value2 < i) {
                i5 = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, value2, iArr, i5, intArrayList2, intArrayList);
                value2 = aIterator2.next();
            }
            if (i2 == i) {
                return addVal(aMapToData.getIndex(aIterator.getDataIndex()) + (aMapToData2.getIndex(aIterator2.getDataIndex()) * i3), value2, iArr, i5, intArrayList2, intArrayList);
            }
            if (value2 == i) {
                addVal = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + (aMapToData2.getIndex(aIterator2.getDataIndex()) * i3), value2, iArr, i5, intArrayList2, intArrayList);
                value2 = aIterator2.next();
            } else {
                addVal = addVal(aMapToData.getIndex(aIterator.getDataIndex()) + i6, i, iArr, i5, intArrayList2, intArrayList);
            }
            while (value2 < i2) {
                addVal = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, value2, iArr, addVal, intArrayList2, intArrayList);
                value2 = aIterator2.next();
            }
            addVal2 = addVal((aMapToData2.getIndex(aIterator2.getDataIndex()) * i3) + i7, value2, iArr, addVal, intArrayList2, intArrayList);
        }
        return addVal2;
    }

    private static int addVal(int i, int i2, int[] iArr, int i3, IntArrayList intArrayList, IntArrayList intArrayList2) {
        int i4 = iArr[i];
        if (i4 == 0) {
            i3++;
            iArr[i] = i3;
            i4 = i3;
        }
        intArrayList.appendValue(i4 - 1);
        intArrayList2.appendValue(i2);
        return i3;
    }

    private static DenseEncoding combineSparseToDense(AMapToData aMapToData, AMapToData aMapToData2, AIterator aIterator, AIterator aIterator2, int i, int i2, int i3, int i4, int[] iArr, int i5, int i6) {
        AMapToData create = MapToFactory.create(i5, (i3 + 1) * (i4 + 1));
        int value = aIterator.value();
        while (true) {
            int i7 = value;
            if (i7 >= i) {
                break;
            }
            create.set(i7, aMapToData.getIndex(aIterator.getDataIndex()) + 1);
            value = aIterator.next();
        }
        create.set(i, aMapToData.getIndex(aIterator.getDataIndex()) + 1);
        int value2 = aIterator2.value();
        while (true) {
            int i8 = value2;
            if (i8 >= i2) {
                break;
            }
            create.set(i8, create.getIndex(i8) + ((aMapToData2.getIndex(aIterator2.getDataIndex()) + 1) * i3));
            value2 = aIterator2.next();
        }
        create.set(i2, create.getIndex(i2) + ((aMapToData2.getIndex(aIterator2.getDataIndex()) + 1) * i3));
        AMapToData create2 = MapToFactory.create(i6, i6 + 1);
        int i9 = 1;
        for (int i10 = 0; i10 < create.size(); i10++) {
            int index = create.getIndex(i10);
            int index2 = create2.getIndex(index);
            if (index2 == 0) {
                int i11 = i9;
                i9++;
                index2 = create2.setAndGet(index, i11);
            }
            create.set(i10, index2 - 1);
        }
        create.setUnique(i9 - 1);
        return new DenseEncoding(create);
    }

    @Override // org.apache.sysds.runtime.compress.estim.encoding.IEncode
    public int getUnique() {
        return this.map.getUnique() + 1;
    }

    @Override // org.apache.sysds.runtime.compress.estim.encoding.IEncode
    public EstimationFactors extractFacts(int i, double d, double d2, CompressionSettings compressionSettings) {
        int size = i - this.map.size();
        double min = Math.min(this.map.size() / i, d);
        int[] counts = this.map.getCounts();
        return compressionSettings.isRLEAllowed() ? new EstimationFactors(this.map.getUnique(), this.map.size(), size, counts, 0, i, this.map.countRuns(this.off), false, true, d2, min) : new EstimationFactors(this.map.getUnique(), this.map.size(), size, counts, 0, i, false, true, d2, min);
    }

    @Override // org.apache.sysds.runtime.compress.estim.encoding.IEncode
    public boolean isDense() {
        return false;
    }

    public AOffset getOffsets() {
        return this.off;
    }

    public AMapToData getMap() {
        return this.map;
    }

    public int getNumRows() {
        return this.nRows;
    }

    @Override // org.apache.sysds.runtime.compress.estim.encoding.IEncode
    public boolean equals(IEncode iEncode) {
        return (iEncode instanceof SparseEncoding) && ((SparseEncoding) iEncode).off.equals(this.off) && ((SparseEncoding) iEncode).map.equals(this.map);
    }

    public String toString() {
        return getClass().getSimpleName() + ProgramConverter.NEWLINE + "mapping: " + this.map + ProgramConverter.NEWLINE + "offsets: " + this.off;
    }
}
