package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteCompressedReblock.class */
public class RewriteCompressedReblock extends StatementBlockRewriteRule {
    private static final Log LOG = LogFactory.getLog(RewriteCompressedReblock.class.getName());
    private static final String TMP_PREFIX = "__cmtx";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteCompressedReblock$ProbeStatus.class */
    public static class ProbeStatus {
        private final long startHopID;
        private final DMLProgram prog;
        private int numberCompressedOpsExecuted;
        private int numberDecompressedOpsExecuted;
        private int inefficientSupportedOpsExecuted;
        private boolean foundStart;
        private boolean usedInLoop;
        private boolean condUpdate;
        private boolean nonApplicable;
        private HashSet<String> procFn;
        private HashSet<String> compMtx;

        private ProbeStatus(long j, DMLProgram dMLProgram) {
            this.numberCompressedOpsExecuted = 0;
            this.numberDecompressedOpsExecuted = 0;
            this.inefficientSupportedOpsExecuted = 0;
            this.foundStart = false;
            this.usedInLoop = false;
            this.condUpdate = false;
            this.nonApplicable = false;
            this.procFn = new HashSet<>();
            this.compMtx = new HashSet<>();
            this.startHopID = j;
            this.prog = dMLProgram;
        }

        private ProbeStatus(ProbeStatus probeStatus) {
            this.numberCompressedOpsExecuted = 0;
            this.numberDecompressedOpsExecuted = 0;
            this.inefficientSupportedOpsExecuted = 0;
            this.foundStart = false;
            this.usedInLoop = false;
            this.condUpdate = false;
            this.nonApplicable = false;
            this.procFn = new HashSet<>();
            this.compMtx = new HashSet<>();
            this.startHopID = probeStatus.startHopID;
            this.prog = probeStatus.prog;
            this.foundStart = probeStatus.foundStart;
            this.usedInLoop = probeStatus.usedInLoop;
            this.condUpdate = probeStatus.condUpdate;
            this.nonApplicable = probeStatus.nonApplicable;
            this.procFn.addAll(probeStatus.procFn);
        }

        private void rAnalyzeProgram(StatementBlock statementBlock) {
            if (statementBlock instanceof FunctionStatementBlock) {
                Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
                while (it.hasNext()) {
                    rAnalyzeProgram(it.next());
                }
                return;
            }
            if (statementBlock instanceof WhileStatementBlock) {
                WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
                Iterator<StatementBlock> it2 = ((WhileStatement) whileStatementBlock.getStatement(0)).getBody().iterator();
                while (it2.hasNext()) {
                    rAnalyzeProgram(it2.next());
                }
                if (whileStatementBlock.variablesRead().containsAnyName(this.compMtx)) {
                    this.usedInLoop = true;
                    return;
                }
                return;
            }
            if (statementBlock instanceof IfStatementBlock) {
                IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
                IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
                Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
                while (it3.hasNext()) {
                    rAnalyzeProgram(it3.next());
                }
                Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
                while (it4.hasNext()) {
                    rAnalyzeProgram(it4.next());
                }
                if (ifStatementBlock.variablesUpdated().containsAnyName(this.compMtx)) {
                    this.condUpdate = true;
                    return;
                }
                return;
            }
            if (statementBlock instanceof ForStatementBlock) {
                ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
                Iterator<StatementBlock> it5 = ((ForStatement) forStatementBlock.getStatement(0)).getBody().iterator();
                while (it5.hasNext()) {
                    rAnalyzeProgram(it5.next());
                }
                if (forStatementBlock.variablesRead().containsAnyName(this.compMtx)) {
                    this.usedInLoop = true;
                    return;
                }
                return;
            }
            if (statementBlock.getHops() != null) {
                ArrayList<Hop> hops = statementBlock.getHops();
                Hop.resetVisitStatus(hops);
                Iterator<Hop> it6 = hops.iterator();
                while (it6.hasNext()) {
                    rAnalyzeHopDag(it6.next());
                }
                this.compMtx.removeIf(str -> {
                    return str.startsWith(RewriteCompressedReblock.TMP_PREFIX);
                });
                Hop.resetVisitStatus(hops);
            }
        }

