package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.lops.BinaryScalar;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.RightIndex;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/RewriteAddGPUEvictLop.class */
public class RewriteAddGPUEvictLop extends LopRewriteRule {
    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock statementBlock) {
        if (!ConfigurationManager.isAutoEvictionEnabled()) {
            return List.of(statementBlock);
        }
        if (statementBlock == null || !(statementBlock instanceof ForStatementBlock) || !DMLScript.USE_ACCELERATOR || LineageCacheConfig.ReuseCacheType.isNone()) {
            return List.of(statementBlock);
        }
        boolean findMiniBatchSlicing = findMiniBatchSlicing(OperatorOrderingUtils.getLopList(((ForStatement) statementBlock.getStatement(0)).getBody().get(0)));
        ArrayList arrayList = new ArrayList();
        if (findMiniBatchSlicing) {
            StatementBlock statementBlock2 = new StatementBlock();
            statementBlock2.setDMLProg(statementBlock.getDMLProg());
            statementBlock2.setParseInfo(statementBlock);
            statementBlock2.setLiveIn(new VariableSet());
            statementBlock2.setLiveOut(new VariableSet());
            ArrayList<Lop> arrayList2 = new ArrayList<>();
            ArrayList<Hop> arrayList3 = new ArrayList<>();
            Data createLiteralLop = Data.createLiteralLop(Types.ValueType.INT64, Integer.toString(100));
            createLiteralLop.getOutputParameters().setDimensions(0L, 0L, 0L, -1L);
            UnaryCP unaryCP = new UnaryCP(createLiteralLop, Types.OpOp1._EVICT, createLiteralLop.getDataType(), createLiteralLop.getValueType(), Types.ExecType.CP);
            UnaryOp unaryOp = new UnaryOp("tmp", Types.DataType.SCALAR, Types.ValueType.INT64, Types.OpOp1._EVICT, new LiteralOp(100));
            arrayList2.add(unaryCP);
            arrayList3.add(unaryOp);
            statementBlock2.setLops(arrayList2);
            statementBlock2.setHops(arrayList3);
            arrayList.add(statementBlock2);
        }
        arrayList.add(statementBlock);
        return arrayList;
    }

    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> list) {
        return list;
    }

    private boolean findMiniBatchSlicing(ArrayList<Lop> arrayList) {
        Iterator<Lop> it = arrayList.iterator();
        while (it.hasNext()) {
            Lop next = it.next();
            if (next instanceof RightIndex) {
                ArrayList<Lop> inputs = next.getInputs();
                if ((inputs.get(0) instanceof Data) && ((Data) inputs.get(0)).isTransientRead() && inputs.get(0).getInputs().size() == 0 && (inputs.get(1) instanceof BinaryScalar) && ((BinaryScalar) inputs.get(1)).getOperationType() == Types.OpOp2.PLUS && (inputs.get(2) instanceof BinaryScalar) && ((BinaryScalar) inputs.get(2)).getOperationType() == Types.OpOp2.MIN) {
                    return true;
                }
            }
        }
        return false;
    }
}
