package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.Program;

/* loaded from: input_file:org/apache/sysds/hops/ipa/IPAPassReplaceEvalFunctionCalls.class */
public class IPAPassReplaceEvalFunctionCalls extends IPAPass {
    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return functionCallGraph.containsSecondOrderCall() && OptimizerUtils.ALLOW_EVAL_FCALL_REPLACEMENT;
    }

    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        boolean z = false;
        for (String str : dMLProgram.getNamespaces().keySet()) {
            Iterator<String> it = dMLProgram.getFunctionStatementBlocks(str).keySet().iterator();
            while (it.hasNext()) {
                z |= rewriteStatementBlock(dMLProgram, dMLProgram.getFunctionStatementBlock(str, it.next()), functionCallGraph);
            }
        }
        Iterator<StatementBlock> it2 = dMLProgram.getStatementBlocks().iterator();
        while (it2.hasNext()) {
            z |= rewriteStatementBlock(dMLProgram, it2.next(), functionCallGraph);
        }
        return z;
    }

    private static boolean rewriteStatementBlock(DMLProgram dMLProgram, StatementBlock statementBlock, FunctionCallGraph functionCallGraph) {
        boolean z = false;
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                z |= rewriteStatementBlock(dMLProgram, it.next(), functionCallGraph);
            }
        } else if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it2 = ((WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                z |= rewriteStatementBlock(dMLProgram, it2.next(), functionCallGraph);
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                z |= rewriteStatementBlock(dMLProgram, it3.next(), functionCallGraph);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                z |= rewriteStatementBlock(dMLProgram, it4.next(), functionCallGraph);
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it5 = ((ForStatement) ((ForStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it5.hasNext()) {
                z |= rewriteStatementBlock(dMLProgram, it5.next(), functionCallGraph);
            }
        } else {
            z = false | checkAndReplaceEvalFunctionCall(dMLProgram, statementBlock, functionCallGraph);
        }
        return z;
    }

    private static boolean checkAndReplaceEvalFunctionCall(DMLProgram dMLProgram, StatementBlock statementBlock, FunctionCallGraph functionCallGraph) {
        if (statementBlock.getHops() == null) {
            return false;
        }
        ArrayList<Hop> hops = statementBlock.getHops();
        boolean z = false;
        for (int i = 0; i < hops.size(); i++) {
            Hop hop = hops.get(i);
            if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE, Types.OpOpData.PERSISTENTWRITE) && HopRewriteUtils.isNary(hop.getInput(0), Types.OpOpN.EVAL) && (hop.getInput(0).getInput(0) instanceof LiteralOp) && hop.getInput(0).getParent().size() == 1) {
                Hop input = hop.getInput(0);
                String name = ((DataOp) hop).getName();
                String stringValue = ((LiteralOp) input.getInput(0)).getStringValue();
                String str = dMLProgram.getDefaultFunctionDictionary().containsFunction(stringValue) ? DMLProgram.DEFAULT_NAMESPACE : DMLProgram.BUILTIN_NAMESPACE;
                if (stringValue.contains(Program.KEY_DELIM)) {
                    String[] splitFunctionKey = DMLProgram.splitFunctionKey(stringValue);
                    str = splitFunctionKey[0];
                    stringValue = splitFunctionKey[1];
                }
                String internalFName = str.equals(DMLProgram.BUILTIN_NAMESPACE) ? Builtins.getInternalFName(stringValue, input.getInput(1).getDataType()) : stringValue;
                FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(str, internalFName);
                FunctionStatement functionStatement = functionStatementBlock != null ? (FunctionStatement) functionStatementBlock.getStatement(0) : null;
                if (input.getInput().size() > 1 && input.getInput(1).getDataType().isList() && (functionStatement == null || !functionStatement.getInputParams().get(0).getDataType().isList())) {
                    LOG.warn("IPA: eval(" + str + Program.KEY_DELIM + internalFName + ") applicable for replacement, but list inputs not yet supported.");
                } else if (functionStatement.getOutputParams().size() == 1 && functionStatement.getOutputParams().get(0).getDataType().isMatrix()) {
                    FunctionOp functionOp = new FunctionOp(FunctionOp.FunctionType.DML, str, internalFName, functionStatement.getInputParamNames(), input.getInput().subList(1, input.getInput().size()), new String[]{name}, true);
                    HopRewriteUtils.copyLineNumbers(input, functionOp);
                    HopRewriteUtils.removeAllChildReferences(input);
                    hops.set(i, functionOp);
                    z = true;
                } else {
                    LOG.warn("IPA: eval(" + str + Program.KEY_DELIM + internalFName + ") applicable for replacement, but function output is not a matrix.");
                }
            }
        }
        return z;
    }
}
