package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/MarkForLineageReuse.class */
public class MarkForLineageReuse extends StatementBlockRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (!HopRewriteUtils.isLoopStatementBlock(statementBlock) || LineageCacheConfig.ReuseCacheType.isNone()) {
            return Arrays.asList(statementBlock);
        }
        if (statementBlock instanceof ForStatementBlock) {
            ForStatement forStatement = (ForStatement) statementBlock.getStatement(0);
            rUnmarkLoopDepVarsSB(forStatement.getBody(), new HashSet<>(), new HashSet(Arrays.asList(forStatement.getIterablePredicate().getIterVar().getName())));
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatement whileStatement = (WhileStatement) statementBlock.getStatement(0);
            rUnmarkLoopDepVarsSB(whileStatement.getBody(), new HashSet<>(), (Set) statementBlock.variablesUpdated().getVariableNames().stream().filter(str -> {
                return whileStatement.getConditionalPredicate().variablesRead().containsVariable(str);
            }).collect(Collectors.toSet()));
        }
        return Arrays.asList(statementBlock);
    }

    private void rUnmarkLoopDepVarsSB(ArrayList<StatementBlock> arrayList, HashSet<String> hashSet, Set<String> set) {
        HashSet<String> hashSet2 = new HashSet<>();
        int i = 0;
        while (true) {
            hashSet2.clear();
            hashSet2.addAll(hashSet);
            Iterator<StatementBlock> it = arrayList.iterator();
            while (it.hasNext()) {
                StatementBlock next = it.next();
                if (next instanceof ForStatementBlock) {
                    rUnmarkLoopDepVarsSB(((ForStatement) next.getStatement(0)).getBody(), hashSet2, set);
                } else if (next instanceof WhileStatementBlock) {
                    rUnmarkLoopDepVarsSB(((WhileStatement) next.getStatement(0)).getBody(), hashSet2, set);
                } else if (next instanceof IfStatementBlock) {
                    IfStatement ifStatement = (IfStatement) next.getStatement(0);
                    rUnmarkLoopDepVarsSB(ifStatement.getIfBody(), hashSet2, set);
                    if (ifStatement.getElseBody() != null) {
                        rUnmarkLoopDepVarsSB(ifStatement.getElseBody(), hashSet2, set);
                    }
                } else if (next instanceof FunctionStatementBlock) {
                    rUnmarkLoopDepVarsSB(((FunctionStatement) next.getStatement(0)).getBody(), hashSet2, set);
                } else if (next.getHops() != null) {
                    for (int i2 = 0; i2 < next.variablesUpdated().getSize(); i2++) {
                        HashSet<String> hashSet3 = new HashSet<>(hashSet);
                        Iterator<Hop> it2 = next.getHops().iterator();
                        while (it2.hasNext()) {
                            Hop next2 = it2.next();
                            Hop.resetVisitStatus(next.getHops());
                            rUnmarkLoopDepVars(next2, set, hashSet3, new HashSet<>());
                        }
                        if (hashSet.isEmpty() || !hashSet.equals(hashSet3)) {
                            hashSet.addAll(hashSet3);
                        }
                    }
                }
            }
            hashSet.addAll(hashSet2);
            i++;
            if (i >= arrayList.size()) {
                return;
            }
            if (!hashSet.isEmpty() && hashSet.equals(hashSet2)) {
                return;
            }
        }
    }

    private void rUnmarkLoopDepVars(Hop hop, Set<String> set, HashSet<String> hashSet, HashSet<Long> hashSet2) {
        if (hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rUnmarkLoopDepVars(it.next(), set, hashSet, hashSet2);
        }
        boolean z = set.contains(hop.getName()) || (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) && hashSet.contains(hop.getName()));
        Iterator<Hop> it2 = hop.getInput().iterator();
        while (it2.hasNext()) {
            z |= hashSet2.contains(Long.valueOf(it2.next().getHopID()));
        }
        if (z) {
            hashSet2.add(Long.valueOf(hop.getHopID()));
            hop.setRequiresLineageCaching(false);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE) && !hashSet2.isEmpty()) {
            hashSet.add(hop.getName());
        }
        hop.setVisited();
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }
}
