package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LeftIndexingOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteMarkLoopVariablesUpdateInPlace.class */
public class RewriteMarkLoopVariablesUpdateInPlace extends StatementBlockRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (DMLScript.getGlobalExecMode() == Types.ExecMode.SPARK) {
            return Arrays.asList(statementBlock);
        }
        if ((statementBlock instanceof WhileStatementBlock) || (statementBlock instanceof ForStatementBlock)) {
            ArrayList<String> arrayList = new ArrayList<>();
            VariableSet variablesUpdated = statementBlock.variablesUpdated();
            VariableSet liveOut = statementBlock.liveOut();
            for (String str : variablesUpdated.getVariableNames()) {
                if (variablesUpdated.getVariable(str).getDataType() == Types.DataType.MATRIX && liveOut.containsVariable(str)) {
                    if (statementBlock instanceof WhileStatementBlock) {
                        if (rIsApplicableForUpdateInPlace(((WhileStatement) statementBlock.getStatement(0)).getBody(), str)) {
                            arrayList.add(str);
                        }
                    } else if ((statementBlock instanceof ForStatementBlock) && rIsApplicableForUpdateInPlace(((ForStatement) statementBlock.getStatement(0)).getBody(), str)) {
                        arrayList.add(str);
                    }
                }
            }
            statementBlock.setUpdateInPlaceVars(arrayList);
        }
        return Arrays.asList(statementBlock);
    }

    private boolean rIsApplicableForUpdateInPlace(ArrayList<StatementBlock> arrayList, String str) {
        boolean z = true;
        Iterator<StatementBlock> it = arrayList.iterator();
        while (it.hasNext()) {
            StatementBlock next = it.next();
            if (next.variablesRead().containsVariable(str) || next.variablesUpdated().containsVariable(str)) {
                if ((next instanceof WhileStatementBlock) || (next instanceof ForStatementBlock)) {
                    z &= next.getUpdateInPlaceVars().contains(str);
                } else if (next instanceof IfStatementBlock) {
                    IfStatement ifStatement = (IfStatement) ((IfStatementBlock) next).getStatement(0);
                    z &= rIsApplicableForUpdateInPlace(ifStatement.getIfBody(), str);
                    if (z && ifStatement.getElseBody() != null) {
                        z &= rIsApplicableForUpdateInPlace(ifStatement.getElseBody(), str);
                    }
                } else if (next.getHops() != null && !isApplicableForUpdateInPlace(next.getHops(), str)) {
                    Iterator<Hop> it2 = next.getHops().iterator();
                    while (it2.hasNext()) {
                        z &= isApplicableForUpdateInPlace(it2.next(), str);
                    }
                }
                if (!z) {
                    break;
                }
            }
        }
        return z;
    }

    private static boolean isApplicableForUpdateInPlace(Hop hop, String str) {
        if ((hop instanceof FunctionOp) && ((FunctionOp) hop).containsOutput(str)) {
            return false;
        }
        if (!hop.getName().equals(str)) {
            return true;
        }
        boolean probeLixRoot = probeLixRoot(hop, str);
        if (probeLixRoot) {
            Iterator<Hop> it = hop.getInput().get(0).getInput().get(0).getParent().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                probeLixRoot &= next == hop.getInput().get(0) || ((next instanceof UnaryOp) && ((UnaryOp) next).getOp() == Types.OpOp1.NROW) || ((next instanceof UnaryOp) && ((UnaryOp) next).getOp() == Types.OpOp1.NCOL);
            }
        }
        return probeLixRoot;
    }

    private static boolean isApplicableForUpdateInPlace(ArrayList<Hop> arrayList, String str) {
        Hop hop = null;
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            Hop next = it.next();
            if (probeLixRoot(next, str)) {
                if (hop != null) {
                    return false;
                }
                hop = next.getInput().get(0);
            }
        }
        boolean z = true;
        Hop.resetVisitStatus(arrayList);
        Iterator<Hop> it2 = arrayList.iterator();
        while (it2.hasNext()) {
            Hop next2 = it2.next();
            if (next2.getInput().get(0) != hop) {
                z &= rProbeOtherRoot(next2, str);
            }
        }
        Hop.resetVisitStatus(arrayList);
        return z;
    }

    private static boolean probeLixRoot(Hop hop, String str) {
        return (hop instanceof DataOp) && hop.isMatrix() && hop.getInput().get(0).isMatrix() && (hop.getInput().get(0) instanceof LeftIndexingOp) && (hop.getInput().get(0).getInput().get(0) instanceof DataOp) && hop.getInput().get(0).getInput().get(0).getName().equals(str);
    }

    private static boolean rProbeOtherRoot(Hop hop, String str) {
        if (hop.isVisited()) {
            return false;
        }
        boolean z = ((hop instanceof LeftIndexingOp) || (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) && hop.getName().equals(str))) ? false : true;
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            z &= rProbeOtherRoot(it.next(), str);
        }
        hop.setVisited();
        return z;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }
}
