package org.apache.sysds.lops.rewrite;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/lops/rewrite/RewriteAddChkpointInLoop.class */
public class RewriteAddChkpointInLoop extends LopRewriteRule {
    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock statementBlock) {
        if (!ConfigurationManager.isCheckpointEnabled()) {
            return List.of(statementBlock);
        }
        if (statementBlock == null || !HopRewriteUtils.isLastLevelLoopStatementBlock(statementBlock)) {
            return List.of(statementBlock);
        }
        Set<String> set = (Set) statementBlock.variablesRead().getVariableNames().stream().filter(str -> {
            return statementBlock.variablesUpdated().containsVariable(str);
        }).collect(Collectors.toSet());
        if (set.isEmpty()) {
            return List.of(statementBlock);
        }
        StatementBlock statementBlock2 = statementBlock instanceof WhileStatementBlock ? ((WhileStatement) statementBlock.getStatement(0)).getBody().get(0) : ((ForStatement) statementBlock.getStatement(0)).getBody().get(0);
        ArrayList<Lop> lopList = OperatorOrderingUtils.getLopList(statementBlock2);
        List list = (List) lopList.stream().filter(OperatorOrderingUtils::isLopRoot).collect(Collectors.toList());
        HashSet<Lop> hashSet = new HashSet<>();
        list.forEach(lop -> {
            OperatorOrderingUtils.collectSparkRoots(lop, new HashMap(), hashSet);
        });
        if (hashSet.isEmpty()) {
            return List.of(statementBlock);
        }
        Map<Long, Integer> hashMap = new HashMap<>();
        findOverlappingJobs(hashSet, set, hashMap);
        if (hashMap.isEmpty()) {
            return List.of(statementBlock);
        }
        addChkpointLop(lopList, hashMap, statementBlock2);
        return List.of(statementBlock);
    }

    @Override // org.apache.sysds.lops.rewrite.LopRewriteRule
    public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> list) {
        return list;
    }

    private void addChkpointLop(List<Lop> list, Map<Long, Integer> map, StatementBlock statementBlock) {
        for (Lop lop : list) {
            if (map.containsKey(Long.valueOf(lop.getID())) && map.get(Long.valueOf(lop.getID())).intValue() > 1) {
                ArrayList arrayList = new ArrayList(lop.getOutputs());
                Checkpoint checkpoint = new Checkpoint(lop, lop.getDataType(), lop.getValueType(), Checkpoint.getDefaultStorageLevelString(), false);
                for (Lop lop2 : arrayList) {
                    checkpoint.addOutput(lop2);
                    lop2.replaceInput(lop, checkpoint);
                    lop.removeOutput(lop2);
                }
                statementBlock.setCheckpointPosition(lop, arrayList);
            }
        }
    }

    private void findOverlappingJobs(HashSet<Lop> hashSet, Set<String> set, Map<Long, Integer> map) {
        HashSet hashSet2 = new HashSet();
        for (String str : set) {
            Iterator<Lop> it = hashSet.iterator();
            while (it.hasNext()) {
                Lop next = it.next();
                if (ifJobContains(next, str)) {
                    hashSet2.add(next);
                }
                next.resetVisitStatus();
            }
            if (!hashSet2.isEmpty()) {
                OperatorOrderingUtils.markSharedSparkOps(hashSet2, map);
            }
            hashSet2.clear();
        }
    }

    private boolean ifJobContains(Lop lop, String str) {
        if (lop.isVisited()) {
            return false;
        }
        Iterator<Lop> it = lop.getInputs().iterator();
        while (it.hasNext()) {
            Lop next = it.next();
            if ((next instanceof Data) || (next.isExecSpark() && lop.getBroadcastInput() != next)) {
                if (ifJobContains(next, str)) {
                    lop.setVisited();
                    return true;
                }
            }
        }
        if ((lop instanceof Data) && ((Data) lop).isTransientRead() && lop.getOutputParameters().getLabel().equalsIgnoreCase(str)) {
            lop.setVisited();
            return true;
        }
        lop.setVisited();
        return false;
    }
}
