package org.apache.sysds.runtime.controlprogram.parfor.opt;

import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimator;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptNode;
import org.apache.sysds.runtime.controlprogram.parfor.opt.Optimizer;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimatorHops.class */
public class CostEstimatorHops extends CostEstimator {
    public static final double DEFAULT_MEM_SP = 2.097152E7d;
    private OptTreePlanMappingAbstract _map;

    public CostEstimatorHops(OptTreePlanMappingAbstract optTreePlanMappingAbstract) {
        this._map = null;
        this._map = optTreePlanMappingAbstract;
    }

    @Override // org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimator
    public double getLeafNodeEstimate(CostEstimator.TestMeasure testMeasure, OptNode optNode) {
        if (optNode.getNodeType() != OptNode.NodeType.HOP) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        if (testMeasure != CostEstimator.TestMeasure.MEMORY_USAGE) {
            throw new DMLRuntimeException("Testmeasure " + testMeasure + " not supported by cost model " + Optimizer.CostModelType.STATIC_MEM_METRIC + ".");
        }
        Hop mappedHop = this._map.getMappedHop(optNode.getID());
        double memEstimate = (this._exclVars == null || this._exclType != CostEstimator.ExcludeType.SHARED_READ) ? mappedHop.getMemEstimate() : mappedHop.getInputOutputSize(this._exclVars);
        double d = OptimizerUtils.isSparkExecutionMode() ? 2.097152E7d : DataExpression.DEFAULT_DELIM_FILL_VALUE;
        boolean z = DMLScript.getGlobalExecMode() == Types.ExecMode.SINGLE_NODE || mappedHop.getForcedExecType() != null;
        if (memEstimate >= d) {
            if (mappedHop.getExecType() == LopProperties.ExecType.SPARK) {
                memEstimate = d + mappedHop.getSpBroadcastSize();
            } else if (mappedHop.getExecType() == LopProperties.ExecType.CP && memEstimate >= OptimizerUtils.getLocalMemBudget()) {
                if (!z) {
                    LOG.warn("Memory estimate larger than budget but CP exec type (op=" + mappedHop.getOpString() + ", name=" + mappedHop.getName() + ", memest=" + mappedHop.getMemEstimate() + ").");
                }
                memEstimate = d;
            } else if (mappedHop.getExecType() == null) {
                memEstimate = d;
            }
        }
        if (mappedHop.getForcedExecType() == LopProperties.ExecType.SPARK) {
            memEstimate = d;
        }
        if (memEstimate <= DataExpression.DEFAULT_DELIM_FILL_VALUE && !z) {
            LOG.warn("Cannot get memory estimate for hop (op=" + mappedHop.getOpString() + ", name=" + mappedHop.getName() + ", memest=" + mappedHop.getMemEstimate() + ").");
            memEstimate = 1024.0d;
        }
        double d2 = (this._exclVars != null && this._exclType == CostEstimator.ExcludeType.RESULT_LIX && (mappedHop instanceof LeftIndexingOp) && this._exclVars.contains(mappedHop.getName())) ? DataExpression.DEFAULT_DELIM_FILL_VALUE : memEstimate;
        if (LOG.isTraceEnabled()) {
            LOG.trace("Memory estimate " + mappedHop.getName() + ", " + mappedHop.getOpString() + "(" + optNode.getExecType() + ")=" + OptimizerRuleBased.toMB(d2));
        }
        return d2;
    }

    @Override // org.apache.sysds.runtime.controlprogram.parfor.opt.CostEstimator
    public double getLeafNodeEstimate(CostEstimator.TestMeasure testMeasure, OptNode optNode, LopProperties.ExecType execType) {
        if (optNode.getNodeType() != OptNode.NodeType.HOP) {
            return DataExpression.DEFAULT_DELIM_FILL_VALUE;
        }
        if (testMeasure != CostEstimator.TestMeasure.MEMORY_USAGE) {
            throw new DMLRuntimeException("Testmeasure " + testMeasure + " not supported by cost model " + Optimizer.CostModelType.STATIC_MEM_METRIC + ".");
        }
        Hop mappedHop = this._map.getMappedHop(optNode.getID());
        double memEstimate = mappedHop.getMemEstimate();
        if (execType != LopProperties.ExecType.CP) {
            memEstimate = 2.097152E7d;
        }
        if (memEstimate <= DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            memEstimate = 1024.0d;
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Memory estimate (forced exec type) " + mappedHop.getName() + ", " + mappedHop.getOpString() + "(" + optNode.getExecType() + ")=" + OptimizerRuleBased.toMB(memEstimate));
        }
        return memEstimate;
    }
}
