package org.apache.sysds.hops.codegen.template;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.sysds.hops.codegen.cplan.CNode;
import org.apache.sysds.hops.codegen.cplan.CNodeData;
import org.apache.sysds.hops.codegen.cplan.CNodeMultiAgg;
import org.apache.sysds.hops.codegen.cplan.CNodeTpl;

/* loaded from: input_file:org/apache/sysds/hops/codegen/template/CPlanCSERewriter.class */
public class CPlanCSERewriter {
    public CNodeTpl eliminateCommonSubexpressions(CNodeTpl cNodeTpl) {
        List outputs = cNodeTpl instanceof CNodeMultiAgg ? ((CNodeMultiAgg) cNodeTpl).getOutputs() : Collections.singletonList(cNodeTpl.getOutput());
        cNodeTpl.resetVisitStatusOutputs();
        Iterator it = outputs.iterator();
        while (it.hasNext()) {
            rSetStrictDataNodeComparision((CNode) it.next(), true);
        }
        HashMap<CNode, CNode> hashMap = new HashMap<>();
        cNodeTpl.resetVisitStatusOutputs();
        Iterator it2 = outputs.iterator();
        while (it2.hasNext()) {
            rEliminateCommonSubexpression((CNode) it2.next(), hashMap);
        }
        cNodeTpl.resetVisitStatusOutputs();
        Iterator it3 = outputs.iterator();
        while (it3.hasNext()) {
            rSetStrictDataNodeComparision((CNode) it3.next(), false);
        }
        cNodeTpl.resetVisitStatusOutputs();
        return cNodeTpl;
    }

    private void rEliminateCommonSubexpression(CNode cNode, HashMap<CNode, CNode> hashMap) {
        if (cNode.isVisited()) {
            return;
        }
        for (int i = 0; i < cNode.getInput().size(); i++) {
            CNode cNode2 = cNode.getInput().get(i);
            if (hashMap.containsKey(cNode2)) {
                cNode.getInput().set(i, hashMap.get(cNode2));
            }
        }
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            rEliminateCommonSubexpression(it.next(), hashMap);
        }
        hashMap.put(cNode, cNode);
        cNode.setVisited();
    }

    private void rSetStrictDataNodeComparision(CNode cNode, boolean z) {
        if (cNode.isVisited()) {
            return;
        }
        Iterator<CNode> it = cNode.getInput().iterator();
        while (it.hasNext()) {
            CNode next = it.next();
            rSetStrictDataNodeComparision(next, z);
            next.resetHash();
        }
        if (cNode instanceof CNodeData) {
            ((CNodeData) cNode).setStrictEquals(z);
        }
        cNode.setVisited();
    }
}
