package org.apache.sysds.runtime.lineage;

import java.util.Arrays;
import java.util.Comparator;
import java.util.stream.Stream;
import org.antlr.v4.runtime.tree.xpath.XPath;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.lops.Append;
import org.apache.sysds.lops.AppendG;
import org.apache.sysds.lops.AppendGAlignedSP;
import org.apache.sysds.lops.AppendR;
import org.apache.sysds.lops.LeftIndex;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.MapMultChain;
import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.BinaryMatrixMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryScalarScalarCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DataGenCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ListIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
import org.apache.sysds.runtime.instructions.fed.ComputationFEDInstruction;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysds.runtime.instructions.spark.CpmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageCacheConfig.class */
public class LineageCacheConfig {
    private static String[] REUSE_OPCODES;
    private static String[] CHKPOINT_OPCODES;
    protected static final double CPU_CACHE_FRAC = 0.05d;
    public static final double MIN_SPILL_TIME_ESTIMATE = 10.0d;
    public static final double MIN_SPILL_DATA = 2.0d;
    private static final String[] OPCODES = {"tsmm", "ba+*", XPath.WILDCARD, Lop.FILE_SEPARATOR, "+", "||", "nrow", "ncol", "round", "exp", ParForStatementBlock.OPT_LOG, RightIndex.OPCODE, LeftIndex.OPCODE, "groupedagg", "r'", "solve", "spoof", "isna", "uamean", "max", "min", "ifelse", ProgramConverter.DASH, "sqrt", "<", ">", "uak+", "<=", "^", "uamax", "uark+", "uacmean", "eigen", "ctable", "ctableexpand", "replace", "^2", "*2", "uack+", TernaryAggregate.OPCODE_RC, "uacsqk+", "uark+", "n+", "uarimax", SortKeys.OPCODE, PickByCount.OPCODE, "transformapply", "uarmax", "n+", "-*", "castdtm", "lowertri", "1-*", "prefetch", MapMult.OPCODE, "contains", MapMultChain.OPCODE_CP, MapMultChain.OPCODE, "+*", "==", "rmempty", "conv2d_bias_add", "relu_maxpooling", "maxpooling", "batch_norm2d", "avgpooling", "softmax"};
    private static final String[] PERSIST_OPCODES1 = {"cpmm", "rmm", PMMJ.OPCODE, "zipmm", "rev", "rshape", "rsort", ProgramConverter.DASH, XPath.WILDCARD, "+", Lop.FILE_SEPARATOR, "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">", "<=", ">=", "&&", "||", "xor", "max", "min", "rmempty", AppendR.OPCODE, AppendG.OPCODE, AppendGAlignedSP.OPCODE, "rbind", "cbind", "nmin", "nmax", "n+", "ctable", "ucumack+", "ucumac*", "ucumacmin", "ucumacmax", SortKeys.OPCODE, PickByCount.OPCODE};
    private static final String[] PERSIST_OPCODES2 = {MapMult.OPCODE, "isna", LeftIndex.OPCODE};
    private static final String[] GPU_OPCODE_HEAVY = {"conv2d_bias_add", "relu_maxpooling", "maxpooling", "batch_norm2d", "avgpooling"};
    private static ReuseCacheType _cacheType = null;
    private static CachedItemHead _itemH = null;
    private static CachedItemTail _itemT = null;
    private static boolean _compilerAssistedRW = false;
    private static boolean _onlyEstimate = false;
    private static boolean _reuseLineageTraces = true;
    private static boolean DELAYED_CACHING = false;
    private static boolean DELAYED_CACHING_GPU = true;
    private static boolean DELAYED_CACHING_RDD = true;
    public static double FSREAD_DENSE = 500.0d;
    public static double FSREAD_SPARSE = 400.0d;
    public static double FSWRITE_DENSE = 450.0d;
    public static double FSWRITE_SPARSE = 225.0d;
    public static double D2HCOPYBANDWIDTH = 1500.0d;
    public static double D2HMAXBANDWIDTH = 8192.0d;
    private static LineageCachePolicy _cachepolicy = null;
    protected static double[] WEIGHTS = {1.0d, DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE};
    public static boolean GPU2HOSTEVICTION = false;
    protected static Comparator<LineageCacheEntry> LineageCacheComparator = (lineageCacheEntry, lineageCacheEntry2) -> {
        int i = 0;
        if (lineageCacheEntry.score == lineageCacheEntry2.score) {
            switch (_cachepolicy) {
                case LRU:
                case DAGHEIGHT:
                    double costNsize = lineageCacheEntry.getCostNsize();
                    double costNsize2 = lineageCacheEntry2.getCostNsize();
                    i = costNsize == costNsize2 ? Long.compare(lineageCacheEntry._key.getId(), lineageCacheEntry2._key.getId()) : costNsize < costNsize2 ? -1 : 1;
                    break;
                case COSTNSIZE:
                    double timestamp = lineageCacheEntry.getTimestamp();
                    double timestamp2 = lineageCacheEntry2.getTimestamp();
                    i = timestamp == timestamp2 ? Long.compare(lineageCacheEntry._key.getId(), lineageCacheEntry2._key.getId()) : timestamp < timestamp2 ? -1 : 1;
                    break;
            }
        } else {
            i = lineageCacheEntry.score < lineageCacheEntry2.score ? -1 : 1;
        }
        return i;
    };
    protected static Comparator<LineageCacheEntry> LineageGPUCacheComparator = (lineageCacheEntry, lineageCacheEntry2) -> {
        if (lineageCacheEntry._key.getId() == lineageCacheEntry2._key.getId()) {
            return 0;
        }
        return lineageCacheEntry.score == lineageCacheEntry2.score ? Long.compare(lineageCacheEntry._key.getId(), lineageCacheEntry2._key.getId()) : lineageCacheEntry.score < lineageCacheEntry2.score ? -1 : 1;
    };
    protected static boolean ENABLE_LOCAL_ONLY_RDD_CACHING = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageCacheConfig$CachedItemHead.class */
    public enum CachedItemHead {
        TSMM,
        ALL
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageCacheConfig$CachedItemTail.class */
    public enum CachedItemTail {
        CBIND,
        RBIND,
        INDEX,
        ALL
    }

    /* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageCacheConfig$LineageCachePolicy.class */
    public enum LineageCachePolicy {
        LRU,
        COSTNSIZE,
        DAGHEIGHT
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageCacheConfig$LineageCacheStatus.class */
    public enum LineageCacheStatus {
        EMPTY,
        NOTCACHED,
        TOCACHE,
        CACHED,
        SPILLED,
        RELOADED,
        PINNED,
        TOCACHEGPU,
        GPUCACHED,
        PERSISTEDRDD,
        TOPERSISTRDD,
        TOSPILL,
        TODELETE;

        public boolean canEvict() {
            return this == CACHED || this == RELOADED;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageCacheConfig$ReuseCacheType.class */
    public enum ReuseCacheType {
        REUSE_FULL,
        REUSE_PARTIAL,
        REUSE_MULTILEVEL,
        REUSE_HYBRID,
        NONE;

        public boolean isFullReuse() {
            return this == REUSE_FULL || this == REUSE_MULTILEVEL || this == REUSE_HYBRID;
        }

        public boolean isPartialReuse() {
            return this == REUSE_PARTIAL || this == REUSE_HYBRID;
        }

        public boolean isMultilevelReuse() {
            return this == REUSE_MULTILEVEL || this == REUSE_HYBRID;
        }

        public static boolean isNone() {
            return DMLScript.LINEAGE_REUSE == null || DMLScript.LINEAGE_REUSE == NONE;
        }
    }

    public static void setReusableOpcodes(String... strArr) {
        REUSE_OPCODES = strArr;
    }

    public static String[] getReusableOpcodes() {
        return REUSE_OPCODES;
    }

    public static void resetReusableOpcodes() {
        REUSE_OPCODES = OPCODES;
    }

    public static boolean isReusable(Instruction instruction, ExecutionContext executionContext) {
        return (((instruction instanceof ComputationCPInstruction) || (instruction instanceof ComputationFEDInstruction) || (instruction instanceof GPUInstruction) || (instruction instanceof ComputationSPInstruction)) && !(instruction instanceof ListIndexingCPInstruction) && !(instruction instanceof BinaryScalarScalarCPInstruction)) && ((ArrayUtils.contains(REUSE_OPCODES, instruction.getOpcode()) || ((instruction.getOpcode().equals(Append.OPCODE) && isVectorAppend(instruction, executionContext)) || instruction.getOpcode().startsWith("spoof") || ((instruction instanceof DataGenCPInstruction) && ((DataGenCPInstruction) instruction).isMatrixCall()))) || isReusableRDDType(instruction)) && !(((instruction instanceof MatrixIndexingCPInstruction) && executionContext.getMatrixObject(((ComputationCPInstruction) instruction).input1).getUpdateType().isInPlace()) || ((instruction instanceof BinaryMatrixMatrixCPInstruction) && ((BinaryMatrixMatrixCPInstruction) instruction).isInPlace())) && 0 == 0;
    }

    private static boolean isVectorAppend(Instruction instruction, ExecutionContext executionContext) {
        if (instruction instanceof ComputationFEDInstruction) {
            ComputationFEDInstruction computationFEDInstruction = (ComputationFEDInstruction) instruction;
            if (computationFEDInstruction.input1.isMatrix() && computationFEDInstruction.input2.isMatrix()) {
                return executionContext.getMatrixObject(computationFEDInstruction.input1).getNumColumns() == 1 || executionContext.getMatrixObject(computationFEDInstruction.input2).getNumColumns() == 1;
            }
            return false;
        }
        if (instruction instanceof ComputationCPInstruction) {
            ComputationCPInstruction computationCPInstruction = (ComputationCPInstruction) instruction;
            if (computationCPInstruction.input1.isMatrix() && computationCPInstruction.input2.isMatrix()) {
                return executionContext.getMatrixObject(computationCPInstruction.input1).getNumColumns() == 1 || executionContext.getMatrixObject(computationCPInstruction.input2).getNumColumns() == 1;
            }
            return false;
        }
        if (instruction instanceof ComputationSPInstruction) {
            ComputationSPInstruction computationSPInstruction = (ComputationSPInstruction) instruction;
            if (computationSPInstruction.input1.isMatrix() && computationSPInstruction.input2.isMatrix()) {
                return executionContext.getMatrixObject(computationSPInstruction.input1).getNumColumns() == 1 || executionContext.getMatrixObject(computationSPInstruction.input2).getNumColumns() == 1;
            }
            return false;
        }
        GPUInstruction gPUInstruction = (GPUInstruction) instruction;
        if (gPUInstruction._input1.isMatrix() && gPUInstruction._input2.isMatrix()) {
            return executionContext.getMatrixObject(gPUInstruction._input1).getNumColumns() == 1 || executionContext.getMatrixObject(gPUInstruction._input2).getNumColumns() == 1;
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean isReusableRDDType(Instruction instruction) {
        boolean z = instruction instanceof ComputationSPInstruction;
        boolean contains = ArrayUtils.contains(CHKPOINT_OPCODES, instruction.getOpcode());
        if (contains && (instruction instanceof MapmmSPInstruction) && ((MapmmSPInstruction) instruction).getAggType() == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            contains = false;
        }
        if (contains && (instruction instanceof CpmmSPInstruction) && ((CpmmSPInstruction) instruction).getAggType() == AggBinaryOp.SparkAggType.SINGLE_BLOCK) {
            contains = false;
        }
        return z && contains;
    }

    protected static boolean isShuffleOp(String str) {
        return ArrayUtils.contains(PERSIST_OPCODES1, str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static boolean isComputeGPUOps(String str) {
        return ArrayUtils.contains(GPU_OPCODE_HEAVY, str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static int getComputeGroup(String str) {
        return ArrayUtils.contains(PERSIST_OPCODES1, str) || ArrayUtils.contains(GPU_OPCODE_HEAVY, str) ? 2 : 1;
    }

    public static boolean isOutputFederated(Instruction instruction, Data data) {
        return (instruction instanceof ComputationFEDInstruction) && (instruction instanceof ComputationFEDInstruction) && (data instanceof MatrixObject) && ((MatrixObject) data).isFederated();
    }

    public static void setConfigTsmmCbind(ReuseCacheType reuseCacheType) {
        _cacheType = reuseCacheType;
        _itemH = CachedItemHead.TSMM;
        _itemT = CachedItemTail.CBIND;
    }

    public static void setConfig(ReuseCacheType reuseCacheType) {
        _cacheType = reuseCacheType;
    }

    public static void setConfig(ReuseCacheType reuseCacheType, CachedItemHead cachedItemHead, CachedItemTail cachedItemTail) {
        _cacheType = reuseCacheType;
        _itemH = cachedItemHead;
        _itemT = cachedItemTail;
    }

    public static void setCompAssRW(boolean z) {
        _compilerAssistedRW = z;
    }

    public static void shutdownReuse() {
        DMLScript.LINEAGE = false;
        DMLScript.LINEAGE_REUSE = ReuseCacheType.NONE;
    }

    public static void restartReuse(ReuseCacheType reuseCacheType) {
        DMLScript.LINEAGE = true;
        DMLScript.LINEAGE_REUSE = reuseCacheType;
    }

    public static ReuseCacheType getCacheType() {
        return _cacheType;
    }

    public static boolean isMultiLevelReuse() {
        return !ReuseCacheType.isNone() && _cacheType.isMultilevelReuse();
    }

    public static boolean getCompAssRW() {
        return _compilerAssistedRW;
    }

    public static void setReuseLineageTraces(boolean z) {
        _reuseLineageTraces = z;
    }

    public static boolean isLineageTraceReuse() {
        return _reuseLineageTraces;
    }

    public static boolean isDelayedCaching() {
        return DELAYED_CACHING;
    }

    public static boolean isDelayedCachingGPU() {
        return DELAYED_CACHING_GPU;
    }

    public static boolean isDelayedCachingRDD() {
        return DELAYED_CACHING_RDD;
    }

    public static void setCachePolicy(LineageCachePolicy lineageCachePolicy) {
        switch (lineageCachePolicy) {
            case LRU:
                WEIGHTS[0] = 0.0d;
                WEIGHTS[1] = 1.0d;
                WEIGHTS[2] = 0.0d;
                break;
            case COSTNSIZE:
                WEIGHTS[0] = 1.0d;
                WEIGHTS[1] = 0.0d;
                WEIGHTS[2] = 0.0d;
                break;
            case DAGHEIGHT:
                WEIGHTS[0] = 0.0d;
                WEIGHTS[1] = 0.0d;
                WEIGHTS[2] = 1.0d;
                break;
        }
        _cachepolicy = lineageCachePolicy;
    }

    public static LineageCachePolicy getCachePolicy() {
        return _cachepolicy;
    }

    public static void setEstimator(boolean z) {
        _onlyEstimate = z;
    }

    public static boolean isEstimator() {
        return _onlyEstimate;
    }

    public static boolean isTimeBased() {
        return WEIGHTS[1] > DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    public static boolean isCostNsize() {
        return WEIGHTS[0] > DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    public static boolean isDagHeightBased() {
        return WEIGHTS[2] > DataExpression.DEFAULT_DELIM_FILL_VALUE;
    }

    public static boolean isSetSpill() {
        return ConfigurationManager.getDMLConfig().getBooleanValue(DMLConfig.LINEAGECACHESPILL);
    }

    static {
        REUSE_OPCODES = new String[0];
        CHKPOINT_OPCODES = new String[0];
        REUSE_OPCODES = OPCODES;
        CHKPOINT_OPCODES = (String[]) Stream.concat(Arrays.stream(PERSIST_OPCODES1), Arrays.stream(PERSIST_OPCODES2)).toArray(i -> {
            return new String[i];
        });
        setCachePolicy(LineageCachePolicy.COSTNSIZE);
        setCompAssRW(true);
    }
}
