package org.apache.sysds.api;

import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.cert.CertificateException;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Map;
import java.util.Scanner;
import org.apache.commons.cli.AlreadySelectedException;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.LanguageException;
import org.apache.sysds.parser.ParseException;
import org.apache.sysds.parser.ParserFactory;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorker;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.util.IDHandler;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContextPool;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.runtime.util.LocalFileUtils;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Explain;
import org.apache.sysds.utils.NativeHelper;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/api/DMLScript.class */
public class DMLScript {
    private static Types.ExecMode EXEC_MODE = DMLOptions.defaultOptions.execMode;
    public static boolean STATISTICS = DMLOptions.defaultOptions.stats;
    public static boolean JMLC_MEM_STATISTICS = false;
    public static int STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount;
    public static int STATISTICS_MAX_WRAP_LEN = 30;
    public static boolean FED_STATISTICS = DMLOptions.defaultOptions.fedStats;
    public static int FED_STATISTICS_COUNT = DMLOptions.defaultOptions.fedStatsCount;
    public static Explain.ExplainType EXPLAIN = DMLOptions.defaultOptions.explainType;
    public static String DML_FILE_PATH_ANTLR_PARSER = DMLOptions.defaultOptions.filePath;
    public static String FLOATING_POINT_PRECISION = Statement.DOUBLE_VALUE_TYPE;
    public static boolean PRINT_GPU_MEMORY_INFO = false;
    public static long EVICTION_SHADOW_BUFFER_MAX_BYTES = 0;
    public static long EVICTION_SHADOW_BUFFER_CURR_BYTES = 0;
    public static double GPU_MEMORY_UTILIZATION_FACTOR = 0.9d;
    public static String GPU_MEMORY_ALLOCATOR = "cuda";
    public static boolean LINEAGE = DMLOptions.defaultOptions.lineage;
    public static boolean LINEAGE_DEDUP = DMLOptions.defaultOptions.lineage_dedup;
    public static LineageCacheConfig.ReuseCacheType LINEAGE_REUSE = DMLOptions.defaultOptions.linReuseType;
    public static LineageCacheConfig.LineageCachePolicy LINEAGE_POLICY = DMLOptions.defaultOptions.linCachePolicy;
    public static boolean LINEAGE_ESTIMATE = DMLOptions.defaultOptions.lineage_estimate;
    public static boolean LINEAGE_DEBUGGER = DMLOptions.defaultOptions.lineage_debugger;
    public static boolean CHECK_PRIVACY = DMLOptions.defaultOptions.checkPrivacy;
    public static boolean USE_ACCELERATOR = DMLOptions.defaultOptions.gpu;
    public static boolean FORCE_ACCELERATOR = DMLOptions.defaultOptions.forceGPU;
    public static boolean SYNCHRONIZE_GPU = true;
    public static boolean EAGER_CUDA_FREE = false;
    public static boolean _suppressPrint2Stdout = false;
    public static boolean USE_LOCAL_SPARK_CONFIG = false;
    public static boolean _activeAM = false;
    public static boolean VALIDATOR_IGNORE_ISSUES = false;
    public static String _uuid = IDHandler.createDistributedUniqueID();
    private static final Log LOG = LogFactory.getLog(DMLScript.class.getName());

    public static String getUUID() {
        return _uuid;
    }

    public static void setUUID(String str) {
        _uuid = str;
    }

    public static boolean suppressPrint2Stdout() {
        return _suppressPrint2Stdout;
    }

    public static void setActiveAM() {
        _activeAM = true;
    }

    public static boolean isActiveAM() {
        return _activeAM;
    }

    public static void main(String[] strArr) {
        try {
            Configuration configuration = new Configuration(ConfigurationManager.getCachedJobConf());
            executeScript(configuration, new GenericOptionsParser(configuration, strArr).getRemainingArgs());
        } catch (Exception e) {
            errorPrint(e);
            for (String str : strArr) {
                if (str.trim().contains("-debug")) {
                    e.printStackTrace();
                }
            }
        }
    }

