package org.apache.sysds.runtime.instructions.cp;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriter;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DMLTranslator;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.dml.DmlSyntacticValidator;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.ProgramConverter;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/cp/EvalNaryCPInstruction.class */
public class EvalNaryCPInstruction extends BuiltinNaryCPInstruction {
    private int _threadID;

    public EvalNaryCPInstruction(Operator operator, String str, String str2, CPOperand cPOperand, CPOperand... cPOperandArr) {
        super(operator, str, str2, cPOperand, cPOperandArr);
        this._threadID = -1;
    }

    @Override // org.apache.sysds.runtime.instructions.cp.CPInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixBlock convertToMatrixBlock;
        String stringValue = executionContext.getScalarInput(this.inputs[0]).getStringValue();
        String str = null;
        if (stringValue.contains(Program.KEY_DELIM)) {
            String[] splitFunctionKey = DMLProgram.splitFunctionKey(stringValue);
            stringValue = splitFunctionKey[1];
            str = splitFunctionKey[0];
        }
        CPOperand[] cPOperandArr = (CPOperand[]) Arrays.copyOfRange(this.inputs, 1, this.inputs.length);
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.output.getName());
        MatrixObject matrixObject = new MatrixObject(executionContext.getMatrixObject(this.output.getName()));
        Types.DataType dataType = cPOperandArr[0].getDataType().isList() ? Types.DataType.MATRIX : cPOperandArr[0].getDataType();
        String internalFName = Builtins.getInternalFName(stringValue, dataType);
        if (!executionContext.getProgram().containsFunctionProgramBlock(str, stringValue)) {
            if (!Builtins.contains(stringValue, true, false) && !executionContext.getProgram().containsFunctionProgramBlock(DMLProgram.BUILTIN_NAMESPACE, internalFName)) {
                throw new DMLRuntimeException("Function '" + DMLProgram.constructFunctionKey(str == null ? DMLProgram.DEFAULT_NAMESPACE : str, stringValue) + "' (called through eval) is non-existing.");
            }
            str = DMLProgram.BUILTIN_NAMESPACE;
            synchronized (executionContext.getProgram()) {
                if (!executionContext.getProgram().containsFunctionProgramBlock(str, internalFName)) {
                    compileFunctionProgramBlock(stringValue, dataType, executionContext.getProgram());
                }
            }
            stringValue = internalFName;
        }
        FunctionProgramBlock functionProgramBlock = executionContext.getProgram().getFunctionProgramBlock(str, stringValue, false);
        if (ProgramBlock.isThreadID(this._threadID)) {
            String str2 = stringValue + Lop.CP_CHILD_THREAD + this._threadID;
            if (!executionContext.getProgram().containsFunctionProgramBlock(str, str2, false)) {
                executionContext.getProgram().addFunctionProgramBlock(str, str2, ProgramConverter.createDeepCopyFunctionProgramBlock(functionProgramBlock, new HashSet(), new HashSet(), this._threadID), false);
            }
            functionProgramBlock = executionContext.getProgram().getFunctionProgramBlock(str, str2, false);
            stringValue = str2;
        }
        CPOperand[] cPOperandArr2 = null;
        LineageItem[] lineageItemArr = null;
        if (cPOperandArr.length == 1 && cPOperandArr[0].getDataType().isList() && (functionProgramBlock.getInputParams().size() != 1 || !functionProgramBlock.getInputParams().get(0).getDataType().isList())) {
            ListObject listObject = executionContext.getListObject(cPOperandArr[0]);
            checkValidArguments(listObject.getData(), listObject.getNames(), functionProgramBlock.getInputParamNames());
            if (listObject.isNamedList()) {
                listObject = reorderNamedListForFunctionCall(listObject, functionProgramBlock.getInputParamNames());
            }
            cPOperandArr2 = new CPOperand[listObject.getLength()];
            for (int i = 0; i < listObject.getLength(); i++) {
                Data data = listObject.getData(i);
                String nextUniqueVarname = Dag.getNextUniqueVarname(data.getDataType());
                executionContext.getVariables().put(nextUniqueVarname, data);
                cPOperandArr2[i] = new CPOperand(nextUniqueVarname, data);
            }
            cPOperandArr = cPOperandArr2;
            lineageItemArr = DMLScript.LINEAGE ? (LineageItem[]) listObject.getLineageItems().toArray(new LineageItem[listObject.getLength()]) : null;
        }
        new FunctionCallCPInstruction(str, stringValue, false, cPOperandArr, lineageItemArr, functionProgramBlock.getInputParamNames(), arrayList, "eval func").processInstruction(executionContext);
        Data variable = executionContext.getVariable(this.output);
        if (!(variable instanceof MatrixObject)) {
            if (variable instanceof ScalarObject) {
                convertToMatrixBlock = new MatrixBlock(((ScalarObject) variable).getDoubleValue());
            } else {
                if (!(variable instanceof FrameObject)) {
                    throw new DMLRuntimeException("Invalid eval return type: " + variable.getDataType().name() + " (valid: matrix/frame/scalar; where frames or scalars are converted to output matrices)");
                }
                convertToMatrixBlock = DataConverter.convertToMatrixBlock(((FrameObject) variable).acquireRead());
                executionContext.cleanupCacheableData((FrameObject) variable);
            }
            matrixObject.acquireModify(convertToMatrixBlock);
            matrixObject.release();
            executionContext.setVariable(this.output.getName(), matrixObject);
        }
        if (cPOperandArr2 != null) {
            for (CPOperand cPOperand : cPOperandArr2) {
                VariableCPInstruction.processRmvarInstruction(executionContext, cPOperand.getName());
            }
        }
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void updateInstructionThreadID(String str, String str2) {
        this._threadID = Integer.parseInt(str2.substring(Lop.CP_CHILD_THREAD.length()));
    }

    private static void compileFunctionProgramBlock(String str, Types.DataType dataType, Program program) {
        Map<String, FunctionStatementBlock> loadAndParseBuiltinFunction = DmlSyntacticValidator.loadAndParseBuiltinFunction(str, DMLProgram.BUILTIN_NAMESPACE);
        if (loadAndParseBuiltinFunction.isEmpty()) {
            throw new DMLRuntimeException("Failed to compile function '" + str + "'.");
        }
        DMLProgram dMLProg = program.getDMLProg() != null ? program.getDMLProg() : loadAndParseBuiltinFunction.get(Builtins.getInternalFName(str, dataType)).getDMLProg();
        Map<String, FunctionStatementBlock> map = dMLProg.getBuiltinFunctionDictionary() == null ? loadAndParseBuiltinFunction : (Map) loadAndParseBuiltinFunction.entrySet().stream().filter(entry -> {
            return !dMLProg.getBuiltinFunctionDictionary().containsFunction((String) entry.getKey());
        }).collect(Collectors.toMap(entry2 -> {
            return (String) entry2.getKey();
        }, entry3 -> {
            return (FunctionStatementBlock) entry3.getValue();
        }));
        for (Map.Entry<String, FunctionStatementBlock> entry4 : map.entrySet()) {
            dMLProg.createNamespace(DMLProgram.BUILTIN_NAMESPACE);
            dMLProg.addFunctionStatementBlock(DMLProgram.BUILTIN_NAMESPACE, entry4.getKey(), entry4.getValue());
            entry4.getValue().setDMLProg(dMLProg);
        }
        DMLTranslator dMLTranslator = new DMLTranslator(dMLProg);
        ProgramRewriter programRewriter = new ProgramRewriter(true, false);
        ProgramRewriter programRewriter2 = new ProgramRewriter(false, true);
        for (FunctionStatementBlock functionStatementBlock : map.values()) {
            dMLTranslator.liveVariableAnalysisFunction(dMLProg, functionStatementBlock);
            dMLTranslator.validateFunction(dMLProg, functionStatementBlock);
        }
        for (FunctionStatementBlock functionStatementBlock2 : map.values()) {
            dMLTranslator.constructHops(functionStatementBlock2);
            programRewriter.rewriteHopDAGsFunction(functionStatementBlock2, false);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock2);
            programRewriter.rewriteHopDAGsFunction(functionStatementBlock2, true);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock2);
            programRewriter2.rewriteHopDAGsFunction(functionStatementBlock2, true);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock2);
            HopRewriteUtils.setUnoptimizedFunctionCalls(functionStatementBlock2);
            DMLTranslator.resetHopsDAGVisitStatus(functionStatementBlock2);
            DMLTranslator.refreshMemEstimates(functionStatementBlock2);
            dMLTranslator.constructLops(functionStatementBlock2);
        }
        for (Map.Entry<String, FunctionStatementBlock> entry5 : map.entrySet()) {
            FunctionProgramBlock functionProgramBlock = (FunctionProgramBlock) dMLTranslator.createRuntimeProgramBlock(program, entry5.getValue(), ConfigurationManager.getDMLConfig());
            program.addFunctionProgramBlock(DMLProgram.BUILTIN_NAMESPACE, entry5.getKey(), functionProgramBlock, true);
            program.addFunctionProgramBlock(DMLProgram.BUILTIN_NAMESPACE, entry5.getKey(), functionProgramBlock, false);
        }
    }

    private static void checkValidArguments(List<Data> list, List<String> list2, List<String> list3) {
        int size = list2 != null ? list2.size() : list.size();
        if (size != list3.size()) {
            throw new DMLRuntimeException("Failed to expand list for function call (mismatching number of arguments: " + size + " vs. " + list3.size() + ").");
        }
        if (list2 != null) {
            HashSet hashSet = new HashSet();
            Iterator<String> it = list3.iterator();
            while (it.hasNext()) {
                hashSet.add(it.next());
            }
            for (String str : list2) {
                if (!hashSet.contains(str)) {
                    throw new DMLRuntimeException("List argument named '" + str + "' not in function signature.");
                }
            }
        }
    }

    private static ListObject reorderNamedListForFunctionCall(ListObject listObject, List<String> list) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = DMLScript.LINEAGE ? new ArrayList() : null;
        for (String str : list) {
            arrayList.add(listObject.getData(str));
            if (DMLScript.LINEAGE) {
                arrayList2.add(listObject.getLineageItem(str));
            }
        }
        return new ListObject(arrayList, new ArrayList(list), arrayList2);
    }
}
