package org.apache.sysds.runtime.compress;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.cocode.CoCoderFactory;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupValue;
import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
import org.apache.sysds.runtime.compress.cost.ACostEstimate;
import org.apache.sysds.runtime.compress.cost.CostEstimatorBuilder;
import org.apache.sysds.runtime.compress.cost.CostEstimatorFactory;
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.compress.cost.MemoryCostEstimator;
import org.apache.sysds.runtime.compress.estim.AComEst;
import org.apache.sysds.runtime.compress.estim.ComEstFactory;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.DMLCompressionStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.class */
public class CompressedMatrixBlockFactory {
    private static final Log LOG = LogFactory.getLog(CompressedMatrixBlockFactory.class.getName());
    private final Timing time;
    private final CompressionStatistics _stats;
    private final int k;
    private final CompressionSettings compSettings;
    private final ACostEstimate costEstimator;
    private double lastPhase;
    private MatrixBlock mb;
    private CompressedMatrixBlock res;
    private int phase;
    private AComEst informationExtractor;
    private CompressedSizeInfo compressionGroups;

    private CompressedMatrixBlockFactory(MatrixBlock matrixBlock, int i, CompressionSettingsBuilder compressionSettingsBuilder, ACostEstimate aCostEstimate) {
        this(matrixBlock, i, compressionSettingsBuilder.create(), aCostEstimate);
    }

