package org.apache.sysds.runtime.compress.workload;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.IndexingOp;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.ParameterizedBuiltinOp;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.RewriteCompressedReblock;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParForStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.workload.AWTreeNode;
import org.apache.sysds.utils.Explain;

/* loaded from: input_file:org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.class */
public class WorkloadAnalyzer {
    private static final Log LOG = LogFactory.getLog(WorkloadAnalyzer.class.getName());
    public static boolean ALLOW_INTERMEDIATE_CANDIDATES = false;
    public static boolean PRUNE_COMPRESSED_INTERMEDIATES = true;
    private final Set<Hop> visited;
    private final Set<Long> compressed;
    private final Set<Long> transposed;
    private final Map<String, Long> transientCompressed;
    private final Set<Long> overlapping;
    private final DMLProgram prog;
    private final Map<Long, Op> treeLookup;

    public static Map<Long, WTreeRoot> getAllCandidateWorkloads(DMLProgram dMLProgram) {
        List<Hop> candidates = getCandidates(dMLProgram);
        LinkedList linkedList = new LinkedList();
        HashMap hashMap = new HashMap();
        for (Hop hop : candidates) {
            if (!PRUNE_COMPRESSED_INTERMEDIATES || !linkedList.stream().anyMatch(workloadAnalyzer -> {
                return workloadAnalyzer.containsCompressed(hop);
            })) {
                WorkloadAnalyzer workloadAnalyzer2 = new WorkloadAnalyzer(dMLProgram);
                hashMap.put(Long.valueOf(hop.getHopID()), workloadAnalyzer2.createWorkloadTree(hop));
                linkedList.add(workloadAnalyzer2);
            }
        }
        return hashMap;
    }

    private WorkloadAnalyzer(DMLProgram dMLProgram) {
        this.prog = dMLProgram;
        this.visited = new HashSet();
        this.compressed = new HashSet();
        this.transposed = new HashSet();
        this.transientCompressed = new HashMap();
        this.overlapping = new HashSet();
        this.treeLookup = new HashMap();
    }

    private WorkloadAnalyzer(DMLProgram dMLProgram, Set<Long> set, HashMap<String, Long> hashMap, Set<Long> set2, Set<Long> set3, Map<Long, Op> map) {
        this.prog = dMLProgram;
        this.visited = new HashSet();
        this.compressed = set;
        this.transposed = set2;
        this.transientCompressed = hashMap;
        this.overlapping = set3;
        this.treeLookup = map;
    }

    private WTreeRoot createWorkloadTree(Hop hop) {
        WTreeRoot wTreeRoot = new WTreeRoot(hop);
        this.compressed.add(Long.valueOf(hop.getHopID()));
        Iterator<StatementBlock> it = this.prog.getStatementBlocks().iterator();
        while (it.hasNext()) {
            createWorkloadTree(wTreeRoot, it.next(), this.prog, new HashSet<>());
        }
        pruneWorkloadTree(wTreeRoot);
        return wTreeRoot;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean containsCompressed(Hop hop) {
        return this.compressed.contains(Long.valueOf(hop.getHopID()));
    }

    private static List<Hop> getCandidates(DMLProgram dMLProgram) {
        ArrayList arrayList = new ArrayList();
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            getCandidates(it.next(), dMLProgram, arrayList, new HashSet());
        }
        return arrayList;
    }

    private static boolean pruneWorkloadTree(AWTreeNode aWTreeNode) {
        Iterator<WTreeNode> it = aWTreeNode.getChildNodes().iterator();
        while (it.hasNext()) {
            if (pruneWorkloadTree(it.next())) {
                it.remove();
            }
        }
        return aWTreeNode.isEmpty();
    }

