package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/hops/ipa/FunctionCallGraph.class */
public class FunctionCallGraph {
    private static final String MAIN_FUNCTION_KEY = "_main";
    private final HashMap<String, HashSet<String>> _fGraph = new HashMap<>();
    private final HashMap<String, ArrayList<FunctionOp>> _fCalls = new HashMap<>();
    private final HashMap<String, ArrayList<StatementBlock>> _fCallsSB = new HashMap<>();
    private final HashSet<String> _fRecursive = new HashSet<>();
    private final HashSet<String> _fSideEffectFree = new HashSet<>();
    private final boolean _containsSecondOrder;

    public FunctionCallGraph(DMLProgram dMLProgram) {
        this._containsSecondOrder = constructFunctionCallGraph(dMLProgram);
    }

    public FunctionCallGraph(StatementBlock statementBlock) {
        this._containsSecondOrder = constructFunctionCallGraph(statementBlock);
    }

    public Set<String> getCalledFunctions(String str, String str2) {
        return getCalledFunctions(DMLProgram.constructFunctionKey(str, str2));
    }

    public Set<String> getCalledFunctions(String str) {
        return this._fGraph.get(str == null ? MAIN_FUNCTION_KEY : str);
    }

    public List<FunctionOp> getFunctionCalls(String str) {
        return str == null ? Collections.emptyList() : this._fCalls.get(str);
    }

    public List<StatementBlock> getFunctionCallsSB(String str) {
        return str == null ? Collections.emptyList() : this._fCallsSB.get(str);
    }

    public void removeFunctionCalls(String str) {
        this._fCalls.remove(str);
        this._fCallsSB.remove(str);
        this._fRecursive.remove(str);
        this._fGraph.remove(str);
        Iterator<Map.Entry<String, HashSet<String>>> it = this._fGraph.entrySet().iterator();
        while (it.hasNext()) {
            it.next().getValue().removeIf(str2 -> {
                return str2.equals(str);
            });
        }
    }

    public void removeFunctionCall(String str, FunctionOp functionOp, StatementBlock statementBlock) {
        if (this._fCalls.containsKey(str)) {
            this._fCalls.get(str).remove(functionOp);
        }
        if (this._fCallsSB.containsKey(str)) {
            this._fCallsSB.get(str).remove(statementBlock);
        }
    }

    public void replaceFunctionCalls(String str, String str2) {
        ArrayList<FunctionOp> arrayList = this._fCalls.get(str);
        ArrayList<StatementBlock> arrayList2 = this._fCallsSB.get(str);
        this._fCalls.remove(str);
        this._fCallsSB.remove(str);
        this._fCalls.put(str2, arrayList);
        this._fCallsSB.put(str2, arrayList2);
        this._fRecursive.remove(str);
        this._fSideEffectFree.remove(str);
        this._fGraph.remove(str);
        Iterator<HashSet<String>> it = this._fGraph.values().iterator();
        while (it.hasNext()) {
            it.next().remove(str);
        }
    }

    public boolean isRecursiveFunction(String str, String str2) {
        return isRecursiveFunction(DMLProgram.constructFunctionKey(str, str2));
    }

    public boolean isRecursiveFunction(String str) {
        return this._fRecursive.contains(str == null ? MAIN_FUNCTION_KEY : str);
    }

    public boolean isSideEffectFreeFunction(String str, String str2) {
        return isSideEffectFreeFunction(DMLProgram.constructFunctionKey(str, str2));
    }

    public boolean isSideEffectFreeFunction(String str) {
        return this._fSideEffectFree.contains(str == null ? MAIN_FUNCTION_KEY : str);
    }

    public Set<String> getReachableFunctions() {
        return getReachableFunctions(Collections.emptySet());
    }

    public Set<String> getReachableFunctions(Set<String> set) {
        return (Set) this._fGraph.keySet().stream().filter(str -> {
            return (set.contains(str) || MAIN_FUNCTION_KEY.equals(str)) ? false : true;
        }).collect(Collectors.toSet());
    }

