package org.apache.sysds.runtime.controlprogram.parfor.opt;

import java.util.ArrayList;
import java.util.Collection;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.runtime.controlprogram.parfor.opt.OptNode;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator.class */
public abstract class CostEstimator {
    protected static final Log LOG = LogFactory.getLog(CostEstimator.class.getName());
    public static final double DEFAULT_EST_PARALLELISM = 1.0d;
    public static final long FACTOR_NUM_ITERATIONS = 10;
    public static final double DEFAULT_TIME_ESTIMATE = 5.0d;
    public static final double DEFAULT_MEM_ESTIMATE_CP = 1024.0d;
    public static final double DEFAULT_MEM_ESTIMATE_SP = 2.097152E7d;
    protected boolean _inclCondPart = false;
    protected Collection<String> _exclVars = null;
    protected ExcludeType _exclType = ExcludeType.NONE;

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator$DataFormat.class */
    public enum DataFormat {
        DENSE,
        SPARSE
    }

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator$ExcludeType.class */
    public enum ExcludeType {
        NONE,
        SHARED_READ,
        RESULT_LIX
    }

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator$TestMeasure.class */
    public enum TestMeasure {
        EXEC_TIME,
        MEMORY_USAGE
    }

    public abstract double getLeafNodeEstimate(TestMeasure testMeasure, OptNode optNode);

    public abstract double getLeafNodeEstimate(TestMeasure testMeasure, OptNode optNode, LopProperties.ExecType execType);

    public double getEstimate(TestMeasure testMeasure, OptNode optNode) {
        return getEstimate(testMeasure, optNode, (LopProperties.ExecType) null);
    }

    public double getEstimate(TestMeasure testMeasure, OptNode optNode, boolean z) {
        this._inclCondPart = z;
        double estimate = getEstimate(testMeasure, optNode, (LopProperties.ExecType) null);
        this._inclCondPart = false;
        return estimate;
    }

    public double getEstimate(TestMeasure testMeasure, OptNode optNode, boolean z, Collection<String> collection, ExcludeType excludeType) {
        this._inclCondPart = z;
        this._exclVars = collection;
        this._exclType = excludeType;
        double estimate = getEstimate(testMeasure, optNode, (LopProperties.ExecType) null);
        this._inclCondPart = false;
        this._exclVars = null;
        this._exclType = ExcludeType.NONE;
        return estimate;
    }

    public double getEstimate(TestMeasure testMeasure, OptNode optNode, LopProperties.ExecType execType) {
        double d = -1.0d;
        if (!optNode.isLeaf()) {
            switch (testMeasure) {
                case EXEC_TIME:
                    switch (optNode.getNodeType()) {
                        case GENERIC:
                        case FUNCCALL:
                            d = getSumEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case IF:
                            if (optNode.getChilds().size() != 2) {
                                d = getMaxEstimate(testMeasure, optNode.getChilds(), execType);
                                break;
                            } else {
                                d = getWeightedEstimate(testMeasure, optNode.getChilds(), execType);
                                break;
                            }
                        case WHILE:
                            d = 10.0d * getSumEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case FOR:
                            d = (optNode.getParam(OptNode.ParamType.NUM_ITERATIONS) != null ? Long.parseLong(r0) : 10.0d) * getSumEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case PARFOR:
                            d = ((optNode.getParam(OptNode.ParamType.NUM_ITERATIONS) != null ? Long.parseLong(r0) : 10.0d) * getSumEstimate(testMeasure, optNode.getChilds(), execType)) / Math.max(optNode.getK(), 1);
                            break;
                    }
                case MEMORY_USAGE:
                    switch (optNode.getNodeType()) {
                        case GENERIC:
                        case FUNCCALL:
                        case IF:
                        case WHILE:
                        case FOR:
                            d = getMaxEstimate(testMeasure, optNode.getChilds(), execType);
                            break;
                        case PARFOR:
                            if (optNode.getExecType() != OptNode.ExecType.SPARK) {
                                if (optNode.getExecType() == OptNode.ExecType.CP || optNode.getExecType() == null) {
                                    d = getMaxEstimate(testMeasure, optNode.getChilds(), execType) * Math.max(optNode.getK(), 1);
                                    break;
                                }
                            } else {
                                d = getMaxEstimate(testMeasure, optNode.getChilds(), execType);
                                break;
                            }
                            break;
                    }
            }
        } else {
            d = (!this._inclCondPart || optNode.getParam(OptNode.ParamType.DATA_PARTITION_COND_MEM) == null) ? execType != null ? getLeafNodeEstimate(testMeasure, optNode, execType) : getLeafNodeEstimate(testMeasure, optNode) : Double.parseDouble(optNode.getParam(OptNode.ParamType.DATA_PARTITION_COND_MEM));
        }
        return d;
    }

    protected double getDefaultEstimate(TestMeasure testMeasure) {
        switch (testMeasure) {
            case EXEC_TIME:
                return 5.0d;
            case MEMORY_USAGE:
                return 1024.0d;
            default:
                return -1.0d;
        }
    }

    protected double getMaxEstimate(TestMeasure testMeasure, ArrayList<OptNode> arrayList, LopProperties.ExecType execType) {
        return arrayList.stream().mapToDouble(optNode -> {
            return getEstimate(testMeasure, optNode, execType);
        }).max().orElse(Double.NEGATIVE_INFINITY);
    }

    protected double getSumEstimate(TestMeasure testMeasure, ArrayList<OptNode> arrayList, LopProperties.ExecType execType) {
        return arrayList.stream().mapToDouble(optNode -> {
            return getEstimate(testMeasure, optNode, execType);
        }).sum();
    }

    protected double getWeightedEstimate(TestMeasure testMeasure, ArrayList<OptNode> arrayList, LopProperties.ExecType execType) {
        return arrayList.stream().mapToDouble(optNode -> {
            return getEstimate(testMeasure, optNode, execType);
        }).sum() / arrayList.size();
    }
}
