package org.apache.sysds.hops.fedplanner;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

/* loaded from: input_file:org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.class */
public class FederatedPlannerFedAll extends AFederatedPlanner {
    @Override // org.apache.sysds.hops.fedplanner.AFederatedPlanner
    public void rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        HashMap hashMap = new HashMap();
        Iterator<StatementBlock> it = dMLProgram.getStatementBlocks().iterator();
        while (it.hasNext()) {
            rRewriteStatementBlock(it.next(), hashMap);
        }
    }

    private void rRewriteStatementBlock(StatementBlock statementBlock, Map<String, FTypes.FType> map) {
        if (statementBlock instanceof FunctionStatementBlock) {
            Iterator<StatementBlock> it = ((FunctionStatement) ((FunctionStatementBlock) statementBlock).getStatement(0)).getBody().iterator();
            while (it.hasNext()) {
                rRewriteStatementBlock(it.next(), map);
            }
            return;
        }
        if (statementBlock instanceof WhileStatementBlock) {
            WhileStatementBlock whileStatementBlock = (WhileStatementBlock) statementBlock;
            WhileStatement whileStatement = (WhileStatement) whileStatementBlock.getStatement(0);
            rRewriteHop(whileStatementBlock.getPredicateHops(), new HashMap(), Collections.emptyMap());
            Iterator<StatementBlock> it2 = whileStatement.getBody().iterator();
            while (it2.hasNext()) {
                rRewriteStatementBlock(it2.next(), map);
            }
            return;
        }
        if (statementBlock instanceof IfStatementBlock) {
            IfStatementBlock ifStatementBlock = (IfStatementBlock) statementBlock;
            IfStatement ifStatement = (IfStatement) ifStatementBlock.getStatement(0);
            rRewriteHop(ifStatementBlock.getPredicateHops(), new HashMap(), Collections.emptyMap());
            Iterator<StatementBlock> it3 = ifStatement.getIfBody().iterator();
            while (it3.hasNext()) {
                rRewriteStatementBlock(it3.next(), map);
            }
            Iterator<StatementBlock> it4 = ifStatement.getElseBody().iterator();
            while (it4.hasNext()) {
                rRewriteStatementBlock(it4.next(), map);
            }
            return;
        }
        if (statementBlock instanceof ForStatementBlock) {
            ForStatementBlock forStatementBlock = (ForStatementBlock) statementBlock;
            ForStatement forStatement = (ForStatement) forStatementBlock.getStatement(0);
            rRewriteHop(forStatementBlock.getFromHops(), new HashMap(), Collections.emptyMap());
            rRewriteHop(forStatementBlock.getToHops(), new HashMap(), Collections.emptyMap());
            rRewriteHop(forStatementBlock.getIncrementHops(), new HashMap(), Collections.emptyMap());
            Iterator<StatementBlock> it5 = forStatement.getBody().iterator();
            while (it5.hasNext()) {
                rRewriteStatementBlock(it5.next(), map);
            }
            return;
        }
        HashMap hashMap = new HashMap();
        if (statementBlock.getHops() != null) {
            Iterator<Hop> it6 = statementBlock.getHops().iterator();
            while (it6.hasNext()) {
                rRewriteHop(it6.next(), hashMap, map);
            }
        }
        if (statementBlock.getHops() != null) {
            Iterator<Hop> it7 = statementBlock.getHops().iterator();
            while (it7.hasNext()) {
                Hop next = it7.next();
                if (HopRewriteUtils.isData(next, Types.OpOpData.TRANSIENTWRITE)) {
                    map.put(next.getName(), hashMap.get(Long.valueOf(next.getInput(0).getHopID())));
                }
            }
        }
    }

    private void rRewriteHop(Hop hop, Map<Long, FTypes.FType> map, Map<String, FTypes.FType> map2) {
        if (map.containsKey(Long.valueOf(hop.getHopID()))) {
            return;
        }
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            rRewriteHop(it.next(), map, map2);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED)) {
            map.put(Long.valueOf(hop.getHopID()), deriveFType((DataOp) hop));
            return;
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD)) {
            map.put(Long.valueOf(hop.getHopID()), map2.get(hop.getName()));
            return;
        }
        if (!allowsFederated(hop, map)) {
            map.put(Long.valueOf(hop.getHopID()), null);
            return;
        }
        hop.setForcedExecType(Types.ExecType.FED);
        map.put(Long.valueOf(hop.getHopID()), getFederatedOut(hop, map));
        if (map.get(Long.valueOf(hop.getHopID())) != null) {
            hop.setFederatedOutput(FEDInstruction.FederatedOutput.FOUT);
        }
    }
}
