package org.apache.sysds.hops.fedplanner;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.cost.HopRel;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.IfStatement;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.utils.Explain;

/* loaded from: input_file:org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.class */
public class FederatedPlannerCostbased extends AFederatedPlanner {
    private static final Log LOG = LogFactory.getLog(FederatedPlannerCostbased.class.getName());
    private static final MemoTable hopRelMemo = new MemoTable();
    private static final Set<Long> hopRelUpdatedFinal = new HashSet();
    private static final List<Hop> terminalHops = new ArrayList();
    private static final Map<String, Hop> transientWrites = new HashMap();

    public List<Hop> getTerminalHops() {
        return terminalHops;
    }

    @Override // org.apache.sysds.hops.fedplanner.AFederatedPlanner
    public void rewriteProgram(DMLProgram dMLProgram, FunctionCallGraph functionCallGraph, FunctionCallSizeInfo functionCallSizeInfo) {
        dMLProgram.updateRepetitionEstimates();
        rewriteStatementBlocks(dMLProgram, dMLProgram.getStatementBlocks(), null);
        setFinalFedouts();
        updateExplain();
    }

    private ArrayList<StatementBlock> rewriteStatementBlocks(DMLProgram dMLProgram, List<StatementBlock> list, Map<String, Hop> map) {
        ArrayList<StatementBlock> arrayList = new ArrayList<>();
        Iterator<StatementBlock> it = list.iterator();
        while (it.hasNext()) {
            arrayList.addAll(rewriteStatementBlock(dMLProgram, it.next(), map));
        }
        return arrayList;
    }

    public ArrayList<StatementBlock> rewriteStatementBlock(DMLProgram dMLProgram, StatementBlock statementBlock, Map<String, Hop> map) {
        return statementBlock instanceof WhileStatementBlock ? rewriteWhileStatementBlock(dMLProgram, (WhileStatementBlock) statementBlock, map) : statementBlock instanceof IfStatementBlock ? rewriteIfStatementBlock(dMLProgram, (IfStatementBlock) statementBlock, map) : statementBlock instanceof ForStatementBlock ? rewriteForStatementBlock(dMLProgram, (ForStatementBlock) statementBlock, map) : statementBlock instanceof FunctionStatementBlock ? rewriteFunctionStatementBlock(dMLProgram, (FunctionStatementBlock) statementBlock, map) : rewriteDefaultStatementBlock(dMLProgram, statementBlock, map);
    }

