package org.apache.sysds.runtime.compress;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.cocode.PlanningCoCoder;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorFactory;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.utils.DblArrayIntListHashMap;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.DMLCompressionStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.class */
public class CompressedMatrixBlockFactory {
    private static final Log LOG = LogFactory.getLog(CompressedMatrixBlockFactory.class.getName());
    private double lastPhase;
    private MatrixBlock mb;
    private int k;
    private CompressionSettings compSettings;
    private CompressedMatrixBlock res;
    private CompressedSizeInfo coCodeColGroups;
    private Timing time = new Timing(true);
    private CompressionStatistics _stats = new CompressionStatistics();
    private int phase = 0;

    private CompressedMatrixBlockFactory(MatrixBlock matrixBlock, int i, CompressionSettings compressionSettings) {
        this.mb = matrixBlock;
        this.k = i;
        this.compSettings = compressionSettings;
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock) {
        return compress(matrixBlock, 1, new CompressionSettingsBuilder().create());
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, CompressionSettings compressionSettings) {
        return compress(matrixBlock, 1, compressionSettings);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i) {
        return compress(matrixBlock, i, new CompressionSettingsBuilder().create());
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, CompressionSettings compressionSettings) {
        return new CompressedMatrixBlockFactory(matrixBlock, i, compressionSettings).compressMatrix();
    }

    public static CompressedMatrixBlock createConstant(int i, int i2, double d) {
        CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(i, i2);
        compressedMatrixBlock.allocateColGroup(ColGroupFactory.genColGroupConst(i, i2, d));
        compressedMatrixBlock.recomputeNonZeros();
        return compressedMatrixBlock;
    }

    private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {
        if (this.mb instanceof CompressedMatrixBlock) {
            LOG.info("MatrixBlock already compressed");
            return new ImmutablePair(this.mb, (Object) null);
        }
        this._stats.denseSize = MatrixBlock.estimateSizeInMemory(this.mb.getNumRows(), this.mb.getNumColumns(), 1.0d);
        this._stats.originalSize = this.mb.getInMemorySize();
        this.res = new CompressedMatrixBlock(this.mb);
        classifyPhase();
        if (this.coCodeColGroups == null) {
            return abortCompression();
        }
        transposePhase();
        compressPhase();
        sharePhase();
        cleanupPhase();
        if (this.res == null) {
            return abortCompression();
        }
        this.res.recomputeNonZeros();
        return new ImmutablePair(this.res, this._stats);
    }

    private void classifyPhase() {
        CompressedSizeEstimator sizeEstimator = CompressedSizeEstimatorFactory.getSizeEstimator(this.mb, this.compSettings);
        CompressedSizeInfo computeCompressedSizeInfos = sizeEstimator.computeCompressedSizeInfos(this.k);
        this._stats.estimatedSizeCols = computeCompressedSizeInfos.memoryEstimate();
        logPhase();
        if (this._stats.estimatedSizeCols < this._stats.originalSize || this.compSettings.columnPartitioner == PlanningCoCoder.PartitionerType.COST_MATRIX_MULT) {
            coCodePhase(sizeEstimator, computeCompressedSizeInfos, this.mb.getNumRows());
        } else {
            LOG.info("Estimated Size of singleColGroups: " + this._stats.estimatedSizeCols);
            LOG.info("Original size                    : " + this._stats.originalSize);
        }
    }

    private void coCodePhase(CompressedSizeEstimator compressedSizeEstimator, CompressedSizeInfo compressedSizeInfo, int i) {
        this.coCodeColGroups = PlanningCoCoder.findCoCodesByPartitioning(compressedSizeEstimator, compressedSizeInfo, i, this.k, this.compSettings);
        this._stats.estimatedSizeCoCoded = this.coCodeColGroups.memoryEstimate();
        logPhase();
    }

