package org.apache.sysds.runtime.compress.cocode;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.sysds.runtime.compress.CompressionSettings;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.estim.AComEst;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.util.CommonThreadPool;

/* loaded from: input_file:org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue.class */
public class CoCodePriorityQue extends AColumnCoCoder {
    private static final int COL_COMBINE_THRESHOLD = 1024;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/compress/cocode/CoCodePriorityQue$PQTask.class */
    public static class PQTask implements Callable<List<CompressedSizeInfoColGroup>> {
        private final List<CompressedSizeInfoColGroup> _groups;
        private final int _start;
        private final int _end;
        private final AComEst _sEst;
        private final ACostEstimate _cEst;
        private final int _minNumGroups;

        protected PQTask(List<CompressedSizeInfoColGroup> list, int i, int i2, AComEst aComEst, ACostEstimate aCostEstimate, int i3) {
            this._groups = list;
            this._start = i;
            this._end = i2;
            this._sEst = aComEst;
            this._cEst = aCostEstimate;
            this._minNumGroups = i3;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public List<CompressedSizeInfoColGroup> call() {
            try {
                return CoCodePriorityQue.combineBlock(this._groups, this._start, this._end, this._sEst, this._cEst, this._minNumGroups);
            } catch (Exception e) {
                throw new DMLCompressionException("Falied PQTask ", e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public CoCodePriorityQue(AComEst aComEst, ACostEstimate aCostEstimate, CompressionSettings compressionSettings) {
        super(aComEst, aCostEstimate, compressionSettings);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.cocode.AColumnCoCoder
    public CompressedSizeInfo coCodeColumns(CompressedSizeInfo compressedSizeInfo, int i) {
        compressedSizeInfo.setInfo(join(compressedSizeInfo.getInfo(), this._sest, this._cest, 1, i));
        return compressedSizeInfo;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static List<CompressedSizeInfoColGroup> join(List<CompressedSizeInfoColGroup> list, AComEst aComEst, ACostEstimate aCostEstimate, int i, int i2) {
        return (list.size() <= 1024 || i2 <= 1) ? combineSingleThreaded(list, aComEst, aCostEstimate, i) : combineMultiThreaded(list, aComEst, aCostEstimate, i, i2);
    }

    private static List<CompressedSizeInfoColGroup> combineMultiThreaded(List<CompressedSizeInfoColGroup> list, AComEst aComEst, ACostEstimate aCostEstimate, int i, int i2) {
        try {
            ExecutorService executorService = CommonThreadPool.get(i2);
            ArrayList arrayList = new ArrayList();
            int max = Math.max(list.size() / i2, 500);
            for (int i3 = 0; i3 < list.size(); i3 += max) {
                arrayList.add(new PQTask(list, i3, Math.min(i3 + max, list.size()), aComEst, aCostEstimate, i));
            }
            List<CompressedSizeInfoColGroup> list2 = null;
            Iterator it = executorService.invokeAll(arrayList).iterator();
            while (it.hasNext()) {
                List list3 = (List) ((Future) it.next()).get();
                if (list2 == null) {
                    list2 = list3;
                } else {
                    list2.addAll(list3);
                }
            }
            return list2;
        } catch (Exception e) {
            throw new DMLCompressionException("Failed parallel priority que cocoding", e);
        }
    }

    private static List<CompressedSizeInfoColGroup> combineSingleThreaded(List<CompressedSizeInfoColGroup> list, AComEst aComEst, ACostEstimate aCostEstimate, int i) {
        return combineBlock(list, 0, list.size(), aComEst, aCostEstimate, i);
    }

    private static List<CompressedSizeInfoColGroup> combineBlock(List<CompressedSizeInfoColGroup> list, int i, int i2, AComEst aComEst, ACostEstimate aCostEstimate, int i3) {
        Queue<CompressedSizeInfoColGroup> que = getQue(i2 - i, aCostEstimate);
        for (int i4 = i; i4 < i2; i4++) {
            CompressedSizeInfoColGroup compressedSizeInfoColGroup = list.get(i4);
            if (compressedSizeInfoColGroup != null) {
                que.add(compressedSizeInfoColGroup);
            }
        }
        return combineBlock(que, aComEst, aCostEstimate, i3);
    }

    private static List<CompressedSizeInfoColGroup> combineBlock(Queue<CompressedSizeInfoColGroup> queue, AComEst aComEst, ACostEstimate aCostEstimate, int i) {
        ArrayList arrayList = new ArrayList();
        CompressedSizeInfoColGroup poll = queue.poll();
        int size = arrayList.size() + queue.size();
        int i2 = 0;
        while (queue.peek() != null && size >= i && i2 < 5) {
            CompressedSizeInfoColGroup peek = queue.peek();
            CompressedSizeInfoColGroup combine = aComEst.combine(poll, peek);
            if (combine == null) {
                i2++;
                arrayList.add(poll);
            } else if (aCostEstimate.getCost(combine) < aCostEstimate.getCost(poll) + aCostEstimate.getCost(peek)) {
                queue.poll();
                if (combine.getColumns().size() > 128) {
                    i2++;
                    arrayList.add(combine);
                } else {
                    i2 = 0;
                    queue.add(combine);
                }
            } else {
                i2++;
                arrayList.add(poll);
            }
            poll = queue.poll();
            size = arrayList.size() + queue.size();
        }
        while (queue.peek() != null) {
            arrayList.add(poll);
            poll = queue.poll();
        }
        if (poll != null) {
            arrayList.add(poll);
        }
        Iterator<CompressedSizeInfoColGroup> it = queue.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next());
        }
        return arrayList;
    }

    private static Queue<CompressedSizeInfoColGroup> getQue(int i, ACostEstimate aCostEstimate) {
        return new PriorityQueue(i, Comparator.comparing(compressedSizeInfoColGroup -> {
            return Double.valueOf(getCost(compressedSizeInfoColGroup, aCostEstimate));
        }));
    }

    private static double getCost(CompressedSizeInfoColGroup compressedSizeInfoColGroup, ACostEstimate aCostEstimate) {
        return aCostEstimate.getCost(compressedSizeInfoColGroup) + (compressedSizeInfoColGroup.getColumns().avgOfIndex() / 100000.0d);
    }
}
