package org.apache.sysds.runtime.compress.cost;

import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.class */
public class ComputationCostEstimator extends ACostEstimate {
    private static final long serialVersionUID = -1205636215389161815L;
    private static final double cvThreshold = 0.2d;
    private final int _scans;
    private final int _decompressions;
    private final int _dictionaryOps;
    private final int _overlappingDecompressions;
    private final int _leftMultiplications;
    private final int _rightMultiplications;
    private final int _compressedMultiplication;
    private final boolean _isDensifying;

    /* JADX INFO: Access modifiers changed from: protected */
    public ComputationCostEstimator(InstructionTypeCounter instructionTypeCounter) {
        this._scans = instructionTypeCounter.scans;
        this._decompressions = instructionTypeCounter.decompressions;
        this._overlappingDecompressions = instructionTypeCounter.overlappingDecompressions;
        this._leftMultiplications = instructionTypeCounter.leftMultiplications;
        this._rightMultiplications = instructionTypeCounter.rightMultiplications;
        this._compressedMultiplication = instructionTypeCounter.compressedMultiplications;
        this._dictionaryOps = instructionTypeCounter.dictionaryOps;
        this._isDensifying = instructionTypeCounter.isDensifying;
        if (LOG.isDebugEnabled()) {
            LOG.debug(this);
        }
    }

    public ComputationCostEstimator(int i, int i2, int i3, int i4, int i5, int i6, int i7, boolean z) {
        this._scans = i;
        this._decompressions = i2;
        this._overlappingDecompressions = i3;
        this._leftMultiplications = i4;
        this._rightMultiplications = i5;
        this._compressedMultiplication = i6;
        this._dictionaryOps = i7;
        this._isDensifying = z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.compress.cost.ACostEstimate
    public double getCostSafe(CompressedSizeInfoColGroup compressedSizeInfoColGroup) {
        int numVals = compressedSizeInfoColGroup.getNumVals();
        int length = compressedSizeInfoColGroup.getColumns().length;
        int numRows = compressedSizeInfoColGroup.getNumRows();
        double tupleSparsity = (length < 3 || this._isDensifying) ? 1.0d : compressedSizeInfoColGroup.getTupleSparsity() + 1.0E-10d;
        return (!compressedSizeInfoColGroup.isEmpty() || this._isDensifying) ? (compressedSizeInfoColGroup.isEmpty() || compressedSizeInfoColGroup.isConst()) ? getCost(numRows, 1, length, 1, 1.0d) : ((double) compressedSizeInfoColGroup.getLargestOffInstances()) > cvThreshold ? getCost(numRows, numRows - compressedSizeInfoColGroup.getLargestOffInstances(), length, numVals, tupleSparsity) : getCost(numRows, numRows, length, numVals, tupleSparsity) : getCost(numRows, 1, length, 1, 1.0E-5d);
    }

    public double getCost(int i, int i2, int i3, int i4, double d) {
        double d2 = (i3 < 3 || d > 0.4d) ? 1.0d : d;
        double leftMultCost = DataExpression.DEFAULT_DELIM_FILL_VALUE + leftMultCost(i2, i, i3, i4, d2) + scanCost(i2, i3, i4, d2) + dictionaryOpsCost(i4, i3, d2) + rightMultCost(i3, i4, d2) + decompressionCost(i4, i3, i2, d2) + overlappingDecompressionCost(i2) + compressedMultiplicationCost(i2, i, i4, i3, d2) + 100.0d;
        if (leftMultCost < DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            throw new DMLCompressionException("Ivalid negative cost: " + leftMultCost);
        }
        return leftMultCost;
    }

    public boolean isDense() {
        return this._isDensifying;
    }

    @Override // org.apache.sysds.runtime.compress.cost.ACostEstimate
    public double getCost(MatrixBlock matrixBlock) {
        double numColumns = matrixBlock.getNumColumns();
        double numRows = matrixBlock.getNumRows();
        double sparsity = (numColumns < 3.0d || this._isDensifying) ? 1.0d : matrixBlock.getSparsity();
        double dictionaryOpsCost = DataExpression.DEFAULT_DELIM_FILL_VALUE + dictionaryOpsCost(numRows, numColumns, sparsity) + leftMultCost(DataExpression.DEFAULT_DELIM_FILL_VALUE, (numRows * numColumns * sparsity) + numColumns) + rightMultCost(numRows * numColumns * sparsity, numRows * numColumns) + scanCost(DataExpression.DEFAULT_DELIM_FILL_VALUE, numRows, numColumns, sparsity) + compressedMultiplicationCost(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE, numRows, numColumns, sparsity);
        if (dictionaryOpsCost < DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            throw new DMLCompressionException("Invalid negative cost : " + dictionaryOpsCost);
        }
        return dictionaryOpsCost;
    }

    @Override // org.apache.sysds.runtime.compress.cost.ACostEstimate
    public double getCost(AColGroup aColGroup, int i) {
        return aColGroup.getCost(this, i);
    }

    @Override // org.apache.sysds.runtime.compress.cost.ACostEstimate
    public boolean shouldSparsify() {
        return this._leftMultiplications > 0 || this._compressedMultiplication > 0 || this._rightMultiplications > 0;
    }

    private double dictionaryOpsCost(double d, double d2, double d3) {
        return this._dictionaryOps * d3 * d * d2 * 2.0d;
    }

    private double leftMultCost(double d, double d2, double d3, double d4, double d5) {
        return leftMultCost(Math.max(d, d2 / 10.0d) + (d4 * 2.0d), d5 * d4 * d3);
    }

    private double leftMultCost(double d, double d2) {
        return this._leftMultiplications * (d + d2);
    }

    private double rightMultCost(double d, double d2, double d3) {
        return rightMultCost(d3 * d2 * d, d);
    }

    private double rightMultCost(double d, double d2) {
        return this._rightMultiplications * (d + d2);
    }

    private double decompressionCost(double d, double d2, double d3, double d4) {
        return this._decompressions * d2 * d3 * d4;
    }

    private double overlappingDecompressionCost(double d) {
        return this._overlappingDecompressions * d;
    }

    private double scanCost(double d, double d2, double d3, double d4) {
        return this._scans * (d + (d2 * d3 * d4));
    }

    private double compressedMultiplicationCost(double d, double d2, double d3, double d4, double d5) {
        return this._compressedMultiplication * (Math.max(d, d2 / 10.0d) + (d3 * d4 * d5));
    }

    @Override // org.apache.sysds.runtime.compress.cost.ACostEstimate
    public String toString() {
        return super.toString() + " --- CostVector:[" + (this._scans + ",") + (this._decompressions + ",") + (this._overlappingDecompressions + ",") + (this._leftMultiplications + ",") + (this._rightMultiplications + ",") + (this._compressedMultiplication + ",") + (this._dictionaryOps + "]") + " Densifying:" + this._isDensifying;
    }
}