    private void transposePhase() {
        boolean isInSparseFormat = this.mb.isInSparseFormat();
        transposeHeuristics();
        this.mb = this.compSettings.transposed ? LibMatrixReorg.transpose(this.mb, new MatrixBlock(this.mb.getNumColumns(), this.mb.getNumRows(), isInSparseFormat), this.k) : new MatrixBlock(this.mb.getNumRows(), this.mb.getNumColumns(), isInSparseFormat).copyShallow(this.mb);
        logPhase();
    }

    private void transposeHeuristics() {
        String str = this.compSettings.transposeInput;
        boolean z = -1;
        switch (str.hashCode()) {
            case 3569038:
                if (str.equals("true")) {
                    z = false;
                    break;
                }
                break;
            case 97196323:
                if (str.equals("false")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.compSettings.transposed = true;
                return;
            case true:
                this.compSettings.transposed = false;
                return;
            default:
                if (!this.mb.isInSparseFormat()) {
                    this.compSettings.transposed = false;
                    return;
                }
                this.compSettings.transposed = (this.mb.getNumRows() > 500000) || (this.coCodeColGroups.getNumberColGroups() > this.mb.getNumColumns() / 2);
                return;
        }
    }

    private void compressPhase() {
        this.res.allocateColGroupList(ColGroupFactory.compressColGroups(this.mb, this.coCodeColGroups, this.compSettings, this.k));
        this._stats.compressedInitialSize = this.res.getInMemorySize();
        logPhase();
    }

    private void sharePhase() {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        ArrayList arrayList3 = new ArrayList();
        for (AColGroup aColGroup : this.res.getColGroups()) {
            if (aColGroup instanceof ColGroupEmpty) {
                arrayList.add(aColGroup);
            } else if (aColGroup instanceof ColGroupConst) {
                arrayList2.add(aColGroup);
            } else {
                arrayList3.add(aColGroup);
            }
        }
        if (!arrayList.isEmpty()) {
            arrayList3.add(combineEmpty(arrayList));
        }
        if (!arrayList2.isEmpty()) {
            arrayList3.add(combineConst(arrayList2));
        }
        this.res.allocateColGroupList(arrayList3);
        logPhase();
    }

    private static AColGroup combineEmpty(List<AColGroup> list) {
        return new ColGroupEmpty(combineColIndexes(list), list.get(0).getNumRows());
    }

    private static AColGroup combineConst(List<AColGroup> list) {
        int[] combineColIndexes = combineColIndexes(list);
        double[] dArr = new double[combineColIndexes.length];
        for (int i = 0; i < combineColIndexes.length; i++) {
            Iterator<AColGroup> it = list.iterator();
            while (true) {
                if (it.hasNext()) {
                    ColGroupConst colGroupConst = (ColGroupConst) it.next();
                    int binarySearch = Arrays.binarySearch(colGroupConst.getColIndices(), combineColIndexes[i]);
                    if (binarySearch >= 0) {
                        dArr[i] = colGroupConst.getDictionary().getValue(binarySearch);
                        break;
                    }
                }
            }
        }
        return new ColGroupConst(combineColIndexes, list.get(0).getNumRows(), new Dictionary(dArr));
    }

    private static int[] combineColIndexes(List<AColGroup> list) {
        int i = 0;
        Iterator<AColGroup> it = list.iterator();
        while (it.hasNext()) {
            i += it.next().getNumCols();
        }
        int[] iArr = new int[i];
        int i2 = 0;
        Iterator<AColGroup> it2 = list.iterator();
        while (it2.hasNext()) {
            for (int i3 : it2.next().getColIndices()) {
                int i4 = i2;
                i2++;
                iArr[i4] = i3;
            }
        }
        Arrays.sort(iArr);
        return iArr;
    }

    private void cleanupPhase() {
        this.res.cleanupBlock(true, true);
        this._stats.size = this.res.estimateCompressedSizeInMemory();
        double ratio = this._stats.getRatio();
        if (ratio >= 1.0d || this.compSettings.columnPartitioner == PlanningCoCoder.PartitionerType.COST_MATRIX_MULT) {
            this.mb.cleanupBlock(true, true);
            this._stats.setColGroupsCounts(this.res.getColGroups());
            logPhase();
        } else {
            LOG.info("--dense size:        " + this._stats.denseSize);
            LOG.info("--original size:     " + this._stats.originalSize);
            LOG.info("--compressed size:   " + this._stats.size);
            LOG.info("--compression ratio: " + ratio);
            LOG.info("Abort block compression because compression ratio is less than 1.");
            this.res = null;
        }
    }

    private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
        LOG.warn("Compression aborted at phase: " + this.phase);
        if (this.compSettings.transposed) {
            LibMatrixReorg.transposeInPlace(this.mb, this.k);
        }
        return new ImmutablePair(this.mb, this._stats);
    }

