package org.apache.sysds.hops.codegen.template;

import java.util.ArrayList;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysds.hops.codegen.cplan.CNodeOuterProduct;
import org.apache.sysds.hops.codegen.cplan.CNodeRow;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;
import org.apache.sysds.hops.codegen.cplan.CNodeUnary;

/* loaded from: input_file:org/apache/sysds/hops/codegen/template/CPlanOpRewriter.class */
public class CPlanOpRewriter {
    public CNodeTpl simplifyCPlan(CNodeTpl cNodeTpl) {
        CNodeTpl rewriteRemoveOuterNeq0 = rewriteRemoveOuterNeq0(cNodeTpl);
        if (rewriteRemoveOuterNeq0 instanceof CNodeMultiAgg) {
            ArrayList<CNode> outputs = ((CNodeMultiAgg) rewriteRemoveOuterNeq0).getOutputs();
            for (int i = 0; i < outputs.size(); i++) {
                outputs.set(i, rSimplifyCNode(outputs.get(i)));
            }
        } else {
            rewriteRemoveOuterNeq0.setOutput(rSimplifyCNode(rewriteRemoveOuterNeq0.getOutput()));
            if (TemplateUtils.containsFusedRowVecAgg(rewriteRemoveOuterNeq0)) {
                ((CNodeRow) rewriteRemoveOuterNeq0).setNumVectorIntermediates(((CNodeRow) rewriteRemoveOuterNeq0).getNumVectorIntermediates() - 2);
            }
        }
        return rewriteRemoveOuterNeq0;
    }

    private static CNode rSimplifyCNode(CNode cNode) {
        for (int i = 0; i < cNode.getInput().size(); i++) {
            cNode.getInput().set(i, rSimplifyCNode(cNode.getInput().get(i)));
        }
        return rewriteRowMaxsVectMult(rewriteBinaryMult2Vect(rewriteBinaryMult2(rewriteBinaryPow2Vect(rewriteBinaryPow2(rewriteRowSumSq(rewriteRowCountNnz(cNode)))))));
    }

    private static CNode rewriteRowMaxsVectMult(CNode cNode) {
        if (TemplateUtils.isUnary(cNode, CNodeUnary.UnaryType.ROW_MAXS)) {
            CNode cNode2 = cNode.getInput().get(0);
            if (TemplateUtils.isBinary(cNode2, CNodeBinary.BinType.VECT_MULT)) {
                return new CNodeBinary(cNode2.getInput().get(0), cNode2.getInput().get(1), CNodeBinary.BinType.ROWMAXS_VECTMULT);
            }
        }
        return cNode;
    }

    private static CNode rewriteRowCountNnz(CNode cNode) {
        return (TemplateUtils.isUnary(cNode, CNodeUnary.UnaryType.ROW_SUMS) && TemplateUtils.isBinary(cNode.getInput().get(0), CNodeBinary.BinType.VECT_NOTEQUAL_SCALAR) && cNode.getInput().get(0).getInput().get(1).isLiteral() && cNode.getInput().get(0).getInput().get(1).getVarname().equals("0")) ? new CNodeUnary(cNode.getInput().get(0).getInput().get(0), CNodeUnary.UnaryType.ROW_COUNTNNZS) : cNode;
    }

    private static CNode rewriteRowSumSq(CNode cNode) {
        return (TemplateUtils.isUnary(cNode, CNodeUnary.UnaryType.ROW_SUMS) && TemplateUtils.isBinary(cNode.getInput().get(0), CNodeBinary.BinType.VECT_POW_SCALAR) && cNode.getInput().get(0).getInput().get(1).isLiteral() && cNode.getInput().get(0).getInput().get(1).getVarname().equals("2")) ? new CNodeUnary(cNode.getInput().get(0).getInput().get(0), CNodeUnary.UnaryType.ROW_SUMSQS) : cNode;
    }

    private static CNode rewriteBinaryPow2(CNode cNode) {
        return (TemplateUtils.isBinary(cNode, CNodeBinary.BinType.POW) && cNode.getInput().get(1).isLiteral() && cNode.getInput().get(1).getVarname().equals("2")) ? new CNodeUnary(cNode.getInput().get(0), CNodeUnary.UnaryType.POW2) : cNode;
    }

    private static CNode rewriteBinaryPow2Vect(CNode cNode) {
        return (TemplateUtils.isBinary(cNode, CNodeBinary.BinType.VECT_POW_SCALAR) && cNode.getInput().get(1).isLiteral() && cNode.getInput().get(1).getVarname().equals("2")) ? new CNodeUnary(cNode.getInput().get(0), CNodeUnary.UnaryType.VECT_POW2) : cNode;
    }

    private static CNode rewriteBinaryMult2(CNode cNode) {
        return (TemplateUtils.isBinary(cNode, CNodeBinary.BinType.MULT) && cNode.getInput().get(1).isLiteral() && cNode.getInput().get(1).getVarname().equals("2")) ? new CNodeUnary(cNode.getInput().get(0), CNodeUnary.UnaryType.MULT2) : cNode;
    }

    private static CNode rewriteBinaryMult2Vect(CNode cNode) {
        return (TemplateUtils.isBinary(cNode, CNodeBinary.BinType.VECT_MULT) && cNode.getInput().get(1).isLiteral() && cNode.getInput().get(1).getVarname().equals("2")) ? new CNodeUnary(cNode.getInput().get(0), CNodeUnary.UnaryType.VECT_MULT2) : cNode;
    }

    private static CNodeTpl rewriteRemoveOuterNeq0(CNodeTpl cNodeTpl) {
        if (cNodeTpl instanceof CNodeOuterProduct) {
            rFindAndRemoveBinaryMS(cNodeTpl.getOutput(), (CNodeData) cNodeTpl.getInput().get(0), CNodeBinary.BinType.NOTEQUAL, "0", "1");
        }
        return cNodeTpl;
    }

    private static void rFindAndRemoveBinaryMS(CNode cNode, CNodeData cNodeData, CNodeBinary.BinType binType, String str, String str2) {
        for (int i = 0; i < cNode.getInput().size(); i++) {
            CNode cNode2 = cNode.getInput().get(i);
            if (TemplateUtils.isBinary(cNode2, binType) && cNode2.getInput().get(1).isLiteral() && cNode2.getInput().get(1).getVarname().equals(str) && (cNode2.getInput().get(0) instanceof CNodeData) && ((CNodeData) cNode2.getInput().get(0)).getHopID() == cNodeData.getHopID()) {
                CNodeData cNodeData2 = new CNodeData(new LiteralOp(str2));
                cNodeData2.setLiteral(true);
                cNode.getInput().set(i, cNodeData2);
            } else {
                rFindAndRemoveBinaryMS(cNode2, cNodeData, binType, str, str2);
            }
        }
    }
}
