package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.MatMultCP;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.matrix.data.LibMatrixNative;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/RewriteUpdateGPUPlacements.class */
public class RewriteUpdateGPUPlacements extends LopRewriteRule {
    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock statementBlock) {
        if (!ConfigurationManager.isRuleBasedGPUPlacement()) {
            return List.of(statementBlock);
        }
        ArrayList<Lop> lopList = OperatorOrderingUtils.getLopList(statementBlock);
        if (lopList == null || lopList.stream().noneMatch((v0) -> {
            return v0.isExecGPU();
        })) {
            return List.of(statementBlock);
        }
        ArrayList<Lop> lops = statementBlock.getLops();
        lops.forEach(this::rUpdateExecType);
        lops.forEach((v0) -> {
            v0.resetVisitStatus();
        });
        return List.of(statementBlock);
    }

    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> list) {
        return list;
    }

    private void updateExecTypeGPU2CP(Lop lop) {
        if (lop.isExecGPU()) {
            Iterator<Lop> it = lop.getInputs().iterator();
            while (it.hasNext()) {
                Lop next = it.next();
                if (next.getNnz() >= 0 && MatrixBlock.evalSparseFormatInMemory(next.getNumRows(), next.getNumCols(), next.getNnz())) {
                    lop.setExecType(Types.ExecType.CP);
                    return;
                }
            }
            if (!(lop instanceof MatMultCP) || LibMatrixNative.isMatMultMemoryBound((int) lop.getInput(0).getNumRows(), (int) lop.getInput(0).getNumCols(), (int) lop.getInput(1).getNumCols())) {
                if (lop.getInputs().size() == 2) {
                    long estimateSizeInMemory = MatrixBlock.estimateSizeInMemory(lop.getInput(0).getNumRows(), lop.getInput(0).getNumCols(), lop.getInput(0).getNnz());
                    long estimateSizeInMemory2 = MatrixBlock.estimateSizeInMemory(lop.getInput(1).getNumRows(), lop.getInput(1).getNumCols(), lop.getInput(1).getNnz());
                    if (estimateSizeInMemory > estimateSizeInMemory2 && !(lop.getInput(0) instanceof Data) && !lop.getInput(0).isExecGPU() && !lop.isAllOutputsGPU()) {
                        lop.setExecType(Types.ExecType.CP);
                    }
                    if (estimateSizeInMemory2 > estimateSizeInMemory && !(lop.getInput(1) instanceof Data) && !lop.getInput(1).isExecGPU() && !lop.isAllOutputsGPU()) {
                        lop.setExecType(Types.ExecType.CP);
                    }
                    if (estimateSizeInMemory == estimateSizeInMemory2 && !(lop.getInput(0) instanceof Data) && !(lop.getInput(1) instanceof Data) && !lop.getInput(0).isExecGPU() && !lop.getInput(1).isExecGPU() && !lop.isAllOutputsGPU()) {
                        lop.setExecType(Types.ExecType.CP);
                    }
                }
                if (lop.getInputs().size() == 1 && !(lop.getInput(0) instanceof Data) && !lop.getInput(0).isExecGPU() && !lop.isAllOutputsGPU()) {
                    lop.setExecType(Types.ExecType.CP);
                }
                if (lop.getInputs().size() > 2) {
                    int i = 0;
                    int i2 = 0;
                    Iterator<Lop> it2 = lop.getInputs().iterator();
                    while (it2.hasNext()) {
                        Lop next2 = it2.next();
                        if (!(next2 instanceof Data) && next2.isExecGPU()) {
                            i++;
                        }
                        if (!(next2 instanceof Data) && next2.isExecCP()) {
                            i2++;
                        }
                    }
                    if (i2 <= i || lop.isAllOutputsGPU()) {
                        return;
                    }
                    lop.setExecType(Types.ExecType.CP);
                }
            }
        }
    }

    private void rUpdateExecType(Lop lop) {
        if (lop.isVisited()) {
            return;
        }
        Iterator<Lop> it = lop.getInputs().iterator();
        while (it.hasNext()) {
            Lop next = it.next();
            if (!(next instanceof Data)) {
                rUpdateExecType(next);
            }
        }
        updateExecTypeGPU2CP(lop);
        lop.setVisited();
    }
}
