package org.apache.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.LanguageException;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

/* loaded from: input_file:org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.class */
public class IPAPassFlagNonDeterminism extends IPAPass {
    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return !functionCallGraph.containsSecondOrderCall();
    }

    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        if (!LineageCacheConfig.isMultiLevelReuse() && !DMLScript.LINEAGE_ESTIMATE) {
            return false;
        }
        try {
            HashSet<String> hashSet = new HashSet<>();
            for (String str : functionCallGraph.getReachableFunctions()) {
                if (rIsNonDeterministicFnc(DMLProgram.splitFunctionKey(str)[1], ((FunctionStatement) dMLProgram.getFunctionStatementBlock(str).getStatement(0)).getBody())) {
                    hashSet.add(str);
                }
            }
            propagate2Callers(functionCallGraph, hashSet, new HashSet<>(), null);
            hashSet.forEach(str2 -> {
                dMLProgram.getFunctionStatementBlock(str2).setNondeterministic(true);
            });
            rMarkNondeterministicSBs(dMLProgram.getStatementBlocks(), hashSet);
            Iterator<String> it = functionCallGraph.getReachableFunctions().iterator();
            while (it.hasNext()) {
                rMarkNondeterministicSBs(((FunctionStatement) dMLProgram.getFunctionStatementBlock(it.next()).getStatement(0)).getBody(), hashSet);
            }
            return false;
        } catch (LanguageException e) {
            throw new HopsException(e);
        }
    }

    private boolean rIsNonDeterministicFnc(String str, ArrayList<StatementBlock> arrayList) {
        boolean z = false;
        Iterator<StatementBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            if (z) {
                break;
            }
            if (next instanceof ForStatementBlock) {
                z = rIsNonDeterministicFnc(str, ((ForStatement) next.getStatement(0)).getBody());
            } else if (next instanceof WhileStatementBlock) {
                z = rIsNonDeterministicFnc(str, ((WhileStatement) next.getStatement(0)).getBody());
            } else if (next instanceof IfStatementBlock) {
                IfStatement ifStatement = (IfStatement) next.getStatement(0);
                z = rIsNonDeterministicFnc(str, ifStatement.getIfBody());
                if (ifStatement.getElseBody() != null) {
                    z = rIsNonDeterministicFnc(str, ifStatement.getElseBody());
                }
            } else if (next.getHops() != null) {
                Hop.resetVisitStatus(next.getHops());
                Iterator<Hop> it2 = next.getHops().iterator();
                while (it2.hasNext()) {
                    z |= rIsNonDeterministicHop(it2.next());
                }
                Hop.resetVisitStatus(next.getHops());
                next.setNondeterministic(z);
            }
        }
        return z;
    }

    private void rMarkNondeterministicSBs(ArrayList<StatementBlock> arrayList, HashSet<String> hashSet) {
        Iterator<StatementBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            if (next instanceof ForStatementBlock) {
                rMarkNondeterministicSBs(((ForStatement) next.getStatement(0)).getBody(), hashSet);
            } else if (next instanceof WhileStatementBlock) {
                rMarkNondeterministicSBs(((WhileStatement) next.getStatement(0)).getBody(), hashSet);
            } else if (next instanceof IfStatementBlock) {
                IfStatement ifStatement = (IfStatement) next.getStatement(0);
                rMarkNondeterministicSBs(ifStatement.getIfBody(), hashSet);
                if (ifStatement.getElseBody() != null) {
                    rMarkNondeterministicSBs(ifStatement.getElseBody(), hashSet);
                }
            } else if (next.getHops() != null) {
                boolean z = false;
                Hop.resetVisitStatus(next.getHops());
                Iterator<Hop> it2 = next.getHops().iterator();
                while (it2.hasNext()) {
                    z |= rMarkNondeterministicHop(it2.next(), hashSet);
                }
                Hop.resetVisitStatus(next.getHops());
                if (z) {
                    next.setNondeterministic(z);
                }
            }
        }
    }

    private boolean rMarkNondeterministicHop(Hop hop, HashSet<String> hashSet) {
        if (hop.isVisited()) {
            return false;
        }
        boolean z = (hop instanceof FunctionOp) && hashSet.contains(hop.getName());
        if (!z) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                z |= rMarkNondeterministicHop(it.next(), hashSet);
            }
        }
        hop.setVisited();
        return z;
    }

    private boolean rIsNonDeterministicHop(Hop hop) {
        if (hop.isVisited()) {
            return false;
        }
        boolean isDataGenOpWithNonDeterminism = HopRewriteUtils.isDataGenOpWithNonDeterminism(hop);
        if (!isDataGenOpWithNonDeterminism) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                isDataGenOpWithNonDeterminism |= rIsNonDeterministicHop(it.next());
            }
        }
        hop.setVisited();
        return isDataGenOpWithNonDeterminism;
    }

    private void propagate2Callers(FunctionCallGraph functionCallGraph, HashSet<String> hashSet, HashSet<String> hashSet2, String str) {
        Set<String> calledFunctions = functionCallGraph.getCalledFunctions(str);
        if (calledFunctions != null) {
            for (String str2 : calledFunctions) {
                if (!hashSet2.contains(str2) || !functionCallGraph.isRecursiveFunction(str2)) {
                    hashSet2.add(str2);
                    propagate2Callers(functionCallGraph, hashSet, hashSet2, str2);
                    hashSet2.remove(str2);
                    if (hashSet.contains(str2) && str != null) {
                        hashSet.add(str);
                    }
                } else if (hashSet.contains(str2) && str != null) {
                    hashSet.add(str);
                }
            }
        }
    }
}