    private ArrayList<StatementBlock> rewriteWhileStatementBlock(DMLProgram dMLProgram, WhileStatementBlock whileStatementBlock, Map<String, Hop> map) {
        selectFederatedExecutionPlan(whileStatementBlock.getPredicateHops(), map);
        Iterator<Statement> it = whileStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            WhileStatement whileStatement = (WhileStatement) it.next();
            whileStatement.setBody(rewriteStatementBlocks(dMLProgram, whileStatement.getBody(), map));
        }
        return new ArrayList<>(Collections.singletonList(whileStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteIfStatementBlock(DMLProgram dMLProgram, IfStatementBlock ifStatementBlock, Map<String, Hop> map) {
        selectFederatedExecutionPlan(ifStatementBlock.getPredicateHops(), map);
        Iterator<Statement> it = ifStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            IfStatement ifStatement = (IfStatement) it.next();
            ifStatement.setIfBody(rewriteStatementBlocks(dMLProgram, ifStatement.getIfBody(), map));
            ifStatement.setElseBody(rewriteStatementBlocks(dMLProgram, ifStatement.getElseBody(), map));
        }
        return new ArrayList<>(Collections.singletonList(ifStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteForStatementBlock(DMLProgram dMLProgram, ForStatementBlock forStatementBlock, Map<String, Hop> map) {
        selectFederatedExecutionPlan(forStatementBlock.getFromHops(), map);
        selectFederatedExecutionPlan(forStatementBlock.getToHops(), map);
        selectFederatedExecutionPlan(forStatementBlock.getIncrementHops(), map);
        Iterator<Statement> it = forStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            ForStatement forStatement = (ForStatement) it.next();
            forStatement.setBody(rewriteStatementBlocks(dMLProgram, forStatement.getBody(), map));
        }
        return new ArrayList<>(Collections.singletonList(forStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteFunctionStatementBlock(DMLProgram dMLProgram, FunctionStatementBlock functionStatementBlock, Map<String, Hop> map) {
        Iterator<Statement> it = functionStatementBlock.getStatements().iterator();
        while (it.hasNext()) {
            FunctionStatement functionStatement = (FunctionStatement) it.next();
            functionStatement.setBody(rewriteStatementBlocks(dMLProgram, functionStatement.getBody(), map));
        }
        return new ArrayList<>(Collections.singletonList(functionStatementBlock));
    }

    private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram dMLProgram, StatementBlock statementBlock, Map<String, Hop> map) {
        if (statementBlock.hasHops()) {
            Iterator<Hop> it = statementBlock.getHops().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                selectFederatedExecutionPlan(next, map);
                if (next instanceof FunctionOp) {
                    String functionName = ((FunctionOp) next).getFunctionName();
                    Map<String, Hop> paramMap = getParamMap((FunctionOp) next);
                    if (map != null && paramMap != null) {
                        paramMap.putAll(map);
                    }
                    map = paramMap;
                    rewriteStatementBlock(dMLProgram, dMLProgram.getBuiltinFunctionDictionary().getFunction(functionName), map);
                }
            }
        }
        return new ArrayList<>(Collections.singletonList(statementBlock));
    }

    private Map<String, Hop> getParamMap(FunctionOp functionOp) {
        String[] inputVariableNames = functionOp.getInputVariableNames();
        HashMap hashMap = new HashMap();
        if (inputVariableNames != null) {
            for (int i = 0; i < functionOp.getInput().size(); i++) {
                hashMap.put(inputVariableNames[i], functionOp.getInput(i));
            }
        }
        return hashMap;
    }

    private void setFinalFedouts() {
        Iterator<Hop> it = terminalHops.iterator();
        while (it.hasNext()) {
            setFinalFedout(it.next());
        }
    }

    private void setFinalFedout(Hop hop) {
        setFinalFedout(hop, hopRelMemo.getMinCostAlternative(hop));
    }

    private void setFinalFedout(Hop hop, HopRel hopRel) {
        if (!hopRelUpdatedFinal.contains(Long.valueOf(hop.getHopID()))) {
            updateFederatedOutput(hop, hopRel);
            visitInputDependency(hopRel);
        } else if (!(hopRel.hasLocalOutput() ^ hop.hasLocalOutput()) || !hopRelMemo.hasFederatedOutputAlternative(hop)) {
            updateFederatedOutput(hop, hopRel);
        } else {
            updateFederatedOutput(hop, hopRelMemo.getFederatedOutputAlternative(hop));
            hop.activatePrefetch();
        }
    }

    private void visitInputDependency(HopRel hopRel) {
        for (HopRel hopRel2 : hopRel.getInputDependency()) {
            setFinalFedout(hopRel2.getHopRef(), hopRel2);
        }
    }

    private void updateFederatedOutput(Hop hop, HopRel hopRel) {
        hop.setFederatedOutput(hopRel.getFederatedOutput());
        hop.setFederatedCost(hopRel.getCostObject());
        hop.setForcedExecType(hopRel.getExecType());
        forceFixedFedOut(hop);
        Log log = LOG;
        FEDInstruction.FederatedOutput federatedOutput = hopRel.getFederatedOutput();
        long hopID = hop.getHopID();
        hop.getOpString();
        log.trace("Updated fedOut to " + federatedOutput + " for hop " + hopID + " opcode: " + log);
        hopRelUpdatedFinal.add(Long.valueOf(hop.getHopID()));
    }

    private void forceFixedFedOut(Hop hop) {
        if (OptimizerUtils.FEDERATED_SPECS.containsKey(Integer.valueOf(hop.getBeginLine()))) {
            FEDInstruction.FederatedOutput federatedOutput = OptimizerUtils.FEDERATED_SPECS.get(Integer.valueOf(hop.getBeginLine()));
            hop.setFederatedOutput(federatedOutput);
            if (federatedOutput.isForcedFederated()) {
                hop.deactivatePrefetch();
            }
        }
    }

    private void selectFederatedExecutionPlan(ArrayList<Hop> arrayList, Map<String, Hop> map) {
        Iterator<Hop> it = arrayList.iterator();
        while (it.hasNext()) {
            selectFederatedExecutionPlan(it.next(), map);
        }
    }

    private void selectFederatedExecutionPlan(Hop hop, Map<String, Hop> map) {
        if (hop != null) {
            visitFedPlanHop(hop, map);
            if (HopRewriteUtils.isTerminalHop(hop)) {
                terminalHops.add(hop);
            }
        }
    }

    private void visitFedPlanHop(Hop hop, Map<String, Hop> map) {
        if (hopRelMemo.containsHop(hop)) {
            return;
        }
        debugLog(hop);
        Iterator<Hop> it = hop.getInput().iterator();
        while (it.hasNext()) {
            visitFedPlanHop(it.next(), map);
        }
        ArrayList<HopRel> fedPlans = getFedPlans(hop, map);
        if (fedPlans.isEmpty()) {
            fedPlans.add(getNONEHopRel(hop));
        }
        addTrace(fedPlans);
        hopRelMemo.put(hop, fedPlans);
    }

    private HopRel getNONEHopRel(Hop hop) {
        HopRel hopRel = new HopRel(hop, FEDInstruction.FederatedOutput.NONE, hopRelMemo);
        hopRel.setFType(getFederatedOut(hop, (FTypes.FType[]) hopRel.getInputDependency().stream().map((v0) -> {
            return v0.getFType();
        }).toArray(i -> {
            return new FTypes.FType[i];
        })));
        return hopRel;
    }

    private ArrayList<HopRel> getFedPlans(Hop hop, Map<String, Hop> map) {
        ArrayList<HopRel> arrayList = new ArrayList<>();
        ArrayList<Hop> input = hop.getInput();
        if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD)) {
            input = getTransientInputs(hop, map);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE)) {
            transientWrites.put(hop.getName(), hop);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.FEDERATED)) {
            arrayList.add(new HopRel(hop, FEDInstruction.FederatedOutput.FOUT, deriveFType((DataOp) hop), hopRelMemo, input));
        } else {
            arrayList.addAll(generateHopRels(hop, input));
        }
        if (isLOUTSupported(hop)) {
            arrayList.add(new HopRel(hop, FEDInstruction.FederatedOutput.LOUT, hopRelMemo, input));
        }
        return arrayList;
    }