        private void rAnalyzeHopDag(Hop hop) {
            if (hop.isVisited()) {
                return;
            }
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                rAnalyzeHopDag(it.next());
            }
            if (hop.getHopID() == this.startHopID) {
                this.compMtx.add(getTmpName(hop));
                this.foundStart = true;
            }
            if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE) && this.compMtx.contains(getTmpName(hop.getInput().get(0)))) {
                this.compMtx.add(hop.getName());
            } else if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) && this.compMtx.contains(hop.getName())) {
                this.compMtx.add(getTmpName(hop));
            } else if (hasCompressedInput(hop)) {
                if (hop instanceof FunctionOp) {
                    handleFunctionOps(hop);
                } else {
                    handleApplicableOps(hop);
                }
            }
            hop.setVisited();
        }

        private boolean hasCompressedInput(Hop hop) {
            if (this.compMtx.isEmpty()) {
                return false;
            }
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                if (this.compMtx.contains(getTmpName(it.next()))) {
                    return true;
                }
            }
            return false;
        }

        private static String getTmpName(Hop hop) {
            return "__cmtx" + hop.getHopID();
        }

        private boolean isCompressed(Hop hop) {
            return this.compMtx.contains(getTmpName(hop));
        }

        private void handleFunctionOps(Hop hop) {
            FunctionOp functionOp = (FunctionOp) hop;
            String functionKey = functionOp.getFunctionKey();
            if (this.procFn.contains(functionKey)) {
                return;
            }
            this.procFn.add(functionKey);
            FunctionStatementBlock functionStatementBlock = this.prog.getFunctionStatementBlock(functionKey);
            FunctionStatement functionStatement = (FunctionStatement) functionStatementBlock.getStatement(0);
            ProbeStatus probeStatus = new ProbeStatus(this);
            for (int i = 0; i < functionOp.getInput().size(); i++) {
                if (this.compMtx.contains(getTmpName(functionOp.getInput().get(i)))) {
                    probeStatus.compMtx.add(functionStatement.getInputParams().get(i).getName());
                }
            }
            probeStatus.rAnalyzeProgram(functionStatementBlock);
            this.foundStart |= probeStatus.foundStart;
            this.usedInLoop |= probeStatus.usedInLoop;
            this.condUpdate |= probeStatus.condUpdate;
            this.nonApplicable |= probeStatus.nonApplicable;
            this.numberCompressedOpsExecuted += probeStatus.numberCompressedOpsExecuted;
            this.numberDecompressedOpsExecuted += probeStatus.numberDecompressedOpsExecuted;
            String[] outputVariableNames = functionOp.getOutputVariableNames();
            for (int i2 = 0; i2 < outputVariableNames.length; i2++) {
                if (probeStatus.compMtx.contains(functionStatement.getOutputParams().get(i2).getName())) {
                    this.compMtx.add(outputVariableNames[i2]);
                }
            }
        }

        private void handleApplicableOps(Hop hop) {
            boolean isBinaryMatrixColVectorOperation = false | (hop instanceof AggBinaryOp) | HopRewriteUtils.isBinaryMatrixColVectorOperation(hop);
            boolean isAggUnaryOp = HopRewriteUtils.isAggUnaryOp(hop, Types.AggOp.SUM, Types.AggOp.SUM_SQ, Types.AggOp.MIN, Types.AggOp.MAX, Types.AggOp.MEAN);
            if (isAggUnaryOp && hop.getDim2() < 2 && hop.getDim1() >= 1000) {
                this.inefficientSupportedOpsExecuted++;
            }
            boolean z = isBinaryMatrixColVectorOperation | isAggUnaryOp;
            boolean isBinaryMatrixScalarOperation = false | HopRewriteUtils.isBinaryMatrixScalarOperation(hop) | HopRewriteUtils.isBinaryMatrixRowVectorOperation(hop) | ((hop instanceof AggBinaryOp) && isCompressed(hop.getInput().get(0)));
            boolean z2 = isBinaryMatrixScalarOperation ? false : z;
            boolean isBinary = isBinaryMatrixScalarOperation | HopRewriteUtils.isBinary(hop, Types.OpOp2.CBIND);
            boolean isUnary = HopRewriteUtils.isUnary(hop, Types.OpOp1.NROW, Types.OpOp1.NCOL);
            if (HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE)) {
                this.numberCompressedOpsExecuted += 4;
                isBinary = true;
            }
            boolean z3 = z2 || isBinary || isUnary;
            if (z3) {
                this.numberCompressedOpsExecuted++;
            } else {
                RewriteCompressedReblock.LOG.warn("Decompession op: " + hop);
                this.numberDecompressedOpsExecuted++;
            }
            this.nonApplicable |= !z3;
            if (isBinary) {
                this.compMtx.add(getTmpName(hop));
            }
        }

        private boolean isValidAutoCompression() {
            return this.foundStart && this.usedInLoop && !this.condUpdate && !this.nonApplicable;
        }

        private boolean isValidAggressiveCompression() {
            if (RewriteCompressedReblock.LOG.isDebugEnabled()) {
                RewriteCompressedReblock.LOG.debug(toString());
            }
            return this.inefficientSupportedOpsExecuted < this.numberCompressedOpsExecuted && (this.usedInLoop || this.numberCompressedOpsExecuted > 3) && this.numberDecompressedOpsExecuted < 1;
        }

        public String toString() {
            return ("Compressed ProbeStatus : hopID =" + this.startHopID) + ("\n CLA Ops         : " + this.numberCompressedOpsExecuted) + ("\n Decompress Ops  : " + this.numberDecompressedOpsExecuted) + ("\n Inefficient Ops : " + this.inefficientSupportedOpsExecuted) + ("\n foundStart " + this.foundStart + " , inLoop :" + this.usedInLoop + " , condUpdate : " + this.condUpdate + " , nonApplicable : " + this.nonApplicable) + ("\n compressed Matrix: " + this.compMtx) + ("\n Prog Fn " + this.procFn);
        }
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public boolean createsSplitDag() {
        return false;
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock, ProgramRewriteStatus programRewriteStatus) {
        if (!HopRewriteUtils.isLastLevelStatementBlock(statementBlock) || statementBlock.getHops() == null) {
            return Arrays.asList(statementBlock);
        }
        Compression.CompressConfig compressConfig = ConfigurationManager.getCompressConfig();
        if (compressConfig.isEnabled()) {
            Hop.resetVisitStatus(statementBlock.getHops());
            Iterator<Hop> it = statementBlock.getHops().iterator();
            while (it.hasNext()) {
                injectCompressionDirective(it.next(), compressConfig, statementBlock.getDMLProg());
            }
            Hop.resetVisitStatus(statementBlock.getHops());
        }
        return Arrays.asList(statementBlock);
    }

    @Override // org.apache.sysds.hops.rewrite.StatementBlockRewriteRule
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list, ProgramRewriteStatus programRewriteStatus) {
        return list;
    }

    private static void injectCompressionDirective(Hop hop, Compression.CompressConfig compressConfig, DMLProgram dMLProgram) {
        if (hop.isVisited() || hop.requiresCompression() || hop.hasCompressedInput()) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            injectCompressionDirective(it.next(), compressConfig, dMLProgram);
        }
        switch (compressConfig) {
            case TRUE:
                if (satisfiesCompressionCondition(hop)) {
                    hop.setRequiresCompression();
                    break;
                }
                break;
            case AUTO:
                if (OptimizerUtils.isSparkExecutionMode() && satisfiesAutoCompressionCondition(hop, dMLProgram)) {
                    hop.setRequiresCompression();
                    break;
                }
                break;
            case COST:
                if (satisfiesCostCompressionCondition(hop, dMLProgram)) {
                    hop.setRequiresCompression();
                    break;
                }
                break;
        }
        if (satisfiesDeCompressionCondition(hop)) {
            hop.setRequiresDeCompression();
        }
        hop.setVisited();
    }

    public static boolean satisfiesSizeConstraintsForCompression(Hop hop) {
        if (hop.getDim2() < 1) {
            return false;
        }
        long dim1 = hop.getDim1();
        long dim2 = hop.getDim2();
        return (dim2 << 10) <= dim1 * dim1 || (hop.getSparsity() < 1.0E-4d && dim2 > 100);
    }

    public static boolean satisfiesCompressionCondition(Hop hop) {
        boolean z = false;
        if (satisfiesSizeConstraintsForCompression(hop)) {
            z = false | HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD);
        }
        return z;
    }

    public static boolean satisfiesAggressiveCompressionCondition(Hop hop) {
        boolean z = HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE) && hop.getInput(0).getDataType().isMatrix() && hop.getInput(1).getDataType().isMatrix();
        if (satisfiesSizeConstraintsForCompression(hop)) {
            z = z | HopRewriteUtils.isData(hop, Types.OpOpData.PERSISTENTREAD) | HopRewriteUtils.isUnary(hop, Types.OpOp1.ROUND, Types.OpOp1.FLOOR, Types.OpOp1.NOT, Types.OpOp1.CEIL) | HopRewriteUtils.isBinary(hop, Types.OpOp2.EQUAL, Types.OpOp2.NOTEQUAL, Types.OpOp2.LESS, Types.OpOp2.LESSEQUAL, Types.OpOp2.GREATER, Types.OpOp2.GREATEREQUAL, Types.OpOp2.AND, Types.OpOp2.OR, Types.OpOp2.MODULUS) | HopRewriteUtils.isTernary(hop, Types.OpOp3.CTABLE);
        }
        if (LOG.isDebugEnabled() && z) {
            LOG.debug("Operation Satisfies: " + hop);
        }
        return z;
    }

    private static boolean satisfiesDeCompressionCondition(Hop hop) {
        return false;
    }

    private static boolean outOfCore(Hop hop) {
        return ((double) OptimizerUtils.estimatePartitionedSizeExactSparsity(hop)) > SparkExecutionContext.getDataMemoryBudget(true, true);
    }

    private static boolean ultraSparse(Hop hop) {
        return OptimizerUtils.getSparsity(hop) < 4.0E-5d;
    }

    private static boolean satisfiesAutoCompressionCondition(Hop hop, DMLProgram dMLProgram) {
        if (satisfiesCompressionCondition(hop) && hop.getMemEstimate() >= OptimizerUtils.getLocalMemBudget() && hop.dimsKnown(true) && outOfCore(hop) && !ultraSparse(hop)) {
            return analyseProgram(hop, dMLProgram).isValidAutoCompression();
        }
        return false;
    }

    private static boolean satisfiesCostCompressionCondition(Hop hop, DMLProgram dMLProgram) {
        return true & satisfiesAggressiveCompressionCondition(hop) & hop.dimsKnown(false) & analyseProgram(hop, dMLProgram).isValidAggressiveCompression();
    }

    private static ProbeStatus analyseProgram(Hop hop, DMLProgram dMLProgram) {
        ProbeStatus probeStatus = new ProbeStatus(hop.getHopID(), dMLProgram);
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            probeStatus.rAnalyzeProgram(it.next());
        }
        return probeStatus;
    }
}