    public static boolean executeScript(Configuration configuration, String[] strArr) throws IOException, ParseException, DMLScriptException {
        Types.ExecMode execMode = EXEC_MODE;
        Explain.ExplainType explainType = EXPLAIN;
        try {
            DMLOptions parseCLArguments = DMLOptions.parseCLArguments(strArr);
            try {
                STATISTICS = parseCLArguments.stats;
                STATISTICS_COUNT = parseCLArguments.statsCount;
                FED_STATISTICS = parseCLArguments.fedStats;
                FED_STATISTICS_COUNT = parseCLArguments.fedStatsCount;
                JMLC_MEM_STATISTICS = parseCLArguments.memStats;
                USE_ACCELERATOR = parseCLArguments.gpu;
                FORCE_ACCELERATOR = parseCLArguments.forceGPU;
                EXPLAIN = parseCLArguments.explainType;
                EXEC_MODE = parseCLArguments.execMode;
                LINEAGE = parseCLArguments.lineage;
                LINEAGE_DEDUP = parseCLArguments.lineage_dedup;
                LINEAGE_REUSE = parseCLArguments.linReuseType;
                LINEAGE_POLICY = parseCLArguments.linCachePolicy;
                LINEAGE_ESTIMATE = parseCLArguments.lineage_estimate;
                CHECK_PRIVACY = parseCLArguments.checkPrivacy;
                LINEAGE_DEBUGGER = parseCLArguments.lineage_debugger;
                String str = parseCLArguments.configFile;
                boolean z = parseCLArguments.filePath != null;
                String str2 = z ? parseCLArguments.filePath : parseCLArguments.script;
                if (parseCLArguments.help) {
                    new HelpFormatter().printHelp("systemds", parseCLArguments.options);
                    setGlobalExecMode(execMode);
                    EXPLAIN = explainType;
                    return true;
                }
                if (parseCLArguments.clean) {
                    cleanSystemDSWorkspace();
                    setGlobalExecMode(execMode);
                    EXPLAIN = explainType;
                    return true;
                }
                if (parseCLArguments.fedWorker) {
                    loadConfiguration(str);
                    try {
                        new FederatedWorker(parseCLArguments.fedWorkerPort).run();
                    } catch (CertificateException e) {
                        e.printStackTrace();
                    }
                    return true;
                }
                LineageCacheConfig.setConfig(LINEAGE_REUSE);
                LineageCacheConfig.setCachePolicy(LINEAGE_POLICY);
                LineageCacheConfig.setEstimator(LINEAGE_ESTIMATE);
                String readDMLScript = readDMLScript(z, str2);
                Map<String, String> map = parseCLArguments.argVals;
                DML_FILE_PATH_ANTLR_PARSER = parseCLArguments.filePath;
                printInvocationInfo(str2, str, map);
                execute(readDMLScript, str, map, strArr);
                setGlobalExecMode(execMode);
                EXPLAIN = explainType;
                return true;
            } finally {
                setGlobalExecMode(execMode);
                EXPLAIN = explainType;
            }
        } catch (AlreadySelectedException e2) {
            LOG.error("Mutually exclusive options were selected. " + e2.getMessage());
            return false;
        } catch (org.apache.commons.cli.ParseException e3) {
            LOG.error("Parsing Exception " + e3.getMessage());
            return false;
        }
    }