    private ArrayList<Hop> getTransientInputs(Hop hop, Map<String, Hop> map) {
        Hop hop2 = null;
        if (map != null) {
            hop2 = map.get(hop.getName());
        }
        if (hop2 == null) {
            hop2 = transientWrites.get(hop.getName());
        }
        if (hop2 == null) {
            throw new DMLRuntimeException("Transient write not found for " + hop);
        }
        return new ArrayList<>(Collections.singletonList(hop2));
    }

    private Collection<HopRel> generateHopRels(Hop hop, List<Hop> list) {
        List<List<FTypes.FType>> allCombinations = getAllCombinations(getValidFTypes(list));
        HashMap hashMap = new HashMap();
        for (List<FTypes.FType> list2 : allCombinations) {
            if (allowsFederated(hop, (FTypes.FType[]) list2.toArray(i -> {
                return new FTypes.FType[i];
            }))) {
                FTypes.FType federatedOut = getFederatedOut(hop, (FTypes.FType[]) list2.toArray(new FTypes.FType[0]));
                if (federatedOut != null) {
                    HopRel hopRel = new HopRel(hop, FEDInstruction.FederatedOutput.FOUT, federatedOut, hopRelMemo, list, list2);
                    if (hashMap.containsKey(hopRel.getFType())) {
                        hashMap.computeIfPresent(hopRel.getFType(), (fType, hopRel2) -> {
                            return hopRel2.getCost() < hopRel.getCost() ? hopRel2 : hopRel;
                        });
                    } else {
                        hashMap.put(federatedOut, hopRel);
                    }
                } else {
                    LOG.trace("Allows federated, but FOUT is not allowed: " + hop + " input FTypes: " + list2);
                }
            } else {
                LOG.trace("Does not allow federated: " + hop + " input FTypes: " + list2);
            }
        }
        return hashMap.values();
    }

    private List<List<FTypes.FType>> getValidFTypes(List<Hop> list) {
        ArrayList arrayList = new ArrayList();
        Iterator<Hop> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(hopRelMemo.getFTypes(it.next()));
        }
        return arrayList;
    }

    public List<List<FTypes.FType>> getAllCombinations(List<List<FTypes.FType>> list) {
        ArrayList arrayList = new ArrayList();
        buildCombinations(list, arrayList, 0, new ArrayList());
        return arrayList;
    }

    public void buildCombinations(List<List<FTypes.FType>> list, List<List<FTypes.FType>> list2, int i, List<FTypes.FType> list3) {
        if (i == list.size()) {
            list2.add(list3);
            return;
        }
        for (FTypes.FType fType : list.get(i)) {
            ArrayList arrayList = new ArrayList(list3);
            arrayList.add(fType);
            buildCombinations(list, list2, i + 1, arrayList);
        }
    }

    private void updateExplain() {
        if (DMLScript.EXPLAIN == Explain.ExplainType.HOPS) {
            Explain.setMemo(hopRelMemo);
        }
    }

    private void debugLog(Hop hop) {
        if (LOG.isDebugEnabled()) {
            LOG.debug("Visiting HOP: " + hop + " Input size: " + hop.getInput().size());
            int i = 0;
            Iterator<Hop> it = hop.getInput().iterator();
            while (it.hasNext()) {
                Hop next = it.next();
                if (next == null) {
                    LOG.debug("Input at index is null: " + i);
                } else {
                    LOG.debug("HOP input: " + next + " at index " + i + " of " + hop);
                }
                i++;
            }
        }
    }

    private void addTrace(ArrayList<HopRel> arrayList) {
        if (LOG.isTraceEnabled()) {
            Iterator<HopRel> it = arrayList.iterator();
            while (it.hasNext()) {
                LOG.trace("Adding to memo: " + it.next());
            }
        }
    }

    private boolean isLOUTSupported(Hop hop) {
        return hop.getPrivacy() == null || !hop.getPrivacy().hasConstraints();
    }
}
