package org.apache.sysds.lops;

import java.util.HashMap;
import java.util.Map;
import org.apache.sysds.common.Types;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.LopProperties;
import org.apache.sysds.parser.Statement;

/* loaded from: input_file:org/apache/sysds/lops/GroupedAggregate.class */
public class GroupedAggregate extends Lop {
    private HashMap<String, Lop> _inputParams;
    private static final String opcode = "groupedagg";
    public static final String COMBINEDINPUT = "combinedinput";
    private boolean _broadcastGroups;
    private int _numThreads;

    public GroupedAggregate(HashMap<String, Lop> hashMap, Types.DataType dataType, Types.ValueType valueType, LopProperties.ExecType execType) {
        super(Lop.Type.GroupedAgg, dataType, valueType);
        this._broadcastGroups = false;
        this._numThreads = 1;
        init(hashMap, dataType, valueType, execType);
    }

    public GroupedAggregate(HashMap<String, Lop> hashMap, Types.DataType dataType, Types.ValueType valueType, LopProperties.ExecType execType, boolean z) {
        super(Lop.Type.GroupedAgg, dataType, valueType);
        this._broadcastGroups = false;
        this._numThreads = 1;
        init(hashMap, dataType, valueType, execType);
        this._broadcastGroups = z;
    }

    public GroupedAggregate(HashMap<String, Lop> hashMap, Types.DataType dataType, Types.ValueType valueType, LopProperties.ExecType execType, int i) {
        super(Lop.Type.GroupedAgg, dataType, valueType);
        this._broadcastGroups = false;
        this._numThreads = 1;
        init(hashMap, dataType, valueType, execType);
        this._numThreads = i;
    }

    private void init(HashMap<String, Lop> hashMap, Types.DataType dataType, Types.ValueType valueType, LopProperties.ExecType execType) {
        addInput(hashMap.get("target"));
        hashMap.get("target").addOutput(this);
        addInput(hashMap.get(Statement.GAGG_GROUPS));
        hashMap.get(Statement.GAGG_GROUPS).addOutput(this);
        for (Map.Entry<String, Lop> entry : hashMap.entrySet()) {
            String key = entry.getKey();
            Lop value = entry.getValue();
            if (!key.equalsIgnoreCase("target") && !key.equalsIgnoreCase(Statement.GAGG_GROUPS)) {
                addInput(value);
                value.addOutput(this);
            }
        }
        this._inputParams = hashMap;
        this.lps.setProperties(this.inputs, execType);
    }

    @Override // org.apache.sysds.lops.Lop
    public String toString() {
        return "Operation = GroupedAggregate";
    }

    @Override // org.apache.sysds.lops.Lop
    public String getInstructions(String str) {
        StringBuilder sb = new StringBuilder();
        sb.append(getExecType());
        sb.append("°");
        sb.append(opcode);
        sb.append("°");
        if (this._inputParams.get("target") == null || this._inputParams.get(Statement.GAGG_GROUPS) == null || this._inputParams.get(Statement.GAGG_FN) == null) {
            throw new LopsException(printErrorLocation() + "Invalid parameters to groupedAggregate -- \"target\", \"groups\", \"fn\" must be provided");
        }
        String label = this._inputParams.get("target").getOutputParameters().getLabel();
        String label2 = this._inputParams.get(Statement.GAGG_GROUPS).getOutputParameters().getLabel();
        sb.append("target");
        sb.append("=");
        sb.append(label);
        sb.append("°");
        sb.append(Statement.GAGG_GROUPS);
        sb.append("=");
        sb.append(label2);
        if (this._inputParams.get(Statement.GAGG_WEIGHTS) != null) {
            sb.append("°");
            sb.append(Statement.GAGG_WEIGHTS);
            sb.append("=");
            sb.append(this._inputParams.get(Statement.GAGG_WEIGHTS).getOutputParameters().getLabel());
        }
        for (Map.Entry<String, Lop> entry : this._inputParams.entrySet()) {
            String key = entry.getKey();
            if (!key.equalsIgnoreCase("target") && !key.equalsIgnoreCase(Statement.GAGG_GROUPS) && !key.equalsIgnoreCase(Statement.GAGG_WEIGHTS)) {
                String prepScalarLabel = entry.getValue().prepScalarLabel();
                sb.append("°");
                sb.append(key);
                sb.append("=");
                sb.append(prepScalarLabel);
            }
        }
        if (getExecType() == LopProperties.ExecType.CP) {
            sb.append("°");
            sb.append(Statement.PS_PARALLELISM);
            sb.append("=");
            sb.append(this._numThreads);
        } else if (getExecType() == LopProperties.ExecType.SPARK) {
            sb.append("°");
            sb.append("broadcast");
            sb.append("=");
            sb.append(this._broadcastGroups);
        }
        sb.append("°");
        sb.append(prepOutputOperand(str));
        return sb.toString();
    }
}