    private void logPhase() {
        setNextTimePhase(this.time.stop());
        DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), this.phase);
        if (LOG.isDebugEnabled()) {
            switch (this.phase) {
                case 0:
                    LOG.debug("--compression phase " + this.phase + " Classify  : " + getLastTimePhase());
                    LOG.debug("--Individual Columns Estimated Compression: " + this._stats.estimatedSizeCols);
                    break;
                case 1:
                    LOG.debug("--compression phase " + this.phase + " Grouping  : " + getLastTimePhase());
                    LOG.debug("Grouping using: " + this.compSettings.columnPartitioner);
                    LOG.debug("--Cocoded Columns estimated Compression:" + this._stats.estimatedSizeCoCoded);
                    break;
                case 2:
                    LOG.debug("--compression phase " + this.phase + " Transpose : " + getLastTimePhase());
                    LOG.debug("Did transpose: " + this.compSettings.transposed);
                    break;
                case 3:
                    LOG.debug("--compression phase " + this.phase + " Compress  : " + getLastTimePhase());
                    LOG.debug("--compression Hash collisions:" + DblArrayIntListHashMap.hashMissCount);
                    DblArrayIntListHashMap.hashMissCount = 0;
                    LOG.debug("--compressed initial actual size:" + this._stats.compressedInitialSize);
                    break;
                case 4:
                    LOG.debug("--compression phase " + this.phase + " Share     : " + getLastTimePhase());
                    break;
                case 5:
                    LOG.debug("--num col groups: " + this.res.getColGroups().size());
                    LOG.debug("--compression phase " + this.phase + " Cleanup   : " + getLastTimePhase());
                    LOG.debug("--col groups types " + this._stats.getGroupsTypesString());
                    LOG.debug("--col groups sizes " + this._stats.getGroupsSizesString());
                    LOG.debug("--dense size:        " + this._stats.denseSize);
                    LOG.debug("--original size:     " + this._stats.originalSize);
                    LOG.debug("--compressed size:   " + this._stats.size);
                    LOG.debug("--compression ratio: " + this._stats.getRatio());
                    int[] iArr = new int[this.res.getColGroups().size()];
                    int i = 0;
                    Iterator<AColGroup> it = this.res.getColGroups().iterator();
                    while (it.hasNext()) {
                        int i2 = i;
                        i++;
                        iArr[i2] = it.next().getNumValues();
                    }
                    LOG.debug("--compressed colGroup dictionary sizes: " + Arrays.toString(iArr));
                    if (LOG.isTraceEnabled()) {
                        for (AColGroup aColGroup : this.res.getColGroups()) {
                            LOG.trace("--colGroups type       : " + aColGroup.getClass().getSimpleName() + " size: " + aColGroup.estimateInMemorySize() + (aColGroup instanceof ColGroupValue ? "  numValues :" + ((ColGroupValue) aColGroup).getNumValues() : "") + "  colIndexes : " + Arrays.toString(aColGroup.getColIndices()));
                        }
                        break;
                    }
                    break;
            }
        }
        this.phase++;
    }

    public void setNextTimePhase(double d) {
        this.lastPhase = d;
    }

    public double getLastTimePhase() {
        return this.lastPhase;
    }
}
