package org.apache.sysds.runtime.compress.workload;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.RewriteCompressedReblock;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.compress.workload.WTreeNode;

/* loaded from: input_file:org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.class */
public class WorkloadAnalyzer {
    public static Map<Long, WTreeNode> getAllCandidateWorkloads(DMLProgram dMLProgram) {
        List<Hop> candidates = getCandidates(dMLProgram);
        HashMap hashMap = new HashMap();
        for (Hop hop : candidates) {
            WTreeNode createWorkloadTree = createWorkloadTree(dMLProgram, hop);
            pruneWorkloadTree(createWorkloadTree);
            hashMap.put(Long.valueOf(hop.getHopID()), createWorkloadTree);
        }
        return hashMap;
    }

    public static List<Hop> getCandidates(DMLProgram dMLProgram) {
        ArrayList arrayList = new ArrayList();
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            getCandidates(it.next(), dMLProgram, arrayList, new HashSet());
        }
        return arrayList;
    }

    public static WTreeNode createWorkloadTree(DMLProgram dMLProgram, Hop hop) {
        WTreeNode wTreeNode = new WTreeNode(WTreeNode.WTNodeType.MAIN);
        HashSet hashSet = new HashSet();
        hashSet.add(hop.getName());
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            wTreeNode.addChild(createWorkloadTree(it.next(), dMLProgram, hashSet, new HashSet()));
        }
        return wTreeNode;
    }

    public static boolean pruneWorkloadTree(WTreeNode wTreeNode) {
        Iterator<WTreeNode> it = wTreeNode.getChildNodes().iterator();
        while (it.hasNext()) {
            if (pruneWorkloadTree(it.next())) {
                it.remove();
            }
        }
        return wTreeNode.getChildNodes().isEmpty() && wTreeNode.getCompressedOps().isEmpty();
    }

    private static void getCandidates(StatementBlock statementBlock, DMLProgram dMLProgram, List<Hop> list, Set<String> set) {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                getCandidates(it.next(), dMLProgram, list, set);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it2 = ((WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                getCandidates(it2.next(), dMLProgram, list, set);
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                getCandidates(it3.next(), dMLProgram, list, set);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                getCandidates(it4.next(), dMLProgram, list, set);
            }
            return;
        }
        if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it5 = ((ForStatement) ((ForStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it5.hasNext()) {
                getCandidates(it5.next(), dMLProgram, list, set);
            }
        } else {
            if (statementBlock.getHops() == null) {
                return;
            }
            Hop.resetVisitStatus(statementBlock.getHops());
            Iterator<Hop> it6 = statementBlock.getHops().iterator();
            while (it6.hasNext()) {
                getCandidates(it6.next(), dMLProgram, list, set);
            }
            Hop.resetVisitStatus(statementBlock.getHops());
        }
    }

    private static void getCandidates(Hop hop, DMLProgram dMLProgram, List<Hop> list, Set<String> set) {
        if (hop.isVisited()) {
            return;
        }
        if (RewriteCompressedReblock.satisfiesCompressionCondition(hop)) {
            list.add(hop);
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            getCandidates(it.next(), dMLProgram, list, set);
        }
        if (hop instanceof FunctionOp) {
            FunctionOp functionOp = (FunctionOp) hop;
            if (!set.contains(functionOp.getFunctionKey())) {
                set.add(functionOp.getFunctionKey());
                getCandidates(dMLProgram.getFunctionStatementBlock(functionOp.getFunctionKey()), dMLProgram, list, set);
                set.remove(functionOp.getFunctionKey());
            }
        }
        hop.setVisited();
    }

    private static WTreeNode createWorkloadTree(StatementBlock statementBlock, DMLProgram dMLProgram, Set<String> set, Set<String> set2) {
        WTreeNode wTreeNode;
        if (statementBlock instanceof FunctionStatementBlock) {
            FunctionStatement functionStatement = (FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0);
            wTreeNode = new WTreeNode(WTreeNode.WTNodeType.FCALL);
            Iterator<StatementBlock> it = functionStatement.getBody().iterator();
            while (it.hasNext()) {
                wTreeNode.addChild(createWorkloadTree(it.next(), dMLProgram, set, set2));
            }
        } else if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            wTreeNode = new WTreeNode(WTreeNode.WTNodeType.WHILE);
            createWorkloadTree(whileStatementBlock.getPredicateHops(), dMLProgram, wTreeNode, set, set2);
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                wTreeNode.addChild(createWorkloadTree(it2.next(), dMLProgram, set, set2));
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            wTreeNode = new WTreeNode(WTreeNode.WTNodeType.IF);
            createWorkloadTree(ifStatementBlock.getPredicateHops(), dMLProgram, wTreeNode, set, set2);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                wTreeNode.addChild(createWorkloadTree(it3.next(), dMLProgram, set, set2));
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                wTreeNode.addChild(createWorkloadTree(it4.next(), dMLProgram, set, set2));
            }
        } else if (statementBlock instanceof ForStatementBlock) {
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
            wTreeNode = new WTreeNode(statementBlock instanceof ParForStatementBlock ? WTreeNode.WTNodeType.PARFOR : WTreeNode.WTNodeType.FOR);
            createWorkloadTree(forStatementBlock.getFromHops(), dMLProgram, wTreeNode, set, set2);
            createWorkloadTree(forStatementBlock.getToHops(), dMLProgram, wTreeNode, set, set2);
            createWorkloadTree(forStatementBlock.getIncrementHops(), dMLProgram, wTreeNode, set, set2);
            Iterator<StatementBlock> it5 = forStatement.getBody().iterator();
            while (it5.hasNext()) {
                wTreeNode.addChild(createWorkloadTree(it5.next(), dMLProgram, set, set2));
            }
        } else {
            wTreeNode = new WTreeNode(WTreeNode.WTNodeType.BASIC_BLOCK);
            if (statementBlock.getHops() != null) {
                Hop.resetVisitStatus(statementBlock.getHops());
                HashSet hashSet = new HashSet();
                Iterator<Hop> it6 = statementBlock.getHops().iterator();
                while (it6.hasNext()) {
                    createWorkloadTree(it6.next(), dMLProgram, wTreeNode, set, hashSet, set2);
                }
                Iterator<Hop> it7 = statementBlock.getHops().iterator();
                while (it7.hasNext()) {
                    Hop next = it7.next();
                    if (next instanceof FunctionOp) {
                        FunctionOp functionOp = (FunctionOp) next;
                        if (!set2.contains(functionOp.getFunctionKey())) {
                            set2.add(functionOp.getFunctionKey());
                            FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(functionOp.getFunctionKey());
                            FunctionStatement functionStatement2 = (FunctionStatement) functionStatementBlock.getStatement(0);
                            HashSet hashSet2 = new HashSet();
                            ArrayList<DataIdentifier> inputParams = functionStatement2.getInputParams();
                            for (int i = 0; i < inputParams.size(); i++) {
                                if (hashSet.contains(Long.valueOf(functionOp.getInput(i).getHopID()))) {
                                    hashSet2.add(inputParams.get(i).getName());
                                }
                            }
                            wTreeNode.addChild(createWorkloadTree(functionStatementBlock, dMLProgram, hashSet2, set2));
                            set2.remove(functionOp.getFunctionKey());
                        }
                    } else if (HopRewriteUtils.isData(next, Types.OpOpData.TRANSIENTWRITE)) {
                        if (set.contains(next.getName()) && !hashSet.contains(Long.valueOf(next.getHopID()))) {
                            set.remove(next.getName());
                        }
                        if (!set.contains(next.getName()) && hashSet.contains(Long.valueOf(next.getHopID()))) {
                            set.add(next.getName());
                        }
                    }
                }
                Hop.resetVisitStatus(statementBlock.getHops());
            }
        }
        wTreeNode.setLineNumbers(statementBlock.getBeginLine(), statementBlock.getEndLine());
        return wTreeNode;
    }

    private static void createWorkloadTree(Hop hop, DMLProgram dMLProgram, WTreeNode wTreeNode, Set<String> set, Set<String> set2) {
        if (hop == null) {
            return;
        }
        hop.resetVisitStatus();
        createWorkloadTree(hop, dMLProgram, wTreeNode, set, new HashSet(), set2);
        hop.resetVisitStatus();
    }

    private static void createWorkloadTree(Hop hop, DMLProgram dMLProgram, WTreeNode wTreeNode, Set<String> set, Set<Long> set2, Set<String> set3) {
        if (hop == null || hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            createWorkloadTree(it.next(), dMLProgram, wTreeNode, set, set2, set3);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.OpOpData.TRANSIENTREAD) && set.contains(hop.getName())) {
            set2.add(Long.valueOf(hop.getHopID()));
        }
        if (hop.getInput().stream().anyMatch(hop2 -> {
            return set2.contains(Long.valueOf(hop2.getHopID()));
        })) {
            if (!HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.OpOpData.TRANSIENTREAD, Types.OpOpData.TRANSIENTWRITE)) {
                wTreeNode.addCompressedOp(hop);
            }
            if (RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop) && hop.getDataType().isMatrix()) {
                set2.add(Long.valueOf(hop.getHopID()));
            }
        }
        hop.setVisited();
    }
}
