package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.ReorgOp;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.ipa.IPAPass;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

/* loaded from: input_file:org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.class */
public class IPAPassRewriteFederatedPlan extends IPAPass {
    private static final Map<Long, List<HopRel>> hopRelMemo = new HashMap();

    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean isApplicable(FunctionCallGraph functionCallGraph) {
        return OptimizerUtils.FEDERATED_COMPILATION;
    }

    @Override // org.apache.sysds.hops.ipa.IPAPass
    public boolean rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        rewriteStatementBlocks(dMLProgram.getStatementBlocks());
        return false;
    }

    public ArrayList<StatementBlock> rewriteStatementBlocks(List<StatementBlock> list) {
        ArrayList<StatementBlock> arrayList = new ArrayList<>();
        Iterator<StatementBlock> it = list.iterator();
        while (it.hasNext()) {
            arrayList.addAll(rewriteStatementBlock(it.next()));
        }
        return arrayList;
    }

    public ArrayList<StatementBlock> rewriteStatementBlock(StatementBlock statementBlock) {
        if (statementBlock instanceof WhileStatementBlock) {
            return rewriteWhileStatementBlock((WhileStatementBlock) statementBlock);
        }
        if (statementBlock instanceof IfStatementBlock) {
            return rewriteIfStatementBlock((IfStatementBlock) statementBlock);
        }
        if (statementBlock instanceof ForStatementBlock) {
            return rewriteForStatementBlock((ForStatementBlock) statementBlock);
        }
        if (statementBlock instanceof FunctionStatementBlock) {
            return rewriteFunctionStatementBlock((FunctionStatementBlock) statementBlock);
        }
        selectFederatedExecutionPlan(statementBlock.getHops());
        return new ArrayList<>(Collections.singletonList(statementBlock));
    }

    private ArrayList<StatementBlock> rewriteWhileStatementBlock(WhileStatementBlock whileStatementBlock) {
        selectFederatedExecutionPlan(whileStatementBlock.getPredicateHops());
        Iterator<Statement> it = whileStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            WhileStatement whileStatement = (WhileStatement) it.next();
            whileStatement.setBody(rewriteStatementBlocks(whileStatement.getBody()));
        }
        return new ArrayList<>(Collections.singletonList(whileStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteIfStatementBlock(IfStatementBlock ifStatementBlock) {
        selectFederatedExecutionPlan(ifStatementBlock.getPredicateHops());
        Iterator<Statement> it = ifStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            IfStatement ifStatement = (IfStatement) it.next();
            ifStatement.setIfBody(rewriteStatementBlocks(ifStatement.getIfBody()));
            ifStatement.setElseBody(rewriteStatementBlocks(ifStatement.getElseBody()));
        }
        return new ArrayList<>(Collections.singletonList(ifStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteForStatementBlock(ForStatementBlock forStatementBlock) {
        selectFederatedExecutionPlan(forStatementBlock.getFromHops());
        selectFederatedExecutionPlan(forStatementBlock.getToHops());
        selectFederatedExecutionPlan(forStatementBlock.getIncrementHops());
        Iterator<Statement> it = forStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            ForStatement forStatement = (ForStatement) it.next();
            forStatement.setBody(rewriteStatementBlocks(forStatement.getBody()));
        }
        return new ArrayList<>(Collections.singletonList(forStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteFunctionStatementBlock(FunctionStatementBlock functionStatementBlock) {
        Iterator<Statement> it = functionStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            FunctionStatement functionStatement = (FunctionStatement) it.next();
            functionStatement.setBody(rewriteStatementBlocks(functionStatement.getBody()));
        }
        return new ArrayList<>(Collections.singletonList(functionStatementBlock));
    }

    private void setFinalFedout(Hop hop) {
        setFinalFedout(hop, hopRelMemo.get(Long.valueOf(hop.getHopID())).stream().min(Comparator.comparingDouble((v0) -> {
            return v0.getCost();
        })).orElseThrow(() -> {
            return new DMLException("Hop root " + hop + " has no feasible federated output alternatives");
        }));
    }

    private void setFinalFedout(Hop hop, HopRel hopRel) {
        updateFederatedOutput(hop, hopRel);
        visitInputDependency(hopRel);
    }

    private void visitInputDependency(HopRel hopRel) {
        for (HopRel hopRel2 : hopRel.getInputDependency()) {
            setFinalFedout(hopRel2.getHopRef(), hopRel2);
        }
    }

    private void updateFederatedOutput(Hop hop, HopRel hopRel) {
        hop.setFederatedOutput(hopRel.getFederatedOutput());
        hop.setFederatedCost(hopRel.getCostObject());
    }

    private void selectFederatedExecutionPlan(ArrayList<Hop> arrayList) {
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            selectFederatedExecutionPlan(it.next());
        }
    }

    private void selectFederatedExecutionPlan(Hop hop) {
        visitFedPlanHop(hop);
        setFinalFedout(hop);
    }

    private void visitFedPlanHop(Hop hop) {
        if (hopRelMemo.containsKey(Long.valueOf(hop.getHopID()))) {
            return;
        }
        if (hop.getInput() != null && hop.getInput().size() > 0) {
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                visitFedPlanHop(it.next());
            }
        }
        ArrayList arrayList = new ArrayList();
        if (isFedInstSupportedHop(hop)) {
            for (FEDInstruction.FederatedOutput federatedOutput : FEDInstruction.FederatedOutput.values()) {
                if (isFedOutSupported(hop, federatedOutput)) {
                    arrayList.add(new HopRel(hop, federatedOutput, hopRelMemo));
                }
            }
        }
        if (arrayList.isEmpty()) {
            arrayList.add(new HopRel(hop, FEDInstruction.FederatedOutput.NONE, hopRelMemo));
        }
        hopRelMemo.put(Long.valueOf(hop.getHopID()), arrayList);
    }

    private boolean isFedInstSupportedHop(Hop hop) {
        return (hop instanceof AggBinaryOp) || (hop instanceof BinaryOp) || (hop instanceof ReorgOp) || (hop instanceof AggUnaryOp) || (hop instanceof TernaryOp) || (hop instanceof DataOp);
    }

    private boolean isFedOutSupported(Hop hop, FEDInstruction.FederatedOutput federatedOutput) {
        switch (federatedOutput) {
            case FOUT:
                return isFOUTSupported(hop);
            case LOUT:
                return isLOUTSupported(hop);
            case NONE:
                return false;
            default:
                return true;
        }
    }

    private boolean isFOUTSupported(Hop hop) {
        if ((hop instanceof AggUnaryOp) && hop.isScalar()) {
            return false;
        }
        return !hop.getInput().stream().noneMatch(hop2 -> {
            return hopRelMemo.get(Long.valueOf(hop2.getHopID())).stream().anyMatch((v0) -> {
                return v0.hasFederatedOutput();
            });
        }) || hop.isFederatedDataOp();
    }

    private boolean isLOUTSupported(Hop hop) {
        return hop.getPrivacy() == null || !hop.getPrivacy().hasConstraints();
    }
}
