package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.lops.MMTSJ;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteCompressedReblock.class */
public class RewriteCompressedReblock extends StatementBlockRewriteRule {
    private static final String TMP_PREFIX = "__cmtx";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteCompressedReblock$ProbeStatus.class */
    public static class ProbeStatus {
        private final long startHopID;
        private final DMLProgram prog;
        private boolean foundStart;
        private boolean usedInLoop;
        private boolean condUpdate;
        private boolean nonApplicable;
        private HashSet<String> procFn;
        private HashSet<String> compMtx;

        public ProbeStatus(long j, DMLProgram dMLProgram) {
            this.foundStart = false;
            this.usedInLoop = false;
            this.condUpdate = false;
            this.nonApplicable = false;
            this.procFn = new HashSet<>();
            this.compMtx = new HashSet<>();
            this.startHopID = j;
            this.prog = dMLProgram;
        }

        public ProbeStatus(ProbeStatus probeStatus) {
            this.foundStart = false;
            this.usedInLoop = false;
            this.condUpdate = false;
            this.nonApplicable = false;
            this.procFn = new HashSet<>();
            this.compMtx = new HashSet<>();
            this.startHopID = probeStatus.startHopID;
            this.prog = probeStatus.prog;
            this.foundStart = probeStatus.foundStart;
            this.usedInLoop = probeStatus.usedInLoop;
            this.condUpdate = probeStatus.condUpdate;
            this.nonApplicable = probeStatus.nonApplicable;
            this.procFn.addAll(probeStatus.procFn);
        }
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (!HopRewriteUtils.isLastLevelStatementBlock(statementBlock) || statementBlock.getHops() == null) {
            return Arrays.asList(statementBlock);
        }
        Compression.CompressConfig valueOf = Compression.CompressConfig.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.COMPRESSED_LINALG).toUpperCase());
        if (valueOf.isEnabled()) {
            Hop.resetVisitStatus(statementBlock.getHops());
            Iterator<Hop> it = statementBlock.getHops().iterator();
            while (it.hasNext()) {
                injectCompressionDirective(it.next(), valueOf, statementBlock.getDMLProg());
            }
            Hop.resetVisitStatus(statementBlock.getHops());
        }
        return Arrays.asList(statementBlock);
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }

    private static void injectCompressionDirective(Hop hop, Compression.CompressConfig compressConfig, DMLProgram dMLProgram) {
        if (hop.isVisited() || hop.requiresCompression()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            injectCompressionDirective(it.next(), compressConfig, dMLProgram);
        }
        if ((compressConfig == Compression.CompressConfig.TRUE && satisfiesCompressionCondition(hop)) || (compressConfig == Compression.CompressConfig.AUTO && satisfiesAutoCompressionCondition(hop, dMLProgram))) {
            hop.setRequiresCompression(true);
        }
        hop.setVisited();
    }

    private static boolean satisfiesCompressionCondition(Hop hop) {
        return HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD) && hop.getDim1() > 1 && hop.getDim2() > 1;
    }

    private static boolean satisfiesAutoCompressionCondition(Hop hop, DMLProgram dMLProgram) {
        if (!satisfiesCompressionCondition(hop) || hop.getMemEstimate() < OptimizerUtils.getLocalMemBudget() || !OptimizerUtils.isSparkExecutionMode()) {
            return false;
        }
        boolean z = ((double) OptimizerUtils.estimatePartitionedSizeExactSparsity(hop.getDim1(), hop.getDim2(), (long) hop.getBlocksize(), hop.getNnz())) > SparkExecutionContext.getDataMemoryBudget(true, true);
        boolean z2 = OptimizerUtils.getSparsity(hop.getDim1(), hop.getDim2(), hop.getNnz()) < 4.0E-5d;
        if (!hop.dimsKnown(true) || !z || z2) {
            if (!LOG.isDebugEnabled()) {
                return false;
            }
            LOG.debug("Auto compression: false (dimsKnown=" + hop.dimsKnown(true) + ", outOfCore=" + z + ", !ultraSparse=" + (!z2) + ")");
            return false;
        }
        ProbeStatus probeStatus = new ProbeStatus(hop.getHopID(), dMLProgram);
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            rAnalyzeProgram(it.next(), probeStatus);
        }
        boolean z3 = probeStatus.foundStart && probeStatus.usedInLoop && !probeStatus.condUpdate && !probeStatus.nonApplicable;
        if (LOG.isDebugEnabled()) {
            LOG.debug("Auto compression: " + z3 + " (dimsKnown=" + hop.dimsKnown(true) + ", outOfCore=" + z + ", !ultraSparse=" + (!z2) + ", foundStart=" + probeStatus.foundStart + ", usedInLoop=" + probeStatus.foundStart + ", !condUpdate=" + (!probeStatus.condUpdate) + ", !nonApplicable=" + (!probeStatus.nonApplicable) + ")");
        }
        return z3;
    }

    private static void rAnalyzeProgram(StatementBlock statementBlock, ProbeStatus probeStatus) {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                rAnalyzeProgram(it.next(), probeStatus);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            Iterator<StatementBlock> it2 = ((WhileStatement) whileStatementBlock.getStatement(0)).getBody().iterator();
            while (it2.hasNext()) {
                rAnalyzeProgram(it2.next(), probeStatus);
            }
            if (whileStatementBlock.variablesRead().containsAnyName(probeStatus.compMtx)) {
                probeStatus.usedInLoop = true;
                return;
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                rAnalyzeProgram(it3.next(), probeStatus);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                rAnalyzeProgram(it4.next(), probeStatus);
            }
            if (ifStatementBlock.variablesUpdated().containsAnyName(probeStatus.compMtx)) {
                probeStatus.condUpdate = true;
                return;
            }
            return;
        }
        if (statementBlock instanceof ForStatementBlock) {
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            Iterator<StatementBlock> it5 = ((ForStatement) forStatementBlock.getStatement(0)).getBody().iterator();
            while (it5.hasNext()) {
                rAnalyzeProgram(it5.next(), probeStatus);
            }
            if (forStatementBlock.variablesRead().containsAnyName(probeStatus.compMtx)) {
                probeStatus.usedInLoop = true;
                return;
            }
            return;
        }
        if (statementBlock.getHops() != null) {
            ArrayList<Hop> hops = statementBlock.getHops();
            Hop.resetVisitStatus(hops);
            Iterator<Hop> it6 = hops.iterator();
            while (it6.hasNext()) {
                rAnalyzeHopDag(it6.next(), probeStatus);
            }
            probeStatus.compMtx.removeIf(str -> {
                return str.startsWith(TMP_PREFIX);
            });
            Hop.resetVisitStatus(hops);
        }
    }

    private static void rAnalyzeHopDag(Hop hop, ProbeStatus probeStatus) {
        if (hop.isVisited()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rAnalyzeHopDag(it.next(), probeStatus);
        }
        if (hop.getHopID() == probeStatus.startHopID) {
            probeStatus.compMtx.add(getTmpName(hop));
            probeStatus.foundStart = true;
        }
        if ((hop instanceof FunctionOp) && hasCompressedInput(hop, probeStatus)) {
            FunctionOp functionOp = (FunctionOp) hop;
            String functionKey = functionOp.getFunctionKey();
            if (!probeStatus.procFn.contains(functionKey)) {
                probeStatus.procFn.add(functionKey);
                FunctionStatementBlock functionStatementBlock = probeStatus.prog.getFunctionStatementBlock(functionKey);
                FunctionStatement functionStatement = (FunctionStatement) functionStatementBlock.getStatement(0);
                ProbeStatus probeStatus2 = new ProbeStatus(probeStatus);
                for (int i = 0; i < functionOp.getInput().size(); i++) {
                    if (probeStatus.compMtx.contains(getTmpName(functionOp.getInput().get(i)))) {
                        probeStatus2.compMtx.add(functionStatement.getInputParams().get(i).getName());
                    }
                }
                rAnalyzeProgram(functionStatementBlock, probeStatus2);
                probeStatus.foundStart |= probeStatus2.foundStart;
                probeStatus.usedInLoop |= probeStatus2.usedInLoop;
                probeStatus.condUpdate |= probeStatus2.condUpdate;
                probeStatus.nonApplicable |= probeStatus2.nonApplicable;
                String[] outputVariableNames = functionOp.getOutputVariableNames();
                for (int i2 = 0; i2 < outputVariableNames.length; i2++) {
                    if (probeStatus2.compMtx.contains(functionStatement.getOutputParams().get(i2).getName())) {
                        probeStatus.compMtx.add(outputVariableNames[i2]);
                    }
                }
            }
        } else if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE) && probeStatus.compMtx.contains(getTmpName(hop.getInput().get(0)))) {
            probeStatus.compMtx.add(hop.getName());
        } else if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) && probeStatus.compMtx.contains(hop.getName())) {
            probeStatus.compMtx.add(getTmpName(hop));
        } else if (hasCompressedInput(hop, probeStatus)) {
            boolean z = ((hop instanceof AggBinaryOp) && hop.getDim2() <= ((long) hop.getBlocksize()) && ((AggBinaryOp) hop).checkTransposeSelf() == MMTSJ.MMTSJType.LEFT) || ((hop instanceof AggBinaryOp) && (hop.getDim1() == 1 || hop.getDim2() == 1)) || ((HopRewriteUtils.isTransposeOperation(hop) && hop.getParent().size() == 1 && (hop.getParent().get(0) instanceof AggBinaryOp) && (hop.getParent().get(0).getDim1() == 1 || hop.getParent().get(0).getDim2() == 1)) || HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX));
            boolean z2 = HopRewriteUtils.isBinaryMatrixScalarOperation(hop) || HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND);
            probeStatus.nonApplicable |= (z || z2 || HopRewriteUtils.isUnary(hop, Types.OpOp1.NROW, Types.OpOp1.NCOL)) ? false : true;
            if (z2) {
                probeStatus.compMtx.add(getTmpName(hop));
            }
        }
        hop.setVisited();
    }

    private static String getTmpName(Hop hop) {
        return TMP_PREFIX + hop.getHopID();
    }

    private static boolean hasCompressedInput(Hop hop, ProbeStatus probeStatus) {
        if (probeStatus.compMtx.isEmpty()) {
            return false;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            if (probeStatus.compMtx.contains(getTmpName(it.next()))) {
                return true;
            }
        }
        return false;
    }
}