    public boolean isReachableFunction(String str, String str2) {
        return isReachableFunction(DMLProgram.constructFunctionKey(str, str2));
    }

    public boolean isReachableFunction(String str) {
        return isReachableFunction(str, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isReachableFunction(String str, boolean z) {
        String str2 = str == null ? MAIN_FUNCTION_KEY : str;
        return !z ? this._fGraph.containsKey(str2) : this._fGraph.values().stream().anyMatch(hashSet -> {
            return hashSet.contains(str2);
        });
    }

    public boolean containsSecondOrderCall() {
        return this._containsSecondOrder;
    }

    private boolean constructFunctionCallGraph(DMLProgram dMLProgram) {
        if (!dMLProgram.hasFunctionStatementBlocks()) {
            boolean z = false;
            Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
            while (it.hasNext()) {
                z |= rAnalyzeSecondOrderCall(it.next());
            }
            return z;
        }
        boolean z2 = false;
        try {
            Stack<String> stack = new Stack<>();
            HashSet<String> hashSet = new HashSet<>();
            this._fGraph.put(MAIN_FUNCTION_KEY, new HashSet<>());
            Iterator<StatementBlock> it2 = dMLProgram.getStatementBlocks().iterator();
            while (it2.hasNext()) {
                z2 |= rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, it2.next(), stack, hashSet);
            }
            this._fSideEffectFree.addAll((Collection) this._fCalls.keySet().stream().filter(str -> {
                return !str.startsWith(DMLProgram.INTERNAL_NAMESPACE);
            }).filter(str2 -> {
                return isSideEffectFree(dMLProgram.getFunctionStatementBlock(str2));
            }).collect(Collectors.toList()));
            return z2;
        } catch (HopsException e) {
            throw new RuntimeException(e);
        }
    }

    private boolean constructFunctionCallGraph(StatementBlock statementBlock) {
        if (!statementBlock.getDMLProg().hasFunctionStatementBlocks()) {
            return false;
        }
        try {
            Stack<String> stack = new Stack<>();
            HashSet<String> hashSet = new HashSet<>();
            this._fGraph.put(MAIN_FUNCTION_KEY, new HashSet<>());
            return rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, statementBlock, stack, hashSet);
        } catch (HopsException e) {
            throw new RuntimeException(e);
        }
    }