    private static void getCandidates(StatementBlock statementBlock, DMLProgram dMLProgram, List<Hop> list, Set<String> set) {
        if (statementBlock == null) {
            return;
        }
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                getCandidates(it.next(), dMLProgram, list, set);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            Iterator<StatementBlock> it2 = ((WhileStatement) ((WhileStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                getCandidates(it2.next(), dMLProgram, list, set);
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatement ifStatement = (IfStatement) ((IfStatementBlock) statementBlock).getStatement(0);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                getCandidates(it3.next(), dMLProgram, list, set);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                getCandidates(it4.next(), dMLProgram, list, set);
            }
            return;
        }
        if (statementBlock instanceof ForStatementBlock) {
            Iterator<StatementBlock> it5 = ((ForStatement) ((ForStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it5.hasNext()) {
                getCandidates(it5.next(), dMLProgram, list, set);
            }
            return;
        }
        ArrayList<Hop> hops = statementBlock.getHops();
        if (hops != null) {
            Hop.resetVisitStatus(hops);
            Iterator<Hop> it6 = hops.iterator();
            while (it6.hasNext()) {
                getCandidates(it6.next(), dMLProgram, list, set);
            }
            Hop.resetVisitStatus(hops);
        }
    }

    private static void getCandidates(Hop hop, DMLProgram dMLProgram, List<Hop> list, Set<String> set) {
        if (hop.isVisited()) {
            return;
        }
        if ((ALLOW_INTERMEDIATE_CANDIDATES && RewriteCompressedReblock.satisfiesAggressiveCompressionCondition(hop)) || RewriteCompressedReblock.satisfiesCompressionCondition(hop)) {
            list.add(hop);
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            getCandidates(it.next(), dMLProgram, list, set);
        }
        if (hop instanceof FunctionOp) {
            FunctionOp functionOp = (FunctionOp) hop;
            if (!set.contains(functionOp.getFunctionKey())) {
                set.add(functionOp.getFunctionKey());
                getCandidates(dMLProgram.getFunctionStatementBlock(functionOp.getFunctionKey()), dMLProgram, list, set);
                set.remove(functionOp.getFunctionKey());
            }
        }
        hop.setVisited();
    }

    private void createWorkloadTree(AWTreeNode aWTreeNode, StatementBlock statementBlock, DMLProgram dMLProgram, Set<String> set) {
        WTreeNode wTreeNode;
        if (statementBlock instanceof FunctionStatementBlock) {
            FunctionStatement functionStatement = (FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0);
            wTreeNode = new WTreeNode(AWTreeNode.WTNodeType.FCALL, 1);
            Iterator<StatementBlock> it = functionStatement.getBody().iterator();
            while (it.hasNext()) {
                createWorkloadTree(wTreeNode, it.next(), dMLProgram, set);
            }
        } else if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            wTreeNode = new WTreeNode(AWTreeNode.WTNodeType.WHILE, 10);
            createWorkloadTree(whileStatementBlock.getPredicateHops(), dMLProgram, wTreeNode, set);
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                createWorkloadTree(wTreeNode, it2.next(), dMLProgram, set);
            }
        } else if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            wTreeNode = new WTreeNode(AWTreeNode.WTNodeType.IF, 1);
            createWorkloadTree(ifStatementBlock.getPredicateHops(), dMLProgram, wTreeNode, set);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                createWorkloadTree(wTreeNode, it3.next(), dMLProgram, set);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                createWorkloadTree(wTreeNode, it4.next(), dMLProgram, set);
            }
        } else {
            if (!(statementBlock instanceof ForStatementBlock)) {
                ArrayList<Hop> hops = statementBlock.getHops();
                if (hops != null) {
                    Iterator<Hop> it5 = hops.iterator();
                    while (it5.hasNext()) {
                        createWorkloadTree(it5.next(), dMLProgram, aWTreeNode, set);
                    }
                    Iterator<Hop> it6 = hops.iterator();
                    while (it6.hasNext()) {
                        Hop next = it6.next();
                        if (next instanceof FunctionOp) {
                            FunctionOp functionOp = (FunctionOp) next;
                            if (!set.contains(functionOp.getFunctionKey())) {
                                set.add(functionOp.getFunctionKey());
                                FunctionStatementBlock functionStatementBlock = dMLProgram.getFunctionStatementBlock(functionOp.getFunctionKey());
                                if (functionStatementBlock != null) {
                                    HashMap hashMap = new HashMap();
                                    String[] inputVariableNames = functionOp.getInputVariableNames();
                                    for (int i = 0; i < inputVariableNames.length; i++) {
                                        String str = inputVariableNames[i];
                                        Long valueOf = Long.valueOf(functionOp.getInput(i).getHopID());
                                        if (this.compressed.contains(valueOf)) {
                                            hashMap.put(str, valueOf);
                                        }
                                    }
                                    new WorkloadAnalyzer(dMLProgram, this.compressed, hashMap, this.transposed, this.overlapping, this.treeLookup).createWorkloadTree(aWTreeNode, functionStatementBlock, dMLProgram, set);
                                    String[] outputVariableNames = functionOp.getOutputVariableNames();
                                    for (int i2 = 0; i2 < outputVariableNames.length; i2++) {
                                        Long l = (Long) hashMap.get(outputVariableNames[i2]);
                                        if (l != null) {
                                            this.transientCompressed.put(outputVariableNames[i2], l);
                                        }
                                    }
                                    set.remove(functionOp.getFunctionKey());
                                }
                            }
                        }
                    }
                    return;
                }
                return;
            }
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
            wTreeNode = new WTreeNode(statementBlock instanceof ParForStatementBlock ? AWTreeNode.WTNodeType.PARFOR : AWTreeNode.WTNodeType.FOR, forStatementBlock.getEstimateReps());
            createWorkloadTree(forStatementBlock.getFromHops(), dMLProgram, wTreeNode, set);
            createWorkloadTree(forStatementBlock.getToHops(), dMLProgram, wTreeNode, set);
            createWorkloadTree(forStatementBlock.getIncrementHops(), dMLProgram, wTreeNode, set);
            Iterator<StatementBlock> it7 = forStatement.getBody().iterator();
            while (it7.hasNext()) {
                createWorkloadTree(wTreeNode, it7.next(), dMLProgram, set);
            }
        }
        aWTreeNode.addChild(wTreeNode);
    }

    private void createWorkloadTree(Hop hop, DMLProgram dMLProgram, AWTreeNode aWTreeNode, Set<String> set) {
        if (hop == null || this.visited.contains(hop) || isNoOp(hop)) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            createWorkloadTree(it.next(), dMLProgram, aWTreeNode, set);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.OpOpData.TRANSIENTREAD) && this.transientCompressed.containsKey(hop.getName())) {
            this.compressed.add(Long.valueOf(hop.getHopID()));
            this.treeLookup.put(Long.valueOf(hop.getHopID()), this.treeLookup.get(this.transientCompressed.get(hop.getName())));
        }
        if (hop.getInput().stream().anyMatch(hop2 -> {
            return this.compressed.contains(Long.valueOf(hop2.getHopID()));
        })) {
            createOp(hop, aWTreeNode);
        }
        this.visited.add(hop);
    }

    private void createOp(Hop hop, AWTreeNode aWTreeNode) {
        if (!hop.getDataType().isMatrix()) {
            aWTreeNode.addOp(new OpNormal(hop, false));
            return;
        }
        Op op = null;
        if (HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD, Types.OpOpData.TRANSIENTREAD)) {
            return;
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE, Types.OpOpData.PERSISTENTWRITE)) {
            this.transientCompressed.put(hop.getName(), Long.valueOf(hop.getInput(0).getHopID()));
            this.compressed.add(Long.valueOf(hop.getHopID()));
            op = new OpMetadata(hop, hop.getInput(0));
            if (isOverlapping(hop.getInput(0))) {
                op.setOverlapping();
            }
        } else if ((hop instanceof ReorgOp) && ((ReorgOp) hop).getOp() == Types.ReOrgOp.TRANS) {
            this.transposed.add(Long.valueOf(hop.getHopID()));
            this.compressed.add(Long.valueOf(hop.getHopID()));
            this.transientCompressed.put(hop.getName(), Long.valueOf(hop.getHopID()));
            op = new OpMetadata(hop, hop.getInput(0));
            if (isOverlapping(hop.getInput(0))) {
                op.setOverlapping();
            }
        } else if (hop instanceof AggUnaryOp) {
            if ((isOverlapping(hop.getInput().get(0)) && !HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.MEAN)) || HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.TRACE)) {
                setDecompressionOnAllInputs(hop, aWTreeNode);
                return;
            }
            op = new OpNormal(hop, false);
        } else if (!(hop instanceof UnaryOp) || HopRewriteUtils.isUnary(hop, Types.OpOp1.MULT2, Types.OpOp1.MINUS1_MULT, Types.OpOp1.MINUS_RIGHT, Types.OpOp1.CAST_AS_MATRIX)) {
            if (hop instanceof AggBinaryOp) {
                ArrayList<Hop> input = ((AggBinaryOp) hop).getInput();
                OpSided opSided = new OpSided(hop, this.compressed.contains(Long.valueOf(input.get(0).getHopID())) || this.transientCompressed.containsKey(input.get(0).getName()), this.compressed.contains(Long.valueOf(input.get(1).getHopID())) || this.transientCompressed.containsKey(input.get(1).getName()), this.transposed.contains(Long.valueOf(input.get(0).getHopID())), this.transposed.contains(Long.valueOf(input.get(1).getHopID())));
                if (opSided.isRightMM()) {
                    this.overlapping.add(Long.valueOf(hop.getHopID()));
                    opSided.setOverlapping();
                    if (!opSided.isCompressedOutput()) {
                        opSided.setDecompressing();
                    }
                }
                op = opSided;
            } else if (hop instanceof BinaryOp) {
                if (HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND)) {
                    ArrayList<Hop> input2 = hop.getInput();
                    op = new OpNormal(hop, true);
                    if (isOverlapping(input2.get(0)) || isOverlapping(input2.get(1))) {
                        this.overlapping.add(Long.valueOf(hop.getHopID()));
                        op.setOverlapping();
                    }
                    op.setDecompressing();
                } else {
                    if (HopRewriteUtils.isBinary(hop, Types.OpOp2.RBIND)) {
                        setDecompressionOnAllInputs(hop, aWTreeNode);
                        return;
                    }
                    ArrayList<Hop> input3 = hop.getInput();
                    boolean isOverlapping = isOverlapping(input3.get(0));
                    boolean isOverlapping2 = isOverlapping(input3.get(1));
                    boolean z = isOverlapping || isOverlapping2;
                    if (input3.get(1).getDim1() != 1 && !input3.get(1).isScalar() && !input3.get(0).isScalar()) {
                        if (HopRewriteUtils.isBinaryMatrixMatrixOperation(hop) || HopRewriteUtils.isBinaryMatrixColVectorOperation(hop) || HopRewriteUtils.isBinaryMatrixMatrixOperationWithSharedInput(hop)) {
                            setDecompressionOnAllInputs(hop, aWTreeNode);
                            return;
                        } else if (isOverlapping || isOverlapping2) {
                            setDecompressionOnAllInputs(hop, aWTreeNode);
                            return;
                        } else {
                            LOG.warn("Setting decompressed because input Binary Op is unknown, please add the case to WorkloadAnalyzer:\n" + Explain.explain(hop));
                            setDecompressionOnAllInputs(hop, aWTreeNode);
                            return;
                        }
                    }
                    if (z && HopRewriteUtils.isBinary(hop, Types.OpOp2.PLUS, Types.OpOp2.MULT, Types.OpOp2.DIV, Types.OpOp2.MINUS)) {
                        this.overlapping.add(Long.valueOf(hop.getHopID()));
                        op = new OpNormal(hop, true);
                        op.setOverlapping();
                    } else {
                        if (z) {
                            this.treeLookup.get(Long.valueOf(input3.get(0).getHopID())).setDecompressing();
                            return;
                        }
                        op = new OpNormal(hop, true);
                    }
                    if (!HopRewriteUtils.isBinarySparseSafe(hop)) {
                        op.setDensifying();
                    }
                }
            } else if (hop instanceof IndexingOp) {
                boolean isOverlapping3 = isOverlapping(hop.getInput(0));
                if (HopRewriteUtils.isFullColumnIndexing((IndexingOp) hop)) {
                    op = new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop));
                    if (isOverlapping3) {
                        this.overlapping.add(Long.valueOf(hop.getHopID()));
                        op.setOverlapping();
                    }
                } else {
                    op = new OpNormal(hop, false);
                    op.setDecompressing();
                }
            } else {
                if (HopRewriteUtils.isTernary(hop, Types.OpOp3.MINUS_MULT, Types.OpOp3.PLUS_MULT, Types.OpOp3.QUANTILE, Types.OpOp3.CTABLE)) {
                    setDecompressionOnAllInputs(hop, aWTreeNode);
                    return;
                }
                if (!HopRewriteUtils.isTernary(hop, Types.OpOp3.IFELSE)) {
                    if (!(hop instanceof ParameterizedBuiltinOp)) {
                        throw new DMLCompressionException("Unknown Hop: " + Explain.explain(hop));
                    }
                    setDecompressionOnAllInputs(hop, aWTreeNode);
                    return;
                }
                Hop input4 = hop.getInput(1);
                Hop input5 = hop.getInput(2);
                if (isCompressed(input4) && isCompressed(input5)) {
                    op = new OpMetadata(hop, input4);
                    if (isOverlapping(input4) || isOverlapping(input5)) {
                        op.setOverlapping();
                    }
                } else if (isCompressed(input4)) {
                    op = new OpMetadata(hop, input4);
                    if (isOverlapping(input4)) {
                        op.setOverlapping();
                    }
                } else if (isCompressed(input5)) {
                    op = new OpMetadata(hop, input5);
                    if (isOverlapping(input5)) {
                        op.setOverlapping();
                    }
                } else {
                    setDecompressionOnAllInputs(hop, aWTreeNode);
                }
            }
        } else if (isOverlapping(hop.getInput(0))) {
            this.treeLookup.get(Long.valueOf(hop.getInput(0).getHopID())).setDecompressing();
            return;
        }
        Op opNormal = op != null ? op : new OpNormal(hop, RewriteCompressedReblock.satisfiesSizeConstraintsForCompression(hop));
        this.treeLookup.put(Long.valueOf(hop.getHopID()), opNormal);
        aWTreeNode.addOp(opNormal);
        if (opNormal.isCompressedOutput()) {
            this.compressed.add(Long.valueOf(hop.getHopID()));
        }
    }

    private boolean isCompressed(Hop hop) {
        return this.compressed.contains(Long.valueOf(hop.getHopID()));
    }

    private void setDecompressionOnAllInputs(Hop hop, AWTreeNode aWTreeNode) {
        Op op;
        if (aWTreeNode instanceof WTreeRoot) {
            ((WTreeRoot) aWTreeNode).setDecompressing();
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            Op op2 = this.treeLookup.get(Long.valueOf(it.next().getHopID()));
            if (op2 != null) {
                while ((op2 instanceof OpMetadata) && (op = this.treeLookup.get(Long.valueOf(((OpMetadata) op2).getParent().getHopID()))) != null) {
                    op2 = op;
                }
                op2.setDecompressing();
            }
        }
    }

    private boolean isOverlapping(Hop hop) {
        Op op = this.treeLookup.get(Long.valueOf(hop.getHopID()));
        if (op != null) {
            return op.isOverlapping();
        }
        return false;
    }

    private static boolean isNoOp(Hop hop) {
        return (hop instanceof LiteralOp) || HopRewriteUtils.isUnary(hop, Types.OpOp1.NROW, Types.OpOp1.NCOL);
    }
}
