package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibDecompress.class */
public class CLALibDecompress {
    private static final Log LOG = LogFactory.getLog(CLALibDecompress.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibDecompress$DecompressTask.class */
    public static class DecompressTask implements Callable<Long> {
        private final List<AColGroup> _colGroups;
        private final MatrixBlock _ret;
        private final double _eps;
        private final int _rl;
        private final int _ru;
        private final double[] _constV;
        private final boolean _overlapping;

        protected DecompressTask(List<AColGroup> list, MatrixBlock matrixBlock, double d, int i, int i2, boolean z, double[] dArr) {
            this._colGroups = list;
            this._ret = matrixBlock;
            this._eps = d;
            this._rl = i;
            this._ru = i2;
            this._overlapping = z;
            this._constV = dArr;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Long call() {
            Iterator<AColGroup> it = this._colGroups.iterator();
            while (it.hasNext()) {
                it.next().decompressToBlock(this._ret, this._rl, this._ru);
            }
            if (this._constV != null) {
                CLALibDecompress.addVector(this._ret, this._constV, this._eps, this._rl, this._ru);
            }
            return Long.valueOf(this._overlapping ? 0L : this._ret.recomputeNonZeros(this._rl, this._ru - 1));
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static MatrixBlock decompress(CompressedMatrixBlock compressedMatrixBlock, int i) {
        List arrayList = new ArrayList(compressedMatrixBlock.getColGroups());
        int numRows = compressedMatrixBlock.getNumRows();
        int numColumns = compressedMatrixBlock.getNumColumns();
        boolean isOverlapping = compressedMatrixBlock.isOverlapping();
        long nonZeros = compressedMatrixBlock.getNonZeros();
        MatrixBlock uncompressedColGroupAndRemoveFromListOfColGroups = getUncompressedColGroupAndRemoveFromListOfColGroups(arrayList, isOverlapping, numRows, numColumns);
        if (uncompressedColGroupAndRemoveFromListOfColGroups != null && arrayList.size() == 0) {
            uncompressedColGroupAndRemoveFromListOfColGroups.setNonZeros(uncompressedColGroupAndRemoveFromListOfColGroups.recomputeNonZeros());
            return uncompressedColGroupAndRemoveFromListOfColGroups;
        }
        if (uncompressedColGroupAndRemoveFromListOfColGroups == null) {
            uncompressedColGroupAndRemoveFromListOfColGroups = new MatrixBlock(numRows, numColumns, false, -1L);
            uncompressedColGroupAndRemoveFromListOfColGroups.allocateDenseBlock();
        }
        int ceil = (int) Math.ceil(65535.0d / numColumns);
        int max = ceil > 1000 ? (ceil + 1000) - (ceil % 1000) : Math.max(64, ceil);
        boolean containsSDCOrConst = CLALibUtils.containsSDCOrConst(arrayList);
        double[] dArr = containsSDCOrConst ? new double[uncompressedColGroupAndRemoveFromListOfColGroups.getNumColumns()] : null;
        List filterGroups = containsSDCOrConst ? CLALibUtils.filterGroups(arrayList, dArr) : arrayList;
        if (LOG.isTraceEnabled()) {
            LOG.debug("Decompressing with block size: " + max);
        }
        sortGroups(filterGroups, isOverlapping);
        if (arrayList == filterGroups) {
            dArr = null;
        }
        double eps = getEps(dArr);
        if (i == 1) {
            decompressSingleThread(uncompressedColGroupAndRemoveFromListOfColGroups, filterGroups, numRows, max, dArr, eps, nonZeros, isOverlapping);
        } else {
            decompressMultiThread(uncompressedColGroupAndRemoveFromListOfColGroups, filterGroups, numRows, max, dArr, eps, isOverlapping, i);
        }
        if (isOverlapping) {
            uncompressedColGroupAndRemoveFromListOfColGroups.recomputeNonZeros();
        }
        uncompressedColGroupAndRemoveFromListOfColGroups.examSparsity();
        return uncompressedColGroupAndRemoveFromListOfColGroups;
    }

    private static MatrixBlock getUncompressedColGroupAndRemoveFromListOfColGroups(List<AColGroup> list, boolean z, int i, int i2) {
        if (z || list.size() == 1) {
            for (int i3 = 0; i3 < list.size(); i3++) {
                AColGroup aColGroup = list.get(i3);
                if (aColGroup instanceof ColGroupUncompressed) {
                    MatrixBlock data = ((ColGroupUncompressed) aColGroup).getData();
                    if (data.getNumColumns() == i2 && data.getNumRows() == i && (!data.isInSparseFormat() || list.size() == 1)) {
                        list.remove(i3);
                        LOG.debug("Using one of the uncompressed ColGroups as base for decompression");
                        return data;
                    }
                }
            }
        }
        return null;
    }

    private static void decompressSingleThread(MatrixBlock matrixBlock, List<AColGroup> list, int i, int i2, double[] dArr, double d, long j, boolean z) {
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i4 >= i) {
                break;
            }
            int min = Math.min(i4 + i2, i);
            Iterator<AColGroup> it = list.iterator();
            while (it.hasNext()) {
                it.next().decompressToBlock(matrixBlock, i4, min);
            }
            if (dArr != null && !matrixBlock.isInSparseFormat()) {
                addVector(matrixBlock, dArr, d, i4, min);
            }
            i3 = i4 + i2;
        }
        matrixBlock.setNonZeros((j == -1 || z) ? matrixBlock.recomputeNonZeros() : j);
    }

    private static void decompressMultiThread(MatrixBlock matrixBlock, List<AColGroup> list, int i, int i2, double[] dArr, double d, boolean z, int i3) {
        try {
            ExecutorService executorService = CommonThreadPool.get(i3);
            ArrayList arrayList = new ArrayList();
            for (int i4 = 0; i4 * i2 < i; i4++) {
                arrayList.add(new DecompressTask(list, matrixBlock, d, i4 * i2, Math.min((i4 + 1) * i2, i), z, dArr));
            }
            List invokeAll = executorService.invokeAll(arrayList);
            executorService.shutdown();
            long j = 0;
            Iterator it = invokeAll.iterator();
            while (it.hasNext()) {
                j += ((Long) ((Future) it.next()).get()).longValue();
            }
            matrixBlock.setNonZeros(j);
        } catch (InterruptedException | ExecutionException e) {
            throw new DMLCompressionException("Parallel decompression failed", e);
        }
    }

    private static void sortGroups(List<AColGroup> list, boolean z) {
        if (z) {
            list.sort(Comparator.comparing(aColGroup -> {
                return Double.valueOf(effect(aColGroup));
            }));
        }
    }

    private static double effect(AColGroup aColGroup) {
        if (aColGroup instanceof ColGroupUncompressed) {
            return -1.7976931348623157E308d;
        }
        return -Math.max(aColGroup.getMax(), Math.abs(aColGroup.getMin()));
    }

    private static double getEps(double[] dArr) {
        if (dArr == null) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        double d = -1.7976931348623157E308d;
        double d2 = Double.MAX_VALUE;
        for (double d3 : dArr) {
            if (d3 > d) {
                d = d3;
            }
            if (d3 < d2) {
                d2 = d3;
            }
        }
        return ((d + 1.0E-4d) - d2) * 1.0E-10d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static void addVector(MatrixBlock matrixBlock, double[] dArr, double d, int i, int i2) {
        int numColumns = matrixBlock.getNumColumns();
        DenseBlock denseBlock = matrixBlock.getDenseBlock();
        if (d == DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            for (int i3 = i; i3 < i2; i3++) {
                double[] values = denseBlock.values(i3);
                int pos = denseBlock.pos(i3);
                for (int i4 = 0; i4 < numColumns; i4++) {
                    int i5 = pos + i4;
                    values[i5] = values[i5] + dArr[i4];
                }
            }
            return;
        }
        for (int i6 = i; i6 < i2; i6++) {
            double[] values2 = denseBlock.values(i6);
            int pos2 = denseBlock.pos(i6);
            for (int i7 = 0; i7 < numColumns; i7++) {
                int i8 = pos2 + i7;
                values2[i8] = values2[i8] + dArr[i7];
                if (Math.abs(values2[i8]) <= d) {
                    values2[i8] = 0.0d;
                }
            }
        }
    }
}
