package org.apache.sysds.hops.cost;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.sysds.api.DMLException;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.hops.fedplanner.MemoTable;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;

/* loaded from: input_file:org/apache/sysds/hops/cost/HopRel.class */
public class HopRel {
    protected final Hop hopRef;
    protected final FEDInstruction.FederatedOutput fedOut;
    protected Types.ExecType execType;
    protected FTypes.FType fType;
    protected FederatedCost cost;
    protected final Set<Long> costPointerSet;
    protected List<Hop> inputHops;
    protected List<HopRel> inputDependency;

    public HopRel(Hop hop, FEDInstruction.FederatedOutput federatedOutput, MemoTable memoTable) {
        this(hop, federatedOutput, null, memoTable, hop.getInput());
    }

    public HopRel(Hop hop, FEDInstruction.FederatedOutput federatedOutput, MemoTable memoTable, ArrayList<Hop> arrayList) {
        this(hop, federatedOutput, null, memoTable, arrayList);
    }

    private HopRel(Hop hop, FEDInstruction.FederatedOutput federatedOutput, FTypes.FType fType, List<Hop> list) {
        this.costPointerSet = new HashSet();
        this.inputDependency = new ArrayList();
        this.hopRef = hop;
        this.fedOut = federatedOutput;
        this.fType = fType;
        this.inputHops = list;
    }

    public HopRel(Hop hop, FEDInstruction.FederatedOutput federatedOutput, FTypes.FType fType, MemoTable memoTable, ArrayList<Hop> arrayList) {
        this(hop, federatedOutput, fType, arrayList);
        setInputDependency(memoTable);
        this.cost = FederatedCostEstimator.costEstimate(this, memoTable);
        setExecType();
    }

    public HopRel(Hop hop, FEDInstruction.FederatedOutput federatedOutput, FTypes.FType fType, MemoTable memoTable, List<Hop> list, List<FTypes.FType> list2) {
        this(hop, federatedOutput, fType, list);
        setInputFTypeDependency(list, list2, memoTable);
        this.cost = FederatedCostEstimator.costEstimate(this, memoTable);
        setExecType();
    }

    private void setInputFTypeDependency(List<Hop> list, List<FTypes.FType> list2, MemoTable memoTable) {
        for (int i = 0; i < list.size(); i++) {
            this.inputDependency.add(memoTable.getHopRel(list.get(i), list2.get(i)));
        }
        validateInputDependency();
    }

    private void setExecType() {
        if (this.inputDependency.stream().anyMatch((v0) -> {
            return v0.hasFederatedOutput();
        }) || HopRewriteUtils.isData(this.hopRef, Types.OpOpData.FEDERATED)) {
            this.execType = Types.ExecType.FED;
        }
    }

    public void addCostPointer(long j) {
        this.costPointerSet.add(Long.valueOf(j));
    }

    public boolean existingCostPointer(long j) {
        return this.costPointerSet.contains(Long.valueOf(j)) ? this.costPointerSet.size() > 1 : this.costPointerSet.size() > 0;
    }

    public boolean hasLocalOutput() {
        return this.fedOut == FEDInstruction.FederatedOutput.LOUT;
    }

    public boolean hasFederatedOutput() {
        return this.fedOut == FEDInstruction.FederatedOutput.FOUT;
    }

    public FEDInstruction.FederatedOutput getFederatedOutput() {
        return this.fedOut;
    }

    public List<HopRel> getInputDependency() {
        return this.inputDependency;
    }

    public Hop getHopRef() {
        return this.hopRef;
    }

    public FTypes.FType getFType() {
        return this.fType;
    }

    public void setFType(FTypes.FType fType) {
        this.fType = fType;
    }

    public Types.ExecType getExecType() {
        return this.execType;
    }

    private HopRel getFOUTHopRel(Hop hop, MemoTable memoTable) {
        return memoTable.getFederatedOutputAlternativeOrNull(hop);
    }

    private void setInputDependency(MemoTable memoTable) {
        if (this.inputHops != null && this.inputHops.size() > 0) {
            if (this.fedOut != FEDInstruction.FederatedOutput.FOUT || this.hopRef.isFederatedDataOp()) {
                List<HopRel> list = this.inputDependency;
                Stream<Hop> stream = this.inputHops.stream();
                Objects.requireNonNull(memoTable);
                list.addAll((Collection) stream.map(memoTable::getMinCostAlternative).collect(Collectors.toList()));
            } else {
                int i = 0;
                HopRel fOUTHopRel = getFOUTHopRel(this.inputHops.get(0), memoTable);
                for (int i2 = 1; i2 < this.inputHops.size(); i2++) {
                    HopRel fOUTHopRel2 = getFOUTHopRel(this.inputHops.get(i2), memoTable);
                    if (fOUTHopRel == null) {
                        fOUTHopRel = fOUTHopRel2;
                        i = i2;
                    } else if (fOUTHopRel2 != null && fOUTHopRel2.getCost() < fOUTHopRel.getCost()) {
                        fOUTHopRel = fOUTHopRel2;
                        i = i2;
                    }
                }
                HopRel[] hopRelArr = new HopRel[this.inputHops.size()];
                for (int i3 = 0; i3 < this.inputHops.size(); i3++) {
                    if (i3 != i) {
                        hopRelArr[i3] = memoTable.getMinCostAlternative(this.inputHops.get(i3));
                    } else {
                        hopRelArr[i3] = fOUTHopRel;
                    }
                }
                this.inputDependency.addAll(Arrays.asList(hopRelArr));
            }
        }
        validateInputDependency();
    }

    private void validateInputDependency() {
        for (int i = 0; i < this.inputDependency.size(); i++) {
            if (this.inputDependency.get(i) == null) {
                throw new DMLException("HopRel input number " + i + " (" + this.hopRef.getInput(i) + ") is null for root: \n" + this);
            }
        }
    }

    public double getCost() {
        return this.cost.getTotal();
    }

    public FederatedCost getCostObject() {
        return this.cost;
    }

    public String toString() {
        return getClass().getSimpleName() + " {HopID: " + this.hopRef.getHopID() + ", Opcode: " + this.hopRef.getOpString() + ", FedOut: " + this.fedOut + ", Cost: " + this.cost.getTotal() + ", Inputs: " + this.inputDependency.stream().map(hopRel -> {
            long hopID = hopRel.getHopRef().getHopID();
            hopRel.getFederatedOutput();
            return "{" + hopID + ", " + hopID + "}";
        }).collect(Collectors.toList()) + "}";
    }
}
