package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple;
import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.encoding.ConstEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.class */
public final class CLALibCombineGroups {
    protected static final Log LOG = LogFactory.getLog(CLALibCombineGroups.class.getName());

    private CLALibCombineGroups() {
    }

    public static List<AColGroup> combine(CompressedMatrixBlock compressedMatrixBlock, int i) {
        ExecutorService executorService;
        ExecutorService executorService2 = null;
        if (i > 1) {
            try {
                try {
                    executorService = CommonThreadPool.get(i);
                } catch (Exception e) {
                    throw new DMLCompressionException("Compression Failed", e);
                }
            } catch (Throwable th) {
                if (executorService2 != null) {
                    executorService2.shutdown();
                }
                throw th;
            }
        } else {
            executorService = null;
        }
        executorService2 = executorService;
        List<AColGroup> combine = combine(compressedMatrixBlock, null, executorService2);
        if (executorService2 != null) {
            executorService2.shutdown();
        }
        return combine;
    }

    public static List<AColGroup> combine(CompressedMatrixBlock compressedMatrixBlock, CompressedSizeInfo compressedSizeInfo, ExecutorService executorService) {
        List<AColGroup> colGroups = compressedMatrixBlock.getColGroups();
        boolean shouldFilterFOR = CLALibUtils.shouldFilterFOR(colGroups);
        double[] dArr = shouldFilterFOR ? new double[compressedMatrixBlock.getNumColumns()] : null;
        if (shouldFilterFOR) {
            colGroups = CLALibUtils.filterFOR(colGroups, dArr);
        }
        ArrayList arrayList = new ArrayList();
        Iterator<CompressedSizeInfoColGroup> it = compressedSizeInfo.getInfo().iterator();
        while (it.hasNext()) {
            arrayList.add(findGroupsInIndex(it.next().getColumns(), colGroups));
        }
        ArrayList arrayList2 = new ArrayList();
        if (shouldFilterFOR) {
            Iterator it2 = arrayList.iterator();
            while (it2.hasNext()) {
                arrayList2.add(combineN((List) it2.next()).addVector(dArr));
            }
        } else {
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                arrayList2.add(combineN((List) it3.next()));
            }
        }
        return arrayList2;
    }

    public static List<AColGroup> findGroupsInIndex(IColIndex iColIndex, List<AColGroup> list) {
        ArrayList arrayList = new ArrayList();
        for (AColGroup aColGroup : list) {
            if (aColGroup.getColIndices().containsAny(iColIndex)) {
                arrayList.add(aColGroup);
            }
        }
        return arrayList;
    }

    public static AColGroup combineN(List<AColGroup> list) {
        AColGroup aColGroup = list.get(0);
        for (int i = 1; i < list.size(); i++) {
            aColGroup = combine(aColGroup, list.get(i));
        }
        return aColGroup;
    }

    public static AColGroup combine(AColGroup aColGroup, AColGroup aColGroup2) {
        try {
            if ((aColGroup instanceof IFrameOfReferenceGroup) || (aColGroup2 instanceof IFrameOfReferenceGroup)) {
                throw new DMLCompressionException("Invalid call with frame of reference group to combine");
            }
            IColIndex combine = ColIndexFactory.combine(aColGroup, aColGroup2);
            if (aColGroup instanceof ColGroupUncompressed) {
                aColGroup = aColGroup.recompress();
            }
            if (aColGroup2 instanceof ColGroupUncompressed) {
                aColGroup2 = aColGroup2.recompress();
            }
            if ((aColGroup instanceof AColGroupCompressed) && (aColGroup2 instanceof AColGroupCompressed)) {
                return combineCompressed(combine, (AColGroupCompressed) aColGroup, (AColGroupCompressed) aColGroup2);
            }
            if ((aColGroup instanceof ColGroupUncompressed) || (aColGroup2 instanceof ColGroupUncompressed)) {
                return combineUC(combine, aColGroup, aColGroup2);
            }
            throw new NotImplementedException("Not implemented combine for " + aColGroup.getClass().getSimpleName() + " - " + aColGroup2.getClass().getSimpleName());
        } catch (Exception e) {
            throw new DMLCompressionException("Failed to combine:\n\n" + aColGroup + "\n\n" + aColGroup2, e);
        }
    }

    private static AColGroup combineCompressed(IColIndex iColIndex, AColGroupCompressed aColGroupCompressed, AColGroupCompressed aColGroupCompressed2) {
        IEncode encoding = aColGroupCompressed.getEncoding();
        IEncode encoding2 = aColGroupCompressed2.getEncoding();
        if ((encoding instanceof SparseEncoding) && !(encoding2 instanceof SparseEncoding)) {
            return combineCompressed(iColIndex, aColGroupCompressed2, aColGroupCompressed);
        }
        Pair<IEncode, Map<Integer, Integer>> combineWithMap = encoding.combineWithMap(encoding2);
        IEncode iEncode = (IEncode) combineWithMap.getLeft();
        Map map = (Map) combineWithMap.getRight();
        if (iEncode instanceof DenseEncoding) {
            return ColGroupDDC.create(iColIndex, DictionaryFactory.combineDictionaries(aColGroupCompressed, aColGroupCompressed2, map), ((DenseEncoding) iEncode).getMap(), null);
        }
        if (iEncode instanceof EmptyEncoding) {
            return new ColGroupEmpty(iColIndex);
        }
        if (iEncode instanceof ConstEncoding) {
            return ColGroupConst.create(iColIndex, DictionaryFactory.combineDictionaries(aColGroupCompressed, aColGroupCompressed2, map));
        }
        if (!(iEncode instanceof SparseEncoding)) {
            throw new NotImplementedException("Not implemented combine for " + aColGroupCompressed.getClass().getSimpleName() + " - " + aColGroupCompressed2.getClass().getSimpleName());
        }
        SparseEncoding sparseEncoding = (SparseEncoding) iEncode;
        return ColGroupSDC.create(iColIndex, sparseEncoding.getNumRows(), DictionaryFactory.combineDictionariesSparse(aColGroupCompressed, aColGroupCompressed2), constructDefaultTuple(aColGroupCompressed, aColGroupCompressed2), sparseEncoding.getOffsets(), sparseEncoding.getMap(), null);
    }

    private static AColGroup combineUC(IColIndex iColIndex, AColGroup aColGroup, AColGroup aColGroup2) {
        int numRows = aColGroup instanceof ColGroupUncompressed ? ((ColGroupUncompressed) aColGroup).getData().getNumRows() : ((ColGroupUncompressed) aColGroup2).getData().getNumRows();
        MatrixBlock matrixBlock = new MatrixBlock(numRows, iColIndex.size(), false);
        matrixBlock.allocateBlock();
        DenseBlock denseBlock = matrixBlock.getDenseBlock();
        aColGroup.copyAndSet(ColIndexFactory.getColumnMapping(iColIndex, aColGroup.getColIndices())).decompressToDenseBlock(denseBlock, 0, numRows, 0, 0);
        aColGroup2.copyAndSet(ColIndexFactory.getColumnMapping(iColIndex, aColGroup2.getColIndices())).decompressToDenseBlock(denseBlock, 0, numRows, 0, 0);
        matrixBlock.recomputeNonZeros();
        return ColGroupUncompressed.create(iColIndex, matrixBlock, false);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static double[] constructDefaultTuple(AColGroupCompressed aColGroupCompressed, AColGroupCompressed aColGroupCompressed2) {
        double[] dArr = new double[aColGroupCompressed.getNumCols() + aColGroupCompressed2.getNumCols()];
        if (aColGroupCompressed instanceof IContainDefaultTuple) {
            double[] defaultTuple = ((IContainDefaultTuple) aColGroupCompressed).getDefaultTuple();
            System.arraycopy(defaultTuple, 0, dArr, 0, defaultTuple.length);
        }
        if (aColGroupCompressed2 instanceof IContainDefaultTuple) {
            double[] defaultTuple2 = ((IContainDefaultTuple) aColGroupCompressed2).getDefaultTuple();
            System.arraycopy(defaultTuple2, 0, dArr, aColGroupCompressed.getNumCols(), defaultTuple2.length);
        }
        return dArr;
    }
}
