package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.StatementBlock;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.class */
public class RewriteAddBroadcastLop extends LopRewriteRule {
    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock statementBlock) {
        ArrayList<Lop> lopList;
        if (ConfigurationManager.isBroadcastEnabled() && (lopList = OperatorOrderingUtils.getLopList(statementBlock)) != null) {
            ArrayList arrayList = new ArrayList();
            Iterator<Lop> it = lopList.iterator();
            while (it.hasNext()) {
                Lop next = it.next();
                arrayList.add(next);
                if (isBroadcastNeeded(next)) {
                    ArrayList<Lop> arrayList2 = new ArrayList(next.getOutputs());
                    UnaryCP unaryCP = new UnaryCP(next, Types.OpOp1.BROADCAST, next.getDataType(), next.getValueType(), Types.ExecType.CP);
                    unaryCP.setAsynchronous(true);
                    for (Lop lop : arrayList2) {
                        unaryCP.addOutput(lop);
                        lop.replaceInput(next, unaryCP);
                        next.removeOutput(lop);
                    }
                    arrayList.add(unaryCP);
                }
            }
            return Arrays.asList(statementBlock);
        }
        return List.of(statementBlock);
    }

    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> list) {
        return list;
    }

    private static boolean isBroadcastNeeded(Lop lop) {
        return (lop.getExecType() == Types.ExecType.CP) && lop.getOutputs().stream().anyMatch(lop2 -> {
            return lop2.getBroadcastInput() == lop;
        }) && lop.getDataType() == Types.DataType.MATRIX;
    }
}
