package org.apache.sysds.runtime.compress.lib;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibStack.class */
public final class CLALibStack {
    protected static final Log LOG = LogFactory.getLog(CLALibStack.class.getName());

    private CLALibStack() {
    }

    public static MatrixBlock combine(Map<MatrixIndexes, MatrixBlock> map, Map<Integer, List<IDictionary>> map2, int i) {
        MatrixIndexes matrixIndexes = new MatrixIndexes(1L, 1L);
        MatrixBlock matrixBlock = map.get(matrixIndexes);
        return combine(map, map2, matrixIndexes, (int) findRLength(map, matrixBlock), (int) findCLength(map, matrixBlock), Math.max(matrixBlock.getNumColumns(), matrixBlock.getNumRows()), i);
    }

    public static MatrixBlock combine(Map<MatrixIndexes, MatrixBlock> map, Map<Integer, List<IDictionary>> map2, int i, int i2, int i3, int i4) {
        return combine(map, map2, new MatrixIndexes(), i, i2, i3, i4);
    }

    private static long findRLength(Map<MatrixIndexes, MatrixBlock> map, MatrixBlock matrixBlock) {
        MatrixIndexes matrixIndexes = new MatrixIndexes(1L, 1L);
        long j = 0;
        while (true) {
            if (map.get(matrixIndexes) == null) {
                return j;
            }
            j += r0.getNumRows();
            matrixIndexes.setIndexes(matrixIndexes.getRowIndex() + 1, 1L);
        }
    }

    private static long findCLength(Map<MatrixIndexes, MatrixBlock> map, MatrixBlock matrixBlock) {
        MatrixIndexes matrixIndexes = new MatrixIndexes(1L, 1L);
        long j = 0;
        while (true) {
            if (map.get(matrixIndexes) == null) {
                return j;
            }
            j += r0.getNumColumns();
            matrixIndexes.setIndexes(1L, matrixIndexes.getColumnIndex() + 1);
        }
    }

    private static MatrixBlock combine(Map<MatrixIndexes, MatrixBlock> map, Map<Integer, List<IDictionary>> map2, MatrixIndexes matrixIndexes, int i, int i2, int i3, int i4) {
        try {
            return combineColumnGroups(map, map2, matrixIndexes, i, i2, i3, i4);
        } catch (Exception e) {
            LOG.warn("Failed to combine compressed blocks, fallback to decompression.", e);
            return combineViaDecompression(map, i, i2, i3, i4);
        }
    }

    private static MatrixBlock combineViaDecompression(Map<MatrixIndexes, MatrixBlock> map, int i, int i2, int i3, int i4) {
        MatrixBlock matrixBlock = new MatrixBlock(i, i2, false);
        matrixBlock.allocateDenseBlock();
        for (Map.Entry<MatrixIndexes, MatrixBlock> entry : map.entrySet()) {
            MatrixIndexes key = entry.getKey();
            MatrixBlock value = entry.getValue();
            if (value != null) {
                value.putInto(matrixBlock, ((int) (key.getRowIndex() - 1)) * i3, ((int) (key.getColumnIndex() - 1)) * i3, false);
            }
        }
        matrixBlock.setNonZeros(-1L);
        matrixBlock.examSparsity(true);
        return matrixBlock;
    }

    private static MatrixBlock combineColumnGroups(Map<MatrixIndexes, MatrixBlock> map, Map<Integer, List<IDictionary>> map2, MatrixIndexes matrixIndexes, int i, int i2, int i3, int i4) {
        int i5 = 0;
        for (int i6 = 0; i6 * i3 < i2; i6++) {
            matrixIndexes.setIndexes(1L, i6 + 1);
            i5 += ((CompressedMatrixBlock) map.get(matrixIndexes)).getColGroups().size();
        }
        AColGroup[][] aColGroupArr = new AColGroup[i5][(i / i3) + (i % i3 > 0 ? 1 : 0)];
        for (int i7 = 0; i7 * i3 < i; i7++) {
            int i8 = 0;
            int i9 = 0;
            while (i9 * i3 < i2) {
                matrixIndexes.setIndexes(i7 + 1, i9 + 1);
                CompressedMatrixBlock compressedMatrixBlock = (CompressedMatrixBlock) map.get(matrixIndexes);
                if (compressedMatrixBlock == null) {
                    throw new RuntimeException("Invalid empty read: " + matrixIndexes + "  " + i + " " + i2 + " " + i3);
                }
                List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
                if (i8 + colGroups.size() > i5) {
                    return combineViaDecompression(map, i, i2, i3, i4);
                }
                for (int i10 = 0; i10 < colGroups.size(); i10++) {
                    AColGroup aColGroup = colGroups.get(i10);
                    aColGroupArr[i8][i7] = i9 > 0 ? aColGroup.shiftColIndices(i9 * i3) : aColGroup;
                    i8++;
                }
                i9++;
            }
            if (i8 != aColGroupArr.length) {
                LOG.warn("Combining via decompression. The number of columngroups in each block is not identical");
                return combineViaDecompression(map, i, i2, i3, i4);
            }
        }
        ExecutorService executorService = CommonThreadPool.get();
        try {
            try {
                List<AColGroup> list = (List) executorService.submit(() -> {
                    return (List) ((Stream) Arrays.stream(aColGroupArr).parallel()).map(aColGroupArr2 -> {
                        return AColGroup.appendN(aColGroupArr2, i3, i);
                    }).collect(Collectors.toList());
                }).get();
                if (list.contains(null)) {
                    LOG.warn("Combining via decompression. There was a column group that did not append ");
                    MatrixBlock combineViaDecompression = combineViaDecompression(map, i, i2, i3, i4);
                    executorService.shutdown();
                    return combineViaDecompression;
                }
                if (map2 != null) {
                    list = CLALibSeparator.combine(list, map2, i3);
                }
                CompressedMatrixBlock compressedMatrixBlock2 = new CompressedMatrixBlock(i, i2, -1L, false, list);
                executorService.shutdown();
                return compressedMatrixBlock2;
            } catch (InterruptedException | ExecutionException e) {
                throw new DMLRuntimeException("Failed to combine column groups", e);
            }
        } catch (Throwable th) {
            executorService.shutdown();
            throw th;
        }
    }
}
