package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/RewriteConstantFolding.class */
public class RewriteConstantFolding extends HopRewriteRule {
    private static final String TMP_VARNAME = "__cf_tmp";
    private BasicProgramBlock _tmpPB = null;
    private ExecutionContext _tmpEC = null;

    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> arrayList, ProgramRewriteStatus programRewriteStatus) {
        if (arrayList == null) {
            return null;
        }
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList.set(i, rule_ConstantFolding(arrayList.get(i)));
        }
        return arrayList;
    }

    @Override // org.apache.sysds.hops.rewrite.HopRewriteRule
    public Hop rewriteHopDAG(Hop hop, ProgramRewriteStatus programRewriteStatus) {
        if (hop == null) {
            return null;
        }
        return rule_ConstantFolding(hop);
    }

    private Hop rule_ConstantFolding(Hop hop) {
        return rConstantFoldingExpression(hop);
    }

    private Hop rConstantFoldingExpression(Hop hop) {
        if (hop.isVisited()) {
            return hop;
        }
        for (int i = 0; i < hop.getInput().size(); i++) {
            rConstantFoldingExpression(hop.getInput().get(i));
        }
        LiteralOp literalOp = null;
        if (hop.getDataType() == Types.DataType.SCALAR && (isApplicableUnaryOp(hop) || isApplicableBinaryOp(hop) || isApplicableTernaryOp(hop) || isApplicableNaryOp(hop))) {
            literalOp = evalScalarOperation(hop);
        } else if (isApplicableFalseConjunctivePredicate(hop)) {
            literalOp = new LiteralOp(false);
        } else if (isApplicableTrueDisjunctivePredicate(hop)) {
            literalOp = new LiteralOp(true);
        }
        if (literalOp != null) {
            if (hop.getParent().isEmpty()) {
                hop = literalOp;
            } else {
                Iterator it = new ArrayList(hop.getParent()).iterator();
                while (it.hasNext()) {
                    HopRewriteUtils.replaceChildReference((Hop) it.next(), hop, literalOp);
                }
            }
        }
        hop.setVisited();
        return hop;
    }

    private LiteralOp evalScalarOperation(Hop hop) {
        DataOp dataOp = new DataOp(TMP_VARNAME, hop.getDataType(), hop.getValueType(), hop, Types.OpOpData.TRANSIENTWRITE, TMP_VARNAME);
        Dag<Lop> dag = new Dag<>();
        Recompiler.rClearLops(dataOp);
        dataOp.constructLops().addToDag(dag);
        ArrayList<Instruction> jobs = dag.getJobs(null, ConfigurationManager.getDMLConfig());
        ExecutionContext executionContext = getExecutionContext();
        BasicProgramBlock programBlock = getProgramBlock();
        programBlock.setInstructions(jobs);
        programBlock.execute(executionContext);
        LiteralOp createLiteralOp = ScalarObjectFactory.createLiteralOp((ScalarObject) executionContext.getVariable(TMP_VARNAME));
        dataOp.getInput().clear();
        hop.getParent().remove(dataOp);
        programBlock.setInstructions(null);
        executionContext.getVariables().removeAll();
        HopRewriteUtils.setOutputParametersForScalar(createLiteralOp);
        return createLiteralOp;
    }

    private BasicProgramBlock getProgramBlock() {
        if (this._tmpPB == null) {
            this._tmpPB = new BasicProgramBlock(new Program());
        }
        return this._tmpPB;
    }

    private ExecutionContext getExecutionContext() {
        if (this._tmpEC == null) {
            this._tmpEC = ExecutionContextFactory.createContext();
        }
        return this._tmpEC;
    }

    private static boolean isApplicableBinaryOp(Hop hop) {
        ArrayList<Hop> input = hop.getInput();
        return (hop instanceof BinaryOp) && (input.get(0) instanceof LiteralOp) && (input.get(1) instanceof LiteralOp) && ((BinaryOp) hop).getOp() != Types.OpOp2.CBIND && ((BinaryOp) hop).getOp() != Types.OpOp2.RBIND;
    }

    private static boolean isApplicableUnaryOp(Hop hop) {
        return (!(hop instanceof UnaryOp) || !(hop.getInput().get(0) instanceof LiteralOp) || ((UnaryOp) hop).getOp() == Types.OpOp1.EXISTS || ((UnaryOp) hop).getOp() == Types.OpOp1.PRINT || ((UnaryOp) hop).getOp() == Types.OpOp1.ASSERT || ((UnaryOp) hop).getOp() == Types.OpOp1.STOP || hop.getDataType() != Types.DataType.SCALAR) ? false : true;
    }

    private static boolean isApplicableTernaryOp(Hop hop) {
        return HopRewriteUtils.isTernary(hop, Types.OpOp3.IFELSE, Types.OpOp3.MINUS_MULT, Types.OpOp3.PLUS_MULT) && hop.getInput().stream().allMatch(hop2 -> {
            return hop2 instanceof LiteralOp;
        });
    }

    private static boolean isApplicableNaryOp(Hop hop) {
        return HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.getInput().stream().allMatch(hop2 -> {
            return hop2 instanceof LiteralOp;
        });
    }

    private static boolean isApplicableFalseConjunctivePredicate(Hop hop) {
        ArrayList<Hop> input = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Types.OpOp2.AND) && hop.getDataType().isScalar() && (((input.get(0) instanceof LiteralOp) && !((LiteralOp) input.get(0)).getBooleanValue()) || ((input.get(1) instanceof LiteralOp) && !((LiteralOp) input.get(1)).getBooleanValue()));
    }

    private static boolean isApplicableTrueDisjunctivePredicate(Hop hop) {
        ArrayList<Hop> input = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Types.OpOp2.OR) && hop.getDataType().isScalar() && (((input.get(0) instanceof LiteralOp) && ((LiteralOp) input.get(0)).getBooleanValue()) || ((input.get(1) instanceof LiteralOp) && ((LiteralOp) input.get(1)).getBooleanValue()));
    }
}