    public static String readDMLScript(boolean z, String str) throws IOException {
        String next;
        if (z) {
            if (str == null) {
                throw new LanguageException("DML script path was not specified!");
            }
            StringBuilder sb = new StringBuilder();
            BufferedReader bufferedReader = null;
            try {
                try {
                    if (str.startsWith("hdfs:") || str.startsWith("gpfs:") || IOUtilFunctions.isObjectStoreFileScheme(new Path(str))) {
                        Path path = new Path(str);
                        bufferedReader = new BufferedReader(new InputStreamReader(IOUtilFunctions.getFileSystem(path).open(path)));
                    } else {
                        bufferedReader = new BufferedReader(new FileReader(str));
                    }
                    while (true) {
                        String readLine = bufferedReader.readLine();
                        if (readLine == null) {
                            break;
                        }
                        sb.append(readLine);
                        sb.append(ProgramConverter.NEWLINE);
                    }
                    next = sb.toString();
                } catch (IOException e) {
                    LOG.error("Failed to read the script from the file system", e);
                    throw e;
                }
            } finally {
                IOUtilFunctions.closeSilently(bufferedReader);
            }
        } else {
            if (str == null) {
                throw new LanguageException("DML script was not specified!");
            }
            Scanner scanner = new Scanner(new ByteArrayInputStream(str.getBytes()));
            Throwable th = null;
            try {
                try {
                    next = scanner.useDelimiter("\\A").next();
                    if (scanner != null) {
                        if (0 != 0) {
                            try {
                                scanner.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            scanner.close();
                        }
                    }
                } finally {
                }
            } catch (Throwable th3) {
                if (scanner != null) {
                    if (th != null) {
                        try {
                            scanner.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        scanner.close();
                    }
                }
                throw th3;
            }
        }
        return next;
    }

    private static void loadConfiguration(String str) throws IOException {
        DMLConfig readConfigurationFile = DMLConfig.readConfigurationFile(str);
        ConfigurationManager.setGlobalConfig(readConfigurationFile);
        ConfigurationManager.setGlobalConfig(OptimizerUtils.constructCompilerConfig(readConfigurationFile));
        LOG.debug("\nDML config: \n" + readConfigurationFile.getConfigInfo());
        setGlobalFlags(readConfigurationFile);
    }

    private static void execute(String str, String str2, Map<String, String> map, String[] strArr) throws IOException {
        printStartExecInfo(str);
        loadConfiguration(str2);
        configureCodeGen();
        Statistics.startCompileTimer();
        DMLProgram parse = ParserFactory.createParser().parse(DML_FILE_PATH_ANTLR_PARSER, str, map);
        DMLTranslator dMLTranslator = new DMLTranslator(parse);
        dMLTranslator.liveVariableAnalysis(parse);
        dMLTranslator.validateParseTree(parse);
        dMLTranslator.constructHops(parse);
        initHadoopExecution(ConfigurationManager.getDMLConfig());
        dMLTranslator.rewriteHopsDAG(parse);
        dMLTranslator.constructLops(parse);
        Program runtimeProgram = dMLTranslator.getRuntimeProgram(parse, ConfigurationManager.getDMLConfig());
        Explain.ExplainCounts countDistributedOperations = Explain.countDistributedOperations(runtimeProgram);
        Statistics.resetNoOfCompiledJobs(countDistributedOperations.numJobs);
        if (EXPLAIN != Explain.ExplainType.NONE) {
            System.out.println(Explain.display(parse, runtimeProgram, EXPLAIN, countDistributedOperations));
        }
        Statistics.stopCompileTimer();
        ExecutionContext executionContext = null;
        try {
            executionContext = ExecutionContextFactory.createContext(runtimeProgram);
            ScriptExecutorUtils.executeRuntimeProgram(runtimeProgram, executionContext, ConfigurationManager.getDMLConfig(), STATISTICS ? STATISTICS_COUNT : 0, null);
            cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
            if (executionContext != null && (executionContext instanceof SparkExecutionContext)) {
                ((SparkExecutionContext) executionContext).close();
            }
            LOG.info("END DML run " + getDateTime());
        } catch (Throwable th) {
            cleanupHadoopExecution(ConfigurationManager.getDMLConfig());
            if (executionContext != null && (executionContext instanceof SparkExecutionContext)) {
                ((SparkExecutionContext) executionContext).close();
            }
            LOG.info("END DML run " + getDateTime());
            throw th;
        }
    }

    public static void setGlobalFlags(DMLConfig dMLConfig) {
        GPUContextPool.AVAILABLE_GPUS = dMLConfig.getTextValue(DMLConfig.AVAILABLE_GPUS);
        STATISTICS_MAX_WRAP_LEN = dMLConfig.getIntValue(DMLConfig.STATS_MAX_WRAP_LEN);
        NativeHelper.initialize(dMLConfig.getTextValue(DMLConfig.NATIVE_BLAS_DIR), dMLConfig.getTextValue(DMLConfig.NATIVE_BLAS).trim());
        SYNCHRONIZE_GPU = dMLConfig.getBooleanValue(DMLConfig.SYNCHRONIZE_GPU);
        EAGER_CUDA_FREE = dMLConfig.getBooleanValue(DMLConfig.EAGER_CUDA_FREE);
        PRINT_GPU_MEMORY_INFO = dMLConfig.getBooleanValue(DMLConfig.PRINT_GPU_MEMORY_INFO);
        GPU_MEMORY_UTILIZATION_FACTOR = dMLConfig.getDoubleValue(DMLConfig.GPU_MEMORY_UTILIZATION_FACTOR);
        GPU_MEMORY_ALLOCATOR = dMLConfig.getTextValue(DMLConfig.GPU_MEMORY_ALLOCATOR);
        if (GPU_MEMORY_UTILIZATION_FACTOR < DataExpression.DEFAULT_DELIM_FILL_VALUE) {
            throw new RuntimeException("Incorrect value (" + GPU_MEMORY_UTILIZATION_FACTOR + ") for the configuration:" + DMLConfig.GPU_MEMORY_UTILIZATION_FACTOR);
        }
        FLOATING_POINT_PRECISION = dMLConfig.getTextValue(DMLConfig.FLOATING_POINT_PRECISION);
        LibMatrixCUDA.resetFloatingPointPrecision();
        if (FLOATING_POINT_PRECISION.equals(Statement.DOUBLE_VALUE_TYPE)) {
            EVICTION_SHADOW_BUFFER_MAX_BYTES = 0L;
            return;
        }
        double doubleValue = dMLConfig.getDoubleValue(DMLConfig.EVICTION_SHADOW_BUFFERSIZE);
        if (doubleValue < DataExpression.DEFAULT_DELIM_FILL_VALUE || doubleValue > 1.0d) {
            throw new RuntimeException("Incorrect value (" + doubleValue + ") for the configuration:" + DMLConfig.EVICTION_SHADOW_BUFFERSIZE);
        }
        EVICTION_SHADOW_BUFFER_MAX_BYTES = (long) (InfrastructureAnalyzer.getLocalMaxMemory() * doubleValue);
        if (EVICTION_SHADOW_BUFFER_MAX_BYTES <= 0 || EVICTION_SHADOW_BUFFER_CURR_BYTES <= EVICTION_SHADOW_BUFFER_MAX_BYTES) {
            return;
        }
        System.out.println("WARN: Cannot use the shadow buffer due to potentially cached GPU objects. Current shadow buffer size (in bytes):" + EVICTION_SHADOW_BUFFER_CURR_BYTES + " > Max shadow buffer size (in bytes):" + EVICTION_SHADOW_BUFFER_MAX_BYTES);
    }

    public static void initHadoopExecution(DMLConfig dMLConfig) throws IOException, ParseException, DMLRuntimeException {
        HDFSTool.createDirIfNotExistOnHDFS(dMLConfig.getTextValue(DMLConfig.SCRATCH_SPACE), DMLConfig.DEFAULT_SHARED_DIR_PERMISSION);
        cleanupHadoopExecution(dMLConfig);
        LocalFileUtils.createWorkingDirectory();
        CacheableData.initCaching();
        Statistics.resetNoOfExecutedJobs();
        if (STATISTICS) {
            Statistics.reset();
        }
        if (CHECK_PRIVACY) {
            CheckedConstraintsLog.reset();
        }
    }

    public static void cleanupHadoopExecution(DMLConfig dMLConfig) throws IOException, ParseException {
        String str = Lop.FILE_SEPARATOR + Lop.PROCESS_PREFIX + getUUID();
        FederatedData.clearFederatedWorkers();
        SparkUtils.shutdownPool();
        HDFSTool.deleteFileIfExistOnHDFS(dMLConfig.getTextValue(DMLConfig.SCRATCH_SPACE) + str);
        CacheableData.cleanupCacheDir();
        LocalFileUtils.cleanupWorkingDirectory();
    }

    private static void printInvocationInfo(String str, String str2, Map<String, String> map) {
        LOG.debug("****** args to DML Script ******\nUUID: " + getUUID() + "\nSCRIPT PATH: " + str + "\nRUNTIME: " + getGlobalExecMode() + "\nBUILTIN CONFIG: " + DMLConfig.DEFAULT_SYSTEMDS_CONFIG_FILEPATH + "\nOPTIONAL CONFIG: " + str2 + ProgramConverter.NEWLINE);
        if (map.isEmpty()) {
            return;
        }
        LOG.debug("Script arguments are: \n");
        for (int i = 1; i <= map.size(); i++) {
            LOG.debug("Script argument $" + i + " = " + map.get("$" + i));
        }
    }

    private static void printStartExecInfo(String str) {
        LOG.info("BEGIN DML run " + getDateTime());
        LOG.debug("DML script: \n" + str);
    }

    private static String getDateTime() {
        return new SimpleDateFormat("MM/dd/yyyy HH:mm:ss").format(new Date());
    }

    private static void cleanSystemDSWorkspace() {
        try {
            DMLConfig readConfigurationFile = DMLConfig.readConfigurationFile(null);
            String textValue = readConfigurationFile.getTextValue(DMLConfig.SCRATCH_SPACE);
            if (textValue != null) {
                HDFSTool.deleteFileIfExistOnHDFS(textValue);
            }
            String textValue2 = readConfigurationFile.getTextValue(DMLConfig.LOCAL_TMP_DIR);
            if (textValue2 != null) {
                LocalFileUtils.cleanupRcWorkingDirectory(textValue2);
            }
        } catch (Exception e) {
            throw new DMLException("Failed to run SystemDS workspace cleanup.", e);
        }
    }

    public static Types.ExecMode getGlobalExecMode() {
        return EXEC_MODE;
    }

    public static void setGlobalExecMode(Types.ExecMode execMode) {
        EXEC_MODE = execMode;
    }

    public static void errorPrint(Exception exc) {
        StringBuilder sb = new StringBuilder();
        sb.append("\u001b[31m\n");
        sb.append("An Error Occured : ");
        sb.append(ProgramConverter.NEWLINE);
        sb.append(StringUtils.leftPad(exc.getClass().getSimpleName(), 25));
        sb.append(" -- ");
        sb.append(exc.getMessage());
        Throwable cause = exc.getCause();
        while (true) {
            Throwable th = cause;
            if (th == null) {
                sb.append("\n\u001b[0m");
                System.out.println(sb.toString());
                return;
            } else {
                sb.append(ProgramConverter.NEWLINE);
                sb.append(StringUtils.leftPad(th.getClass().getSimpleName(), 25));
                sb.append(" -- ");
                sb.append(th.getMessage());
                cause = th.getCause();
            }
        }
    }

    private static void configureCodeGen() {
        if (ConfigurationManager.isCodegenEnabled()) {
            try {
                SpoofCompiler.loadNativeCodeGenerator(SpoofCompiler.GeneratorAPI.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.CODEGEN_API).toUpperCase()));
            } catch (Exception e) {
                LOG.error("Failed to load native cuda codegen library\n" + e);
            }
        }
    }
}