    private boolean rConstructFunctionCallGraph(String str, StatementBlock statementBlock, Stack<String> stack, HashSet<String> hashSet) {
        boolean z = false;
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it = ((WhileStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                z |= rConstructFunctionCallGraph(str, it.next(), stack, hashSet);
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            Iterator<StatementBlock> it2 = ifStatement.getIfBody().iterator();
            while (it2.hasNext()) {
                z |= rConstructFunctionCallGraph(str, it2.next(), stack, hashSet);
            }
            Iterator<StatementBlock> it3 = ifStatement.getElseBody().iterator();
            while (it3.hasNext()) {
                z |= rConstructFunctionCallGraph(str, it3.next(), stack, hashSet);
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it4 = ((ForStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it4.hasNext()) {
                z |= rConstructFunctionCallGraph(str, it4.next(), stack, hashSet);
            }
        } else if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it5 = ((FunctionStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it5.hasNext()) {
                z |= rConstructFunctionCallGraph(str, it5.next(), stack, hashSet);
            }
        } else {
            ArrayList<Hop> hops = statementBlock.getHops();
            if (hops == null || hops.isEmpty()) {
                return false;
            }
            z = HopRewriteUtils.containsSecondOrderBuiltin(hops);
            Iterator<Hop> it6 = hops.iterator();
            while (it6.hasNext()) {
                Hop next = it6.next();
                if (next instanceof FunctionOp) {
                    FunctionOp functionOp = (FunctionOp) next;
                    String functionKey = functionOp.getFunctionKey();
                    if (!this._fCalls.containsKey(functionKey)) {
                        this._fCalls.put(functionKey, new ArrayList<>());
                        this._fCallsSB.put(functionKey, new ArrayList<>());
                    }
                    this._fCalls.get(functionKey).add(functionOp);
                    this._fCallsSB.get(functionKey).add(statementBlock);
                    if (!hashSet.contains(functionKey) && !functionOp.getFunctionNamespace().equals(DMLProgram.INTERNAL_NAMESPACE)) {
                        if (!this._fGraph.containsKey(functionKey)) {
                            this._fGraph.put(functionKey, new HashSet<>());
                        }
                        if (stack.contains(functionKey)) {
                            this._fGraph.get(str).add(functionKey);
                            this._fRecursive.add(functionKey);
                            for (int indexOf = stack.indexOf(functionKey) + 1; indexOf < stack.size(); indexOf++) {
                                this._fRecursive.add(stack.get(indexOf));
                            }
                        } else {
                            stack.push(functionKey);
                            this._fGraph.get(str).add(functionKey);
                            Iterator<StatementBlock> it7 = ((FunctionStatement) statementBlock.getDMLProg().getFunctionStatementBlock(functionOp.getFunctionNamespace(), functionOp.getFunctionName()).getStatement(0)).getBody().iterator();
                            while (it7.hasNext()) {
                                z |= rConstructFunctionCallGraph(functionKey, it7.next(), stack, new HashSet<>());
                            }
                            stack.pop();
                        }
                        hashSet.add(functionKey);
                    }
                }
            }
        }
        return z;
    }

    private boolean rAnalyzeSecondOrderCall(StatementBlock statementBlock) {
        boolean z = false;
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it = ((WhileStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                z |= rAnalyzeSecondOrderCall(it.next());
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            Iterator<StatementBlock> it2 = ifStatement.getIfBody().iterator();
            while (it2.hasNext()) {
                z |= rAnalyzeSecondOrderCall(it2.next());
            }
            Iterator<StatementBlock> it3 = ifStatement.getElseBody().iterator();
            while (it3.hasNext()) {
                z |= rAnalyzeSecondOrderCall(it3.next());
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it4 = ((ForStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it4.hasNext()) {
                z |= rAnalyzeSecondOrderCall(it4.next());
            }
        } else {
            ArrayList<Hop> hops = statementBlock.getHops();
            if (hops == null || hops.isEmpty()) {
                return false;
            }
            z = HopRewriteUtils.containsSecondOrderBuiltin(hops);
        }
        return z;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSideEffectFree(FunctionStatementBlock functionStatementBlock) {
        Iterator<StatementBlock> it = ((FunctionStatement) functionStatementBlock.getStatement(0)).getBody().iterator();
        while (it.hasNext()) {
            if (rHasSideEffects(it.next())) {
                return false;
            }
        }
        return true;
    }

    private static boolean rHasSideEffects(StatementBlock statementBlock) {
        boolean z = false;
        if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it = ((ForStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                z |= rHasSideEffects(it.next());
            }
        } else if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it2 = ((WhileStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                z |= rHasSideEffects(it2.next());
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) statementBlock.getStatement(0);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                z |= rHasSideEffects(it3.next());
            }
            if (ifStatement.getElseBody() != null) {
                Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
                while (it4.hasNext()) {
                    z |= rHasSideEffects(it4.next());
                }
            }
        } else if (statementBlock.getHops() != null) {
            Iterator<Hop> it5 = statementBlock.getHops().iterator();
            while (it5.hasNext()) {
                Hop next = it5.next();
                z |= HopRewriteUtils.isUnary(next, Types.OpOp1.PRINT) || HopRewriteUtils.isNary(next, Types.OpOpN.PRINTF) || HopRewriteUtils.isData(next, Types.OpOpData.PERSISTENTWRITE) || (next instanceof FunctionOp);
            }
        }
        return z;
    }
}
