package org.apache.sysds.lops.rewrite;

import java.util.Iterator;
import java.util.List;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/RewriteFixIDs.class */
public class RewriteFixIDs extends LopRewriteRule {
    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock statementBlock) {
        if (!ConfigurationManager.isPrefetchEnabled() && !ConfigurationManager.isBroadcastEnabled() && !ConfigurationManager.isCheckpointEnabled()) {
            return List.of(statementBlock);
        }
        if (HopRewriteUtils.isLastLevelLoopStatementBlock(statementBlock)) {
            assignNewIDStatementBlock(statementBlock instanceof WhileStatementBlock ? ((WhileStatement) statementBlock.getStatement(0)).getBody().get(0) : ((ForStatement) statementBlock.getStatement(0)).getBody().get(0));
        } else {
            assignNewIDStatementBlock(statementBlock);
        }
        return List.of(statementBlock);
    }

    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> list) {
        return list;
    }

    private void assignNewIDStatementBlock(StatementBlock statementBlock) {
        if (statementBlock.getLops() == null || statementBlock.getLops().isEmpty()) {
            return;
        }
        Iterator<Lop> it = statementBlock.getLops().iterator();
        while (it.hasNext()) {
            assignNewIDLop(it.next());
        }
        statementBlock.getLops().forEach((v0) -> {
            v0.resetVisitStatus();
        });
    }

    private void assignNewIDLop(Lop lop) {
        if (lop.isVisited()) {
            return;
        }
        if (lop.getInputs().isEmpty()) {
            lop.setNewID();
            lop.setVisited();
            return;
        }
        Iterator<Lop> it = lop.getInputs().iterator();
        while (it.hasNext()) {
            assignNewIDLop(it.next());
        }
        lop.setNewID();
        lop.setVisited();
    }
}