    private CompressedMatrixBlockFactory(MatrixBlock matrixBlock, int i, CompressionSettings compressionSettings, ACostEstimate aCostEstimate) {
        this.time = new Timing(true);
        this._stats = new CompressionStatistics();
        this.phase = 0;
        this.mb = matrixBlock;
        this.k = i;
        this.compSettings = compressionSettings;
        this.costEstimator = aCostEstimate;
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock) {
        return compress(matrixBlock, 1, new CompressionSettingsBuilder(), (WTreeRoot) null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, WTreeRoot wTreeRoot) {
        return compress(matrixBlock, 1, new CompressionSettingsBuilder(), wTreeRoot);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, CostEstimatorBuilder costEstimatorBuilder) {
        return compress(matrixBlock, 1, new CompressionSettingsBuilder(), costEstimatorBuilder);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, InstructionTypeCounter instructionTypeCounter) {
        return instructionTypeCounter == null ? compress(matrixBlock, 1, new CompressionSettingsBuilder()) : compress(matrixBlock, 1, new CompressionSettingsBuilder(), new CostEstimatorBuilder(instructionTypeCounter));
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, CompressionSettingsBuilder compressionSettingsBuilder) {
        return compress(matrixBlock, 1, compressionSettingsBuilder, (WTreeRoot) null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i) {
        return compress(matrixBlock, i, new CompressionSettingsBuilder(), (WTreeRoot) null);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, WTreeRoot wTreeRoot) {
        return compress(matrixBlock, i, new CompressionSettingsBuilder(), wTreeRoot);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, CostEstimatorBuilder costEstimatorBuilder) {
        return compress(matrixBlock, i, new CompressionSettingsBuilder(), costEstimatorBuilder);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, InstructionTypeCounter instructionTypeCounter) {
        return instructionTypeCounter == null ? compress(matrixBlock, 1, new CompressionSettingsBuilder()) : compress(matrixBlock, i, new CompressionSettingsBuilder(), new CostEstimatorBuilder(instructionTypeCounter));
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, ACostEstimate aCostEstimate) {
        return compress(matrixBlock, 1, new CompressionSettingsBuilder(), aCostEstimate);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, ACostEstimate aCostEstimate) {
        return compress(matrixBlock, i, new CompressionSettingsBuilder(), aCostEstimate);
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, CompressionSettingsBuilder compressionSettingsBuilder) {
        return compress(matrixBlock, i, compressionSettingsBuilder, (WTreeRoot) null);
    }

    public static void compressAsync(ExecutionContext executionContext, String str) {
        compressAsync(executionContext, str, null);
    }

    public static void compressAsync(ExecutionContext executionContext, String str, InstructionTypeCounter instructionTypeCounter) {
        LOG.debug("Compressing Async");
        CompletableFuture.runAsync(() -> {
            CacheableData<?> cacheableData = executionContext.getCacheableData(str);
            if (cacheableData instanceof MatrixObject) {
                MatrixObject matrixObject = (MatrixObject) cacheableData;
                MatrixBlock acquireReadAndRelease = matrixObject.acquireReadAndRelease();
                MatrixBlock matrixBlock = (MatrixBlock) compress(matrixObject.acquireReadAndRelease(), instructionTypeCounter).getLeft();
                if (matrixBlock instanceof CompressedMatrixBlock) {
                    ExecutionContext.createCacheableData(acquireReadAndRelease);
                    matrixObject.acquireModify(matrixBlock);
                    matrixObject.release();
                }
            }
        });
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, CompressionSettingsBuilder compressionSettingsBuilder, WTreeRoot wTreeRoot) {
        CompressionSettings create = compressionSettingsBuilder.create();
        return new CompressedMatrixBlockFactory(matrixBlock, i, create, wTreeRoot == null ? CostEstimatorFactory.create(create, null, matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock.getSparsity()) : CostEstimatorFactory.create(create, new CostEstimatorBuilder(wTreeRoot), matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock.getSparsity())).compressMatrix();
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, CompressionSettingsBuilder compressionSettingsBuilder, CostEstimatorBuilder costEstimatorBuilder) {
        CompressionSettings create = compressionSettingsBuilder.create();
        return new CompressedMatrixBlockFactory(matrixBlock, i, create, CostEstimatorFactory.create(create, costEstimatorBuilder, matrixBlock.getNumRows(), matrixBlock.getNumColumns(), matrixBlock.getSparsity())).compressMatrix();
    }

    public static Pair<MatrixBlock, CompressionStatistics> compress(MatrixBlock matrixBlock, int i, CompressionSettingsBuilder compressionSettingsBuilder, ACostEstimate aCostEstimate) {
        return new CompressedMatrixBlockFactory(matrixBlock, i, compressionSettingsBuilder, aCostEstimate).compressMatrix();
    }

    public static CompressedMatrixBlock genUncompressedCompressedMatrixBlock(MatrixBlock matrixBlock) {
        CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(matrixBlock.getNumRows(), matrixBlock.getNumColumns());
        compressedMatrixBlock.allocateColGroup(ColGroupUncompressed.create(matrixBlock));
        compressedMatrixBlock.setNonZeros(matrixBlock.getNonZeros());
        return compressedMatrixBlock;
    }

    public static CompressedMatrixBlock createConstant(int i, int i2, double d) {
        CompressedMatrixBlock compressedMatrixBlock = new CompressedMatrixBlock(i, i2);
        compressedMatrixBlock.allocateColGroup(ColGroupConst.create(i2, d));
        compressedMatrixBlock.recomputeNonZeros();
        if (compressedMatrixBlock.getNumRows() <= 0) {
            throw new DMLCompressionException("Invalid size of allocated constant compressed matrix block");
        }
        return compressedMatrixBlock;
    }

    private Pair<MatrixBlock, CompressionStatistics> compressMatrix() {
        if (this.mb.getNonZeros() < 0) {
            LOG.warn("Recomputing non-zeros since it is unknown in compression");
            this.mb.recomputeNonZeros();
        } else if ((this.mb instanceof CompressedMatrixBlock) && ((CompressedMatrixBlock) this.mb).isOverlapping()) {
            LOG.warn("Unsupported recompression of overlapping compression");
            return new ImmutablePair(this.mb, (Object) null);
        }
        this._stats.denseSize = MatrixBlock.estimateSizeInMemory(this.mb.getNumRows(), this.mb.getNumColumns(), 1.0d);
        this._stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(this.mb.getNumRows(), this.mb.getNumColumns(), this.mb.getSparsity());
        this._stats.originalSize = this.mb.getInMemorySize();
        this._stats.originalCost = this.costEstimator.getCost(this.mb);
        if (this.mb.isEmpty()) {
            return createEmpty();
        }
        this.res = new CompressedMatrixBlock(this.mb);
        classifyPhase();
        if (this.compressionGroups == null) {
            return abortCompression();
        }
        this.compressionGroups.clearMaps();
        this.informationExtractor.clearNNZ();
        transposePhase();
        compressPhase();
        finalizePhase();
        return this.res == null ? abortCompression() : new ImmutablePair(this.res, this._stats);
    }

    private void classifyPhase() {
        this.informationExtractor = ComEstFactory.createEstimator(this.mb, this.compSettings, this.k);
        this.compressionGroups = this.informationExtractor.computeCompressedSizeInfos(this.k);
        if (LOG.isTraceEnabled()) {
            LOG.trace("Logging all individual columns estimated cost:");
            for (CompressedSizeInfoColGroup compressedSizeInfoColGroup : this.compressionGroups.getInfo()) {
                LOG.trace(String.format("Cost: %8.0f Size: %16.0f %15s", Double.valueOf(this.costEstimator.getCost(compressedSizeInfoColGroup)), Double.valueOf(compressedSizeInfoColGroup.getMinSize()), compressedSizeInfoColGroup.getColumns()));
            }
        }
        this._stats.estimatedSizeCols = this.compressionGroups.memoryEstimate();
        this._stats.estimatedCostCols = this.costEstimator.getCost(this.compressionGroups);
        logPhase();
        int numColumns = this.mb.getNumColumns();
        double sqrt = this._stats.estimatedCostCols / Math.sqrt(numColumns);
        if (sqrt < this._stats.originalCost) {
            if (numColumns > 1) {
                coCodePhase();
                return;
            } else {
                logPhase();
                return;
            }
        }
        this.compressionGroups = null;
        if (LOG.isInfoEnabled()) {
            LOG.info("Aborting before co-code, because the compression looks bad");
            Log log = LOG;
            double d = this._stats.originalCost;
            log.info("Threshold was set to : " + sqrt + " but it was above original " + log);
            LOG.info("Original size       : " + this._stats.originalSize);
            LOG.info("single col size     : " + this._stats.estimatedSizeCols);
            if (this.costEstimator instanceof MemoryCostEstimator) {
                return;
            }
            LOG.info("original cost      : " + this._stats.originalCost);
            LOG.info("single col cost    : " + this._stats.estimatedCostCols);
        }
    }

    private void coCodePhase() {
        this.compressionGroups = CoCoderFactory.findCoCodesByPartitioning(this.informationExtractor, this.compressionGroups, this.k, this.costEstimator, this.compSettings);
        this._stats.estimatedSizeCoCoded = this.compressionGroups.memoryEstimate();
        this._stats.estimatedCostCoCoded = this.costEstimator.getCost(this.compressionGroups);
        logPhase();
        if (this._stats.estimatedCostCoCoded > this._stats.originalCost) {
            this.compressionGroups = null;
            if (LOG.isInfoEnabled()) {
                LOG.info("Aborting after co-code, because the compression looks bad");
                LOG.info("co-code size      : " + this._stats.estimatedSizeCoCoded);
                LOG.info("original size     : " + this._stats.originalSize);
                if (this.costEstimator instanceof MemoryCostEstimator) {
                    return;
                }
                LOG.info("original cost    : " + this._stats.originalCost);
                LOG.info("single col cost  : " + this._stats.estimatedCostCols);
                LOG.info("co-code cost     : " + this._stats.estimatedCostCoCoded);
            }
        }
    }

    private void transposePhase() {
        boolean z = Runtime.getRuntime().freeMemory() - (this.mb.estimateSizeInMemory() * 2) > 0;
        if (!this.compSettings.transposed && z) {
            transposeHeuristics();
            if (this.compSettings.transposed) {
                this.mb = LibMatrixReorg.transpose(this.mb, new MatrixBlock(this.mb.getNumColumns(), this.mb.getNumRows(), this.mb.isInSparseFormat()), this.k, true);
                this.mb.evalSparseFormatInMemory();
            }
        }
        logPhase();
    }

    private void transposeHeuristics() {
        String str = this.compSettings.transposeInput;
        boolean z = -1;
        switch (str.hashCode()) {
            case 3569038:
                if (str.equals("true")) {
                    z = false;
                    break;
                }
                break;
            case 97196323:
                if (str.equals("false")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                this.compSettings.transposed = true;
                return;
            case true:
                this.compSettings.transposed = false;
                return;
            default:
                this.compSettings.transposed = transposeHeuristics(this.compressionGroups.getNumberColGroups(), this.mb);
                return;
        }
    }

    public static boolean transposeHeuristics(int i, MatrixBlock matrixBlock) {
        if (!matrixBlock.isInSparseFormat()) {
            return false;
        }
        if (matrixBlock.getNumColumns() > 10000 || matrixBlock.getNumRows() > 10000 || matrixBlock.getNonZeros() < 1000) {
            return true;
        }
        return (matrixBlock.getNumRows() > 500000) && (i > matrixBlock.getNumColumns() / 30);
    }

    private void compressPhase() {
        this.res.allocateColGroupList(ColGroupFactory.compressColGroups(this.mb, this.compressionGroups, this.compSettings, this.costEstimator, this.k));
        this._stats.compressedInitialSize = this.res.getInMemorySize();
        logPhase();
    }

    private void finalizePhase() {
        this.res.cleanupBlock(true, true);
        this._stats.compressedSize = this.res.getInMemorySize();
        this._stats.compressedCost = this.costEstimator.getCost(this.res.getColGroups(), this.res.getNumRows());
        double ratio = this._stats.getRatio();
        double denseRatio = this._stats.getDenseRatio();
        this._stats.setColGroupsCounts(this.res.getColGroups());
        if (ratio >= 1.0d || denseRatio >= 100.0d) {
            if (this.compSettings.isInSparkInstruction) {
                this.res.clearSoftReferenceToDecompressed();
            }
            this.res.setNonZeros(this.mb.getNonZeros());
            logPhase();
            return;
        }
        LOG.info("--dense size:        " + this._stats.denseSize);
        LOG.info("--original size:     " + this._stats.originalSize);
        LOG.info("--compressed size:   " + this._stats.compressedSize);
        LOG.info("--compression ratio: " + ratio);
        LOG.debug("--col groups types   " + this._stats.getGroupsTypesString());
        LOG.debug("--col groups sizes   " + this._stats.getGroupsSizesString());
        logLengths();
        LOG.info("Abort block compression because compression ratio is less than 1.");
        this.res = null;
        setNextTimePhase(this.time.stop());
        DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), this.phase);
    }

    private Pair<MatrixBlock, CompressionStatistics> abortCompression() {
        LOG.warn("Compression aborted at phase: " + this.phase);
        return new ImmutablePair(this.mb, this._stats);
    }

    private void logPhase() {
        setNextTimePhase(this.time.stop());
        DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), this.phase);
        if (LOG.isDebugEnabled()) {
            if (!this.compSettings.isInSparkInstruction) {
                switch (this.phase) {
                    case 0:
                        LOG.debug("--Seed used for comp : " + this.compSettings.seed);
                        LOG.debug("--compression phase " + this.phase + " Classify  : " + getLastTimePhase());
                        LOG.debug("--Individual Columns Estimated Compression: " + this._stats.estimatedSizeCols);
                        if (this.mb instanceof CompressedMatrixBlock) {
                            LOG.debug("--Recompressing already compressed MatrixBlock");
                            break;
                        }
                        break;
                    case 1:
                        LOG.debug("--compression phase " + this.phase + " Grouping  : " + getLastTimePhase());
                        LOG.debug("Grouping using: " + this.compSettings.columnPartitioner);
                        LOG.debug("Cost Calculated using: " + this.costEstimator);
                        LOG.debug("--Cocoded Columns estimated Compression:" + this._stats.estimatedSizeCoCoded);
                        if (this.compressionGroups.getInfo().size() >= 1000) {
                            LOG.debug("--CoCoded produce many columns but the first says:\n" + this.compressionGroups.getInfo().get(0));
                            break;
                        } else {
                            LOG.debug("--Cocoded Columns estimated nr distinct:" + this.compressionGroups.getEstimatedDistinct());
                            LOG.debug("--Cocoded Columns nr columns           :" + this.compressionGroups.getNrColumnsString());
                            break;
                        }
                    case 2:
                        LOG.debug("--compression phase " + this.phase + " Transpose : " + getLastTimePhase());
                        LOG.debug("Did transpose: " + this.compSettings.transposed);
                        break;
                    case 3:
                        LOG.debug("--compression phase " + this.phase + " Compress  : " + getLastTimePhase());
                        LOG.debug("--compressed initial actual size:" + this._stats.compressedInitialSize);
                        break;
                    case 4:
                    default:
                        LOG.debug("--num col groups:    " + this.res.getColGroups().size());
                        LOG.debug("--compression phase  " + this.phase + " Cleanup   : " + getLastTimePhase());
                        LOG.debug("--col groups types   " + this._stats.getGroupsTypesString());
                        LOG.debug("--col groups sizes   " + this._stats.getGroupsSizesString());
                        LOG.debug("--input was compressed " + (this.mb instanceof CompressedMatrixBlock));
                        LOG.debug(String.format("--dense size:        %16d", Long.valueOf(this._stats.denseSize)));
                        LOG.debug(String.format("--sparse size:       %16d", Long.valueOf(this._stats.sparseSize)));
                        LOG.debug(String.format("--original size:     %16d", Long.valueOf(this._stats.originalSize)));
                        LOG.debug(String.format("--compressed size:   %16d", Long.valueOf(this._stats.compressedSize)));
                        LOG.debug(String.format("--compression ratio: %4.3f", Double.valueOf(this._stats.getRatio())));
                        LOG.debug(String.format("--Dense       ratio: %4.3f", Double.valueOf(this._stats.getDenseRatio())));
                        if (!(this.costEstimator instanceof MemoryCostEstimator)) {
                            LOG.debug(String.format("--original cost:     %5.2E", Double.valueOf(this._stats.originalCost)));
                            LOG.debug(String.format("--single col cost:   %5.2E", Double.valueOf(this._stats.estimatedCostCols)));
                            LOG.debug(String.format("--cocode cost:       %5.2E", Double.valueOf(this._stats.estimatedCostCoCoded)));
                            LOG.debug(String.format("--actual cost:       %5.2E", Double.valueOf(this._stats.compressedCost)));
                            LOG.debug(String.format("--relative cost:     %1.4f", Double.valueOf(this._stats.compressedCost / this._stats.originalCost)));
                        }
                        logLengths();
                        break;
                }
            } else if (this.phase == 4) {
                LOG.debug(this._stats);
            }
        }
        this.phase++;
    }

    private void logLengths() {
        if (this.compressionGroups != null && this.compressionGroups.getInfo().size() < 1000) {
            int[] iArr = new int[this.res.getColGroups().size()];
            int i = 0;
            Iterator<AColGroup> it = this.res.getColGroups().iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                iArr[i2] = it.next().getNumValues();
            }
            LOG.debug("--compressed colGroup dictionary sizes: " + Arrays.toString(iArr));
            LOG.debug("--compressed colGroup nr columns      : " + constructNrColumnString(this.res.getColGroups()));
        }
        if (LOG.isTraceEnabled()) {
            for (AColGroup aColGroup : this.res.getColGroups()) {
                if (aColGroup.estimateInMemorySize() < 1000) {
                    LOG.trace(aColGroup);
                } else {
                    Log log = LOG;
                    String simpleName = aColGroup.getClass().getSimpleName();
                    long estimateInMemorySize = aColGroup.estimateInMemorySize();
                    String str = aColGroup instanceof AColGroupValue ? "  numValues :" + ((AColGroupValue) aColGroup).getNumValues() : "";
                    aColGroup.getColIndices();
                    log.trace("--colGroups type       : " + simpleName + " size: " + estimateInMemorySize + log + "  colIndexes : " + str);
                }
            }
        }
    }

    private void setNextTimePhase(double d) {
        this.lastPhase = d;
    }

    private double getLastTimePhase() {
        return this.lastPhase;
    }

    private Pair<MatrixBlock, CompressionStatistics> createEmpty() {
        LOG.info("Empty input to compress, returning a compressed Matrix block with empty column group");
        this.res = new CompressedMatrixBlock(this.mb.getNumRows(), this.mb.getNumColumns());
        this.res.allocateColGroup(ColGroupEmpty.create(this.mb.getNumColumns()));
        this.res.setNonZeros(0L);
        this._stats.compressedSize = this.res.getInMemorySize();
        this._stats.compressedCost = this.costEstimator.getCost(this.res.getColGroups(), this.res.getNumRows());
        this._stats.setColGroupsCounts(this.res.getColGroups());
        this.phase = 4;
        logPhase();
        return new ImmutablePair(this.res, this._stats);
    }

    private static String constructNrColumnString(List<AColGroup> list) {
        StringBuilder sb = new StringBuilder();
        sb.append("[");
        sb.append(list.get(0).getNumCols());
        for (int i = 1; i < list.size(); i++) {
            sb.append(", " + list.get(i).getNumCols());
        }
        sb.append("]");
        return sb.toString();
    }
}
