package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.utils.DMLCompressionStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibDecompress.class */
public final class CLALibDecompress {
    private static final Log LOG = LogFactory.getLog(CLALibDecompress.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibDecompress$DecompressDenseSingleColTask.class */
    public static class DecompressDenseSingleColTask implements Callable<Long> {
        private final AColGroup _grp;
        private final MatrixBlock _ret;
        private final double _eps;
        private final int _rl;
        private final int _ru;
        private final int _blklen;
        private final double[] _constV;

        protected DecompressDenseSingleColTask(AColGroup aColGroup, MatrixBlock matrixBlock, double d, int i, int i2, double[] dArr) {
            this._grp = aColGroup;
            this._ret = matrixBlock;
            this._eps = d;
            this._rl = i;
            this._ru = i2;
            this._blklen = 32768 / matrixBlock.getNumColumns();
            this._constV = dArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() {
            try {
                int i = this._rl;
                while (i < this._ru) {
                    int min = Math.min(i + this._blklen, this._ru);
                    this._grp.decompressToDenseBlock(this._ret.getDenseBlock(), i, min);
                    if (this._constV != null) {
                        CLALibDecompress.addVector(this._ret, this._constV, this._eps, i, min);
                    }
                    i += this._blklen;
                }
                return 0L;
            } catch (Exception e) {
                e.printStackTrace();
                throw new DMLCompressionException("Failed dense decompression", e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibDecompress$DecompressDenseTask.class */
    public static class DecompressDenseTask implements Callable<Long> {
        private final List<AColGroup> _colGroups;
        private final MatrixBlock _ret;
        private final double _eps;
        private final int _rl;
        private final int _ru;
        private final int _blklen;
        private final double[] _constV;

        protected DecompressDenseTask(List<AColGroup> list, MatrixBlock matrixBlock, double d, int i, int i2, double[] dArr) {
            this._colGroups = list;
            this._ret = matrixBlock;
            this._eps = d;
            this._rl = i;
            this._ru = i2;
            this._blklen = 32768 / matrixBlock.getNumColumns();
            this._constV = dArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() {
            try {
                long j = 0;
                int i = this._rl;
                while (i < this._ru) {
                    int min = Math.min(i + this._blklen, this._ru);
                    Iterator<AColGroup> it = this._colGroups.iterator();
                    while (it.hasNext()) {
                        it.next().decompressToDenseBlock(this._ret.getDenseBlock(), i, min);
                    }
                    if (this._constV != null) {
                        CLALibDecompress.addVector(this._ret, this._constV, this._eps, i, min);
                    }
                    j += this._ret.recomputeNonZeros(i, min - 1);
                    i += this._blklen;
                }
                return Long.valueOf(j);
            } catch (Exception e) {
                e.printStackTrace();
                throw new DMLCompressionException("Failed dense decompression", e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibDecompress$DecompressSparseTask.class */
    public static class DecompressSparseTask implements Callable<Object> {
        private final List<AColGroup> _colGroups;
        private final MatrixBlock _ret;
        private final int _rl;
        private final int _ru;

        protected DecompressSparseTask(List<AColGroup> list, MatrixBlock matrixBlock, int i, int i2) {
            this._colGroups = list;
            this._ret = matrixBlock;
            this._rl = i;
            this._ru = i2;
        }

        @Override // java.util.concurrent.Callable
        public Object call() {
            SparseBlock sparseBlock = this._ret.getSparseBlock();
            Iterator<AColGroup> it = this._colGroups.iterator();
            while (it.hasNext()) {
                it.next().decompressToSparseBlock(this._ret.getSparseBlock(), this._rl, this._ru);
            }
            for (int i = this._rl; i < this._ru; i++) {
                if (!sparseBlock.isEmpty(i)) {
                    sparseBlock.sort(i);
                }
            }
            return null;
        }
    }

    private CLALibDecompress() {
    }

    public static MatrixBlock decompress(CompressedMatrixBlock compressedMatrixBlock, int i) {
        Timing timing = new Timing(true);
        MatrixBlock decompressExecute = decompressExecute(compressedMatrixBlock, i);
        if (DMLScript.STATISTICS) {
            double stop = timing.stop();
            DMLCompressionStatistics.addDecompressTime(stop, i);
            if (LOG.isTraceEnabled()) {
                LOG.trace("decompressed block w/ k=" + i + " in " + stop + "ms.");
            }
        }
        return decompressExecute;
    }

    /* JADX WARN: Type inference failed for: r0v6, types: [org.apache.sysds.runtime.matrix.data.MatrixBlock] */
    public static void decompressTo(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i, int i2, int i3, boolean z) {
        Timing timing = new Timing(true);
        if (compressedMatrixBlock.getNumColumns() + i2 > matrixBlock.getNumColumns() || compressedMatrixBlock.getNumRows() + i > matrixBlock.getNumRows()) {
            LOG.warn("Slow slicing off excess parts for decompressTo because decompression into is implemented for fitting blocks");
            compressedMatrixBlock.slice2(Math.min(Math.abs(i), 0), Math.min(compressedMatrixBlock.getNumRows(), matrixBlock.getNumRows() - i) - 1, Math.min(Math.abs(i2), 0), Math.min(compressedMatrixBlock.getNumColumns(), matrixBlock.getNumColumns() - i2) - 1).putInto(matrixBlock, i, i2, false);
            return;
        }
        boolean isInSparseFormat = matrixBlock.isInSparseFormat();
        if (!compressedMatrixBlock.isEmpty()) {
            if (isInSparseFormat && compressedMatrixBlock.isOverlapping()) {
                throw new DMLCompressionException("Not supported decompression into sparse block from overlapping state");
            }
            if (isInSparseFormat) {
                decompressToSparseBlock(compressedMatrixBlock, matrixBlock, i, i2);
            } else {
                decompressToDenseBlock(compressedMatrixBlock, matrixBlock.getDenseBlock(), i, i2);
            }
        }
        if (DMLScript.STATISTICS) {
            double stop = timing.stop();
            DMLCompressionStatistics.addDecompressToBlockTime(stop, i3);
            if (LOG.isTraceEnabled()) {
                LOG.trace("decompressed block w/ k=" + i3 + " in " + stop + "ms.");
            }
        }
        if (z) {
            matrixBlock.recomputeNonZeros();
        }
    }

    private static void decompressToSparseBlock(CompressedMatrixBlock compressedMatrixBlock, MatrixBlock matrixBlock, int i, int i2) {
        SparseBlock sparseBlock = matrixBlock.getSparseBlock();
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        int numRows = compressedMatrixBlock.getNumRows();
        if (CLALibUtils.shouldPreFilter(colGroups)) {
            compressedMatrixBlock.getUncompressed("Decompression to put into Sparse Block").putInto(matrixBlock, i, i2, false);
        } else {
            Iterator<AColGroup> it = colGroups.iterator();
            while (it.hasNext()) {
                it.next().decompressToSparseBlock(sparseBlock, 0, numRows, i, i2);
            }
        }
        sparseBlock.sort();
        matrixBlock.checkSparseRows();
    }

    private static void decompressToDenseBlock(CompressedMatrixBlock compressedMatrixBlock, DenseBlock denseBlock, int i, int i2) {
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        int numRows = compressedMatrixBlock.getNumRows();
        if (!CLALibUtils.shouldPreFilter(colGroups)) {
            Iterator<AColGroup> it = colGroups.iterator();
            while (it.hasNext()) {
                it.next().decompressToDenseBlock(denseBlock, 0, numRows, i, i2);
            }
        } else {
            double[] dArr = new double[compressedMatrixBlock.getNumColumns()];
            Iterator<AColGroup> it2 = CLALibUtils.filterGroups(colGroups, dArr).iterator();
            while (it2.hasNext()) {
                it2.next().decompressToDenseBlock(denseBlock, 0, numRows, i, i2);
            }
            ColGroupConst.create(dArr).decompressToDenseBlock(denseBlock, 0, numRows, i, i2);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static MatrixBlock decompressExecute(CompressedMatrixBlock compressedMatrixBlock, int i) {
        if (compressedMatrixBlock.isEmpty()) {
            return new MatrixBlock(compressedMatrixBlock.getNumRows(), compressedMatrixBlock.getNumColumns(), true);
        }
        List arrayList = new ArrayList(compressedMatrixBlock.getColGroups());
        int numRows = compressedMatrixBlock.getNumRows();
        int numColumns = compressedMatrixBlock.getNumColumns();
        boolean isOverlapping = compressedMatrixBlock.isOverlapping();
        long nonZeros = compressedMatrixBlock.getNonZeros();
        MatrixBlock uncompressedColGroupAndRemoveFromListOfColGroups = getUncompressedColGroupAndRemoveFromListOfColGroups(arrayList, isOverlapping, numRows, numColumns);
        if (uncompressedColGroupAndRemoveFromListOfColGroups != null && arrayList.size() == 0) {
            uncompressedColGroupAndRemoveFromListOfColGroups.setNonZeros(uncompressedColGroupAndRemoveFromListOfColGroups.recomputeNonZeros());
            return uncompressedColGroupAndRemoveFromListOfColGroups;
        }
        boolean shouldPreFilterMorphOrRef = CLALibUtils.shouldPreFilterMorphOrRef(arrayList);
        double[] dArr = shouldPreFilterMorphOrRef ? new double[numColumns] : null;
        List filterGroups = shouldPreFilterMorphOrRef ? CLALibUtils.filterGroups(arrayList, dArr) : arrayList;
        if (uncompressedColGroupAndRemoveFromListOfColGroups == null) {
            boolean z = (shouldPreFilterMorphOrRef || isOverlapping || !MatrixBlock.evalSparseFormatInMemory((long) numRows, (long) numColumns, nonZeros)) ? false : true;
            uncompressedColGroupAndRemoveFromListOfColGroups = new MatrixBlock(numRows, numColumns, z);
            if (z) {
                uncompressedColGroupAndRemoveFromListOfColGroups.allocateSparseRowsBlock();
            } else {
                uncompressedColGroupAndRemoveFromListOfColGroups.allocateDenseBlock();
            }
            if (MatrixBlock.evalSparseFormatInMemory(numRows, numColumns, nonZeros) && !z) {
                LOG.warn("Decompressing into dense but reallocating after to sparse: overlapping - " + isOverlapping + ", filter - " + shouldPreFilterMorphOrRef);
            }
        }
        int max = Math.max(numRows / i, 512);
        if (arrayList == filterGroups) {
            dArr = null;
        }
        double eps = getEps(dArr);
        if (i == 1) {
            if (uncompressedColGroupAndRemoveFromListOfColGroups.isInSparseFormat()) {
                decompressSparseSingleThread(uncompressedColGroupAndRemoveFromListOfColGroups, filterGroups, numRows, max);
            } else {
                decompressDenseSingleThread(uncompressedColGroupAndRemoveFromListOfColGroups, filterGroups, numRows, max, dArr, eps, nonZeros, isOverlapping);
            }
        } else if (uncompressedColGroupAndRemoveFromListOfColGroups.isInSparseFormat()) {
            decompressSparseMultiThread(uncompressedColGroupAndRemoveFromListOfColGroups, filterGroups, numRows, max, i);
        } else {
            decompressDenseMultiThread(uncompressedColGroupAndRemoveFromListOfColGroups, filterGroups, numRows, max, dArr, eps, i, isOverlapping);
        }
        uncompressedColGroupAndRemoveFromListOfColGroups.recomputeNonZeros();
        uncompressedColGroupAndRemoveFromListOfColGroups.examSparsity();
        return uncompressedColGroupAndRemoveFromListOfColGroups;
    }

    private static MatrixBlock getUncompressedColGroupAndRemoveFromListOfColGroups(List<AColGroup> list, boolean z, int i, int i2) {
        if (z || list.size() == 1) {
            for (int i3 = 0; i3 < list.size(); i3++) {
                AColGroup aColGroup = list.get(i3);
                if (aColGroup instanceof ColGroupUncompressed) {
                    MatrixBlock data = ((ColGroupUncompressed) aColGroup).getData();
                    if (data.getNumColumns() == i2 && data.getNumRows() == i && (!data.isInSparseFormat() || list.size() == 1)) {
                        list.remove(i3);
                        LOG.debug("Using one of the uncompressed ColGroups as base for decompression");
                        return data;
                    }
                }
            }
        }
        return null;
    }

    private static void decompressSparseSingleThread(MatrixBlock matrixBlock, List<AColGroup> list, int i, int i2) {
        SparseBlock sparseBlock = matrixBlock.getSparseBlock();
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= i) {
                return;
            }
            int min = Math.min(i4 + i2, i);
            Iterator<AColGroup> it = list.iterator();
            while (it.hasNext()) {
                it.next().decompressToSparseBlock(matrixBlock.getSparseBlock(), i4, min);
            }
            for (int i5 = i4; i5 < min; i5++) {
                if (!sparseBlock.isEmpty(i5)) {
                    sparseBlock.sort(i5);
                }
            }
            i3 = i4 + i2;
        }
    }

    private static void decompressDenseSingleThread(MatrixBlock matrixBlock, List<AColGroup> list, int i, int i2, double[] dArr, double d, long j, boolean z) {
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= i) {
                return;
            }
            int min = Math.min(i4 + i2, i);
            Iterator<AColGroup> it = list.iterator();
            while (it.hasNext()) {
                it.next().decompressToDenseBlock(matrixBlock.getDenseBlock(), i4, min);
            }
            if (dArr != null && !matrixBlock.isInSparseFormat()) {
                addVector(matrixBlock, dArr, d, i4, min);
            }
            i3 = i4 + i2;
        }
    }

    protected static void decompressDenseMultiThread(MatrixBlock matrixBlock, List<AColGroup> list, double[] dArr, int i, boolean z) {
        int numRows = matrixBlock.getNumRows();
        decompressDenseMultiThread(matrixBlock, list, numRows, Math.max(numRows / i, 512), dArr, getEps(dArr), i, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void decompressDenseMultiThread(MatrixBlock matrixBlock, List<AColGroup> list, double[] dArr, double d, int i, boolean z) {
        int numRows = matrixBlock.getNumRows();
        decompressDenseMultiThread(matrixBlock, list, numRows, Math.max(numRows / i, 512), dArr, d, i, z);
    }

    private static void decompressDenseMultiThread(MatrixBlock matrixBlock, List<AColGroup> list, int i, int i2, double[] dArr, double d, int i3, boolean z) {
        ExecutorService executorService = CommonThreadPool.get(i3);
        try {
            ArrayList arrayList = new ArrayList();
            if (z || dArr != null) {
                int i4 = 0;
                while (i4 < i) {
                    arrayList.add(new DecompressDenseTask(list, matrixBlock, d, i4, Math.min(i4 + i2, i), dArr));
                    i4 += i2;
                }
            } else {
                int i5 = 0;
                while (i5 < i) {
                    Iterator<AColGroup> it = list.iterator();
                    while (it.hasNext()) {
                        arrayList.add(new DecompressDenseSingleColTask(it.next(), matrixBlock, d, i5, Math.min(i5 + i2, i), null));
                    }
                    i5 += i2;
                }
            }
            long j = 0;
            Iterator it2 = executorService.invokeAll(arrayList).iterator();
            while (it2.hasNext()) {
                j += ((Long) ((Future) it2.next()).get()).longValue();
            }
            matrixBlock.setNonZeros(j);
            executorService.shutdown();
        } catch (InterruptedException | ExecutionException e) {
            throw new DMLCompressionException("Parallel decompression failed", e);
        }
    }

    private static void decompressSparseMultiThread(MatrixBlock matrixBlock, List<AColGroup> list, int i, int i2, int i3) {
        ExecutorService executorService = CommonThreadPool.get(i3);
        try {
            ArrayList arrayList = new ArrayList();
            int i4 = 0;
            while (i4 < i) {
                arrayList.add(new DecompressSparseTask(list, matrixBlock, i4, Math.min(i4 + i2, i)));
                i4 += i2;
            }
            Iterator it = executorService.invokeAll(arrayList).iterator();
            while (it.hasNext()) {
                ((Future) it.next()).get();
            }
            executorService.shutdown();
        } catch (InterruptedException | ExecutionException e) {
            executorService.shutdown();
            throw new DMLCompressionException("Parallel decompression failed", e);
        }
    }

    private static double getEps(double[] dArr) {
        if (dArr == null) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        double d = -1.7976931348623157E308d;
        double d2 = Double.MAX_VALUE;
        for (double d3 : dArr) {
            if (d3 > d) {
                d = d3;
            }
            if (d3 < d2) {
                d2 = d3;
            }
        }
        return ((d + 1.0E-4d) - d2) * 1.0E-10d;
    }

    private static void addVector(MatrixBlock matrixBlock, double[] dArr, double d, int i, int i2) {
        int numColumns = matrixBlock.getNumColumns();
        DenseBlock denseBlock = matrixBlock.getDenseBlock();
        if (numColumns == 1) {
            if (d == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                addValue(denseBlock.values(0), dArr[0], i, i2);
                return;
            } else {
                addValueEps(denseBlock.values(0), dArr[0], d, i, i2);
                return;
            }
        }
        if (denseBlock.isContiguous()) {
            if (d == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
                addVectorContiguousNoEps(denseBlock.values(0), dArr, numColumns, i, i2);
                return;
            } else {
                addVectorContiguousEps(denseBlock.values(0), dArr, numColumns, d, i, i2);
                return;
            }
        }
        if (d == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            addVectorNoEps(denseBlock, dArr, numColumns, i, i2);
        } else {
            addVectorEps(denseBlock, dArr, numColumns, d, i, i2);
        }
    }

    private static void addValue(double[] dArr, double d, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + d;
        }
    }

    private static void addValueEps(double[] dArr, double d, double d2, int i, int i2) {
        for (int i3 = i; i3 < i2; i3++) {
            double d3 = dArr[i3] + d;
            if (Math.abs(d3) <= d2) {
                dArr[i3] = 0.0d;
            } else {
                dArr[i3] = d3;
            }
        }
    }

    private static void addVectorContiguousNoEps(double[] dArr, double[] dArr2, int i, int i2, int i3) {
        int i4 = i2 * i;
        while (true) {
            int i5 = i4;
            if (i5 >= i3 * i) {
                return;
            }
            for (int i6 = 0; i6 < i; i6++) {
                int i7 = i5 + i6;
                dArr[i7] = dArr[i7] + dArr2[i6];
            }
            i4 = i5 + i;
        }
    }

    private static void addVectorContiguousEps(double[] dArr, double[] dArr2, int i, double d, int i2, int i3) {
        int i4 = i2 * i;
        while (true) {
            int i5 = i4;
            if (i5 >= i3 * i) {
                return;
            }
            for (int i6 = 0; i6 < i; i6++) {
                int i7 = i5 + i6;
                dArr[i7] = dArr[i7] + dArr2[i6];
                if (Math.abs(dArr[i7]) <= d) {
                    dArr[i7] = 0.0d;
                }
            }
            i4 = i5 + i;
        }
    }

    private static void addVectorNoEps(DenseBlock denseBlock, double[] dArr, int i, int i2, int i3) {
        for (int i4 = i2; i4 < i3; i4++) {
            double[] values = denseBlock.values(i4);
            int pos = denseBlock.pos(i4);
            for (int i5 = 0; i5 < i; i5++) {
                int i6 = pos + i5;
                values[i6] = values[i6] + dArr[i5];
            }
        }
    }

    private static void addVectorEps(DenseBlock denseBlock, double[] dArr, int i, double d, int i2, int i3) {
        for (int i4 = i2; i4 < i3; i4++) {
            double[] values = denseBlock.values(i4);
            int pos = denseBlock.pos(i4);
            for (int i5 = 0; i5 < i; i5++) {
                int i6 = pos + i5;
                values[i6] = values[i6] + dArr[i5];
                if (Math.abs(values[i6]) <= d) {
                    values[i6] = 0.0d;
                }
            }
        }
    }
}
