package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.stream.IntStream;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.FunctionStatement;

/* loaded from: input_file:org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.class */
public class IPAPassForwardFunctionCalls extends IPAPass {
    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return true;
    }

    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        for (String str : functionCallGraph.getReachableFunctions()) {
            FunctionStatement functionStatement = (FunctionStatement) dMLProgram.getFunctionStatementBlock(str).getStatement(0);
            if (functionStatement.getBody().size() == 1 && singleFunctionOp(functionStatement.getBody().get(0).getHops()) && hasOnlySimplyArguments((FunctionOp) functionStatement.getBody().get(0).getHops().get(0))) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("IPA: Forward-function-call candidate L1: '" + str + "'");
                }
                FunctionOp functionOp = (FunctionOp) functionStatement.getBody().get(0).getHops().get(0);
                if (hasConsistentOutputOrdering(functionStatement, functionOp) && functionCallGraph.getFunctionCalls(str).size() <= 1) {
                    if (LOG.isDebugEnabled()) {
                        LOG.debug("IPA: Forward-function-call candidate L2: '" + str + "'");
                    }
                    FunctionOp functionOp2 = functionCallGraph.getFunctionCalls(str).get(0);
                    if (hasValidVariableNames(functionOp2) && hasValidVariableNames(functionOp) && isFirstSubsetOfSecond(functionOp.getInputVariableNames(), functionOp2.getInputVariableNames())) {
                        functionOp2.setFunctionName(functionOp.getFunctionName());
                        functionOp2.setFunctionNamespace(functionOp.getFunctionNamespace());
                        reconcileFunctionInputsInPlace(functionOp2, functionOp);
                        functionCallGraph.replaceFunctionCalls(str, functionOp.getFunctionKey());
                        if (!functionCallGraph.containsSecondOrderCall()) {
                            dMLProgram.removeFunctionStatementBlock(str);
                        }
                        if (LOG.isDebugEnabled()) {
                            LOG.debug("IPA: Forward-function-call: replaced '" + str + "' with '" + functionOp.getFunctionKey() + "'");
                        }
                    }
                }
            }
        }
        return false;
    }

    private static boolean singleFunctionOp(ArrayList<Hop> arrayList) {
        if (arrayList == null || arrayList.isEmpty() || arrayList.size() != 1) {
            return false;
        }
        return arrayList.get(0) instanceof FunctionOp;
    }

    private static boolean hasOnlySimplyArguments(FunctionOp functionOp) {
        return functionOp.getInput().stream().allMatch(hop -> {
            return (hop instanceof LiteralOp) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD);
        });
    }

    private static boolean hasConsistentOutputOrdering(FunctionStatement functionStatement, FunctionOp functionOp) {
        return IntStream.range(0, Math.min(functionStatement.getOutputParams().size(), functionOp.getOutputVariableNames().length)).allMatch(i -> {
            return functionStatement.getOutputParams().get(i).getName().equals(functionOp.getOutputVariableNames()[i]);
        });
    }

    private static boolean hasValidVariableNames(FunctionOp functionOp) {
        return functionOp.getInputVariableNames() != null && Arrays.stream(functionOp.getInputVariableNames()).allMatch(str -> {
            return str != null;
        });
    }

    private static boolean isFirstSubsetOfSecond(String[] strArr, String[] strArr2) {
        HashSet hashSet = new HashSet();
        for (String str : strArr2) {
            hashSet.add(str);
        }
        return Arrays.stream(strArr).allMatch(str2 -> {
            return hashSet.contains(str2);
        });
    }

    private static void reconcileFunctionInputsInPlace(FunctionOp functionOp, FunctionOp functionOp2) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < functionOp2.getInput().size(); i++) {
            hashMap.put(functionOp2.getInputVariableNames()[i], functionOp2.getInput().get(i));
        }
        ArrayList<Hop> arrayList = new ArrayList<>();
        for (int i2 = 0; i2 < functionOp.getInput().size(); i2++) {
            if (hashMap.containsKey(functionOp.getInputVariableNames()[i2])) {
                arrayList.add(hashMap.get(functionOp.getInputVariableNames()[i2]) instanceof LiteralOp ? (Hop) hashMap.get(functionOp.getInputVariableNames()[i2]) : functionOp.getInput().get(i2));
            }
        }
        HopRewriteUtils.removeAllChildReferences(functionOp);
        functionOp.addAllInputs(arrayList);
        functionOp.setInputVariableNames(functionOp2.getInputVariableNames());
    }
}
