package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryBranches.class */
public class RewriteRemoveUnnecessaryBranches extends StatementBlockRewriteRule {
    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        ArrayList arrayList = new ArrayList();
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            Hop hop = ifStatementBlock.getPredicateHops().getInput().get(0);
            if (hop instanceof LiteralOp) {
                IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
                if (HopRewriteUtils.getBooleanValue((LiteralOp) hop)) {
                    if (!ifStatement.getIfBody().isEmpty()) {
                        arrayList.addAll(ifStatement.getIfBody());
                    }
                } else if (!ifStatement.getElseBody().isEmpty()) {
                    arrayList.addAll(ifStatement.getElseBody());
                }
                programRewriteStatus.setRemovedBranches();
                LOG.debug("Applied removeUnnecessaryBranches (lines " + statementBlock.getBeginLine() + "-" + statementBlock.getEndLine() + ").");
            } else {
                arrayList.add(statementBlock);
            }
        } else {
            arrayList.add(statementBlock);
        }
        return arrayList;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }
}
