package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.VariableSet;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteHoistLoopInvariantOperations.class */
public class RewriteHoistLoopInvariantOperations extends StatementBlockRewriteRule {
    private final boolean _sideEffectFreeFuns;

    public RewriteHoistLoopInvariantOperations() {
        this(false);
    }

    public RewriteHoistLoopInvariantOperations(boolean z) {
        this._sideEffectFreeFuns = z;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return true;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (statementBlock == null || !HopRewriteUtils.isLoopStatementBlock(statementBlock)) {
            return Arrays.asList(statementBlock);
        }
        Set<String> set = (Set) statementBlock.variablesRead().getVariableNames().stream().filter(str -> {
            return !statementBlock.variablesUpdated().containsVariable(str);
        }).collect(Collectors.toSet());
        HashMap hashMap = new HashMap();
        collectOperations(statementBlock, set, hashMap);
        return hashMap.isEmpty() ? Arrays.asList(statementBlock) : Arrays.asList(createStatementBlock(statementBlock, hashMap), statementBlock);
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }

    private void collectOperations(StatementBlock statementBlock, Set<String> set, Map<String, Hop> map) {
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it = ((WhileStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                collectOperations(it.next(), set, map);
            }
            return;
        }
        if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it2 = ((ForStatement) statementBlock.getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                collectOperations(it2.next(), set, map);
            }
            return;
        }
        if ((statementBlock instanceof IfStatementBlock) || statementBlock.getHops() == null) {
            return;
        }
        Hop.resetVisitStatus(statementBlock.getHops());
        HashSet hashSet = new HashSet();
        Iterator<Hop> it3 = statementBlock.getHops().iterator();
        while (it3.hasNext()) {
            rTagLoopInvariantOperations(it3.next(), set, hashSet);
        }
        Hop.resetVisitStatus(statementBlock.getHops());
        Iterator<Hop> it4 = statementBlock.getHops().iterator();
        while (it4.hasNext()) {
            rCollectAndReplaceOperations(it4.next(), set, hashSet, map);
        }
        if (hashSet.isEmpty()) {
            return;
        }
        LOG.debug("Applied hoistLoopInvariantOperations (lines " + statementBlock.getBeginLine() + "-" + statementBlock.getEndLine() + "): " + hashSet.size() + ".");
    }

    private void rTagLoopInvariantOperations(Hop hop, Set<String> set, Set<Long> set2) {
        if (hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rTagLoopInvariantOperations(it.next(), set, set2);
        }
        boolean z = (HopRewriteUtils.isDataGenOp(hop, Types.OpOpDG.RAND) || ((hop instanceof FunctionOp) && !this._sideEffectFreeFuns) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE)) ? false : true;
        Iterator<Hop> it2 = hop.getInput().iterator();
        while (it2.hasNext()) {
            Hop next = it2.next();
            z &= set.contains(next.getName()) || set2.contains(Long.valueOf(next.getHopID())) || (next instanceof LiteralOp);
        }
        if (z) {
            set2.add(Long.valueOf(hop.getHopID()));
        }
        hop.setVisited();
    }

    private void rCollectAndReplaceOperations(Hop hop, Set<String> set, Set<Long> set2, Map<String, Hop> map) {
        if (hop.isVisited()) {
            return;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            Hop hop2 = hop.getInput().get(i);
            if (set2.contains(Long.valueOf(hop2.getHopID()))) {
                String createCutVarName = createCutVarName(false);
                Hop deepCopyHopsDag = Recompiler.deepCopyHopsDag(hop2);
                deepCopyHopsDag.getParent().clear();
                map.put(createCutVarName, deepCopyHopsDag);
                DataOp createTransientRead = HopRewriteUtils.createTransientRead(createCutVarName, hop2);
                Iterator it = new ArrayList(hop2.getParent()).iterator();
                while (it.hasNext()) {
                    HopRewriteUtils.replaceChildReference((Hop) it.next(), hop2, createTransientRead);
                }
            } else {
                rCollectAndReplaceOperations(hop2, set, set2, map);
            }
        }
        hop.setVisited();
    }

    private static StatementBlock createStatementBlock(StatementBlock statementBlock, Map<String, Hop> map) {
        StatementBlock statementBlock2 = new StatementBlock();
        statementBlock2.setDMLProg(statementBlock.getDMLProg());
        statementBlock2.setParseInfo(statementBlock);
        statementBlock2.setLiveIn(new VariableSet(statementBlock.liveIn()));
        statementBlock2.setLiveOut(new VariableSet(statementBlock.liveIn()));
        ArrayList<Hop> arrayList = new ArrayList<>();
        for (Map.Entry<String, Hop> entry : map.entrySet()) {
            Hop value = entry.getValue();
            arrayList.add(HopRewriteUtils.createTransientWrite(entry.getKey(), value));
            DataIdentifier dataIdentifier = new DataIdentifier(entry.getKey());
            dataIdentifier.setDimensions(value.getDim1(), value.getDim2());
            dataIdentifier.setBlocksize(value.getBlocksize());
            dataIdentifier.setDataType(value.getDataType());
            dataIdentifier.setValueType(value.getValueType());
            statementBlock2.liveOut().addVariable(entry.getKey(), dataIdentifier);
            statementBlock.liveIn().addVariable(entry.getKey(), dataIdentifier);
        }
        statementBlock2.setHops(arrayList);
        return statementBlock2;
    }
}
