package org.apache.sysds.parser;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.sysds.common.Builtins;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.LanguageException;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.class */
public class ParameterizedBuiltinFunctionExpression extends DataIdentifier {
    private Builtins _opcode;
    private LinkedHashMap<String, Expression> _varParams;
    public static final String TF_FN_PARAM_DATA = "target";
    public static final String TF_FN_PARAM_MTD2 = "meta";
    public static final String TF_FN_PARAM_SPEC = "spec";
    public static final String LINEAGE_TRACE = "lineage";
    public static final String TF_FN_PARAM_MTD = "transformPath";
    public static HashMap<Builtins, Types.ParamBuiltinOp> pbHopMap = new HashMap<>();

    public static ParameterizedBuiltinFunctionExpression getParamBuiltinFunctionExpression(ParserRuleContext parserRuleContext, String str, ArrayList<ParameterExpression> arrayList, String str2) {
        Builtins builtins;
        if (str == null || arrayList == null || (builtins = Builtins.get(str, true)) == null) {
            return null;
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        Iterator<ParameterExpression> it = arrayList.iterator();
        while (it.hasNext()) {
            ParameterExpression next = it.next();
            linkedHashMap.put(next.getName(), next.getExpr());
        }
        return new ParameterizedBuiltinFunctionExpression(parserRuleContext, builtins, linkedHashMap, str2);
    }

    public ParameterizedBuiltinFunctionExpression(ParserRuleContext parserRuleContext, Builtins builtins, LinkedHashMap<String, Expression> linkedHashMap, String str) {
        this._opcode = builtins;
        this._varParams = linkedHashMap;
        setCtxValuesAndFilename(parserRuleContext, str);
    }

    public ParameterizedBuiltinFunctionExpression(Builtins builtins, LinkedHashMap<String, Expression> linkedHashMap, ParseInfo parseInfo) {
        this._opcode = builtins;
        this._varParams = linkedHashMap;
        setParseInfo(parseInfo);
    }

    @Override // org.apache.sysds.parser.DataIdentifier, org.apache.sysds.parser.Expression
    public Expression rewriteExpression(String str) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (String str2 : this._varParams.keySet()) {
            linkedHashMap.put(str2, this._varParams.get(str2).rewriteExpression(str));
        }
        return new ParameterizedBuiltinFunctionExpression(this._opcode, linkedHashMap, this);
    }

    public void setOpcode(Builtins builtins) {
        this._opcode = builtins;
    }

    public Builtins getOpCode() {
        return this._opcode;
    }

    public HashMap<String, Expression> getVarParams() {
        return this._varParams;
    }

    public Expression getVarParam(String str) {
        return this._varParams.get(str);
    }

    public void addVarParam(String str, Expression expression) {
        this._varParams.put(str, expression);
    }

    @Override // org.apache.sysds.parser.Identifier, org.apache.sysds.parser.Expression
    public void validateExpression(HashMap<String, DataIdentifier> hashMap, HashMap<String, ConstIdentifier> hashMap2, boolean z) {
        Iterator<String> it = getVarParams().keySet().iterator();
        while (it.hasNext()) {
            Expression varParam = getVarParam(it.next());
            if (varParam instanceof FunctionCallIdentifier) {
                raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
            }
            varParam.validateExpression(hashMap, hashMap2, z);
        }
        DataIdentifier dataIdentifier = new DataIdentifier(getTempName());
        setOutput(dataIdentifier);
        switch (getOpCode()) {
            case GROUPEDAGG:
                validateGroupedAgg(dataIdentifier, z);
                return;
            case CDF:
            case INVCDF:
            case PNORM:
            case QNORM:
            case PT:
            case QT:
            case PF:
            case QF:
            case PCHISQ:
            case QCHISQ:
            case PEXP:
            case QEXP:
                validateDistributionFunctions(dataIdentifier, z);
                return;
            case RMEMPTY:
                validateRemoveEmpty(dataIdentifier, z);
                return;
            case REPLACE:
                validateReplace(dataIdentifier, z);
                return;
            case ORDER:
                validateOrder(dataIdentifier, z);
                return;
            case TOKENIZE:
                validateTokenize(dataIdentifier, z);
                return;
            case TRANSFORMAPPLY:
                validateTransformApply(dataIdentifier, z);
                return;
            case TRANSFORMDECODE:
                validateTransformDecode(dataIdentifier, z);
                return;
            case TRANSFORMCOLMAP:
                validateTransformColmap(dataIdentifier, z);
                return;
            case TRANSFORMMETA:
                validateTransformMeta(dataIdentifier, z);
                return;
            case LOWER_TRI:
            case UPPER_TRI:
                validateExtractTriangular(dataIdentifier, getOpCode(), z);
                return;
            case TOSTRING:
                validateCastAsString(dataIdentifier, z);
                return;
            case AUTODIFF:
                validateAutoDiff(dataIdentifier, z);
                return;
            case LISTNV:
                validateNamedList(dataIdentifier, z);
                return;
            case PARAMSERV:
                validateParamserv(dataIdentifier, z);
                return;
            case COUNT_DISTINCT_APPROX:
                validateCountDistinctApprox(dataIdentifier, z);
                return;
            default:
                if (getOpCode() == Builtins.TRANSFORMENCODE) {
                    raiseValidateError("Parameterized function " + getOpCode() + " requires a multi-assignment statement for data and metadata.", false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION);
                    return;
                } else {
                    raiseValidateError("Unsupported parameterized function " + getOpCode(), false, LanguageException.LanguageErrorCodes.UNSUPPORTED_EXPRESSION);
                    return;
                }
        }
    }

    private void validateAutoDiff(DataIdentifier dataIdentifier, boolean z) {
        checkDataType(false, "lineage", "lineage", Types.DataType.LIST, z);
        checkDataValueType(false, "lineage", "lineage", Types.DataType.LIST, Types.ValueType.UNKNOWN, z);
        HashMap<String, Expression> varParams = getVarParams();
        dataIdentifier.setDataType(Types.DataType.LIST);
        dataIdentifier.setValueType(Types.ValueType.UNKNOWN);
        dataIdentifier.setDimensions(varParams.size(), 1L);
        dataIdentifier.setBlocksize(-1);
    }

    @Override // org.apache.sysds.parser.Expression
    public void validateExpression(MultiAssignmentStatement multiAssignmentStatement, HashMap<String, DataIdentifier> hashMap, HashMap<String, ConstIdentifier> hashMap2, boolean z) {
        Iterator<String> it = getVarParams().keySet().iterator();
        while (it.hasNext()) {
            Expression varParam = getVarParam(it.next());
            if (varParam instanceof FunctionCallIdentifier) {
                raiseValidateError("UDF function call not supported as parameter to built-in function call", false);
            }
            varParam.validateExpression(hashMap, hashMap2, z);
        }
        this._outputs = new Identifier[multiAssignmentStatement.getTargetList().size()];
        int i = 0;
        Iterator<DataIdentifier> it2 = multiAssignmentStatement.getTargetList().iterator();
        while (it2.hasNext()) {
            DataIdentifier dataIdentifier = new DataIdentifier(it2.next());
            dataIdentifier.setParseInfo(this);
            int i2 = i;
            i++;
            this._outputs[i2] = dataIdentifier;
        }
        switch (getOpCode()) {
            case TRANSFORMENCODE:
                validateTransformEncode((DataIdentifier) getOutputs()[0], (DataIdentifier) getOutputs()[1], z);
                return;
            default:
                raiseValidateError("Unsupported parameterized function " + getOpCode(), false, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
                return;
        }
    }

    private void validateParamserv(DataIdentifier dataIdentifier, boolean z) {
        String name = getOpCode().name();
        if (getVarParams().size() < 1) {
            raiseValidateError("Should provide more arguments for function " + name, false, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        checkInvalidParameters(getOpCode(), getVarParams(), CollectionUtils.asSet(Statement.PS_MODEL, Statement.PS_FEATURES, Statement.PS_LABELS, Statement.PS_VAL_FEATURES, Statement.PS_VAL_LABELS, Statement.PS_UPDATE_FUN, Statement.PS_AGGREGATION_FUN, Statement.PS_VAL_FUN, "mode", Statement.PS_UPDATE_TYPE, Statement.PS_FREQUENCY, Statement.PS_EPOCHS, Statement.PS_BATCH_SIZE, Statement.PS_PARALLELISM, Statement.PS_SCHEME, Statement.PS_FED_RUNTIME_BALANCING, Statement.PS_FED_WEIGHTING, Statement.PS_HYPER_PARAMS, Statement.PS_CHECKPOINTING, "seed", Statement.PS_NBATCHES, Statement.PS_MODELAVG, Statement.PS_HE));
        checkDataType(false, name, Statement.PS_MODEL, Types.DataType.LIST, z);
        checkDataType(false, name, Statement.PS_FEATURES, Types.DataType.MATRIX, z);
        checkDataType(false, name, Statement.PS_LABELS, Types.DataType.MATRIX, z);
        checkDataValueType(true, name, Statement.PS_VAL_FEATURES, Types.DataType.MATRIX, Types.ValueType.FP64, z);
        checkDataValueType(true, name, Statement.PS_VAL_LABELS, Types.DataType.MATRIX, Types.ValueType.FP64, z);
        checkDataValueType(false, name, Statement.PS_UPDATE_FUN, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        checkDataValueType(false, name, Statement.PS_AGGREGATION_FUN, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        checkDataValueType(true, name, Statement.PS_VAL_FUN, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        checkStringParam(true, name, "mode", z);
        checkStringParam(true, name, Statement.PS_UPDATE_TYPE, z);
        checkStringParam(true, name, Statement.PS_FREQUENCY, z);
        checkDataValueType(false, name, Statement.PS_EPOCHS, Types.DataType.SCALAR, Types.ValueType.INT64, z);
        checkDataValueType(true, name, Statement.PS_BATCH_SIZE, Types.DataType.SCALAR, Types.ValueType.INT64, z);
        checkDataValueType(true, name, Statement.PS_PARALLELISM, Types.DataType.SCALAR, Types.ValueType.INT64, z);
        checkStringParam(true, name, Statement.PS_SCHEME, z);
        checkStringParam(true, name, Statement.PS_FED_RUNTIME_BALANCING, z);
        checkStringParam(true, name, Statement.PS_FED_WEIGHTING, z);
        checkDataValueType(true, name, Statement.PS_HYPER_PARAMS, Types.DataType.LIST, Types.ValueType.UNKNOWN, z);
        checkStringParam(true, name, Statement.PS_CHECKPOINTING, z);
        checkDataValueType(true, name, "seed", Types.DataType.SCALAR, Types.ValueType.INT64, z);
        dataIdentifier.setDataType(Types.DataType.LIST);
        dataIdentifier.setValueType(Types.ValueType.UNKNOWN);
        dataIdentifier.setDimensions(getVarParam(Statement.PS_MODEL).getOutput().getDim1(), 1L);
        dataIdentifier.setBlocksize(-1);
    }

    private void validateCountDistinctApprox(DataIdentifier dataIdentifier, boolean z) {
        Set asSet = CollectionUtils.asSet("KMV");
        HashMap<String, Expression> varParams = getVarParams();
        if (varParams.containsKey(null)) {
            varParams.put(DataExpression.RAND_DATA, varParams.remove(null));
        }
        String name = getOpCode().getName();
        String str = "function " + name + " takes at least 1 and at most 3 parameters";
        if (varParams.size() < 1) {
            raiseValidateError("Too few parameters: " + str, z);
        }
        if (varParams.size() > 3) {
            raiseValidateError("Too many parameters: " + str, z);
        }
        checkInvalidParameters(getOpCode(), varParams, CollectionUtils.asSet(DataExpression.RAND_DATA, DataExpression.FED_TYPE, "dir"));
        checkDataType(false, name, DataExpression.RAND_DATA, Types.DataType.MATRIX, z);
        checkDataValueType(false, name, DataExpression.RAND_DATA, Types.DataType.MATRIX, Types.ValueType.FP64, z);
        Identifier output = varParams.get(DataExpression.RAND_DATA).getOutput();
        if (output == null) {
            raiseValidateError("Cannot parse input parameter \"data\" to function " + name, z);
        }
        checkStringParam(true, name, DataExpression.FED_TYPE, z);
        if (varParams.keySet().contains(DataExpression.FED_TYPE)) {
            String upperCase = varParams.get(DataExpression.FED_TYPE).toString().toUpperCase();
            if (!asSet.contains(upperCase)) {
                raiseValidateError("Unrecognized type for optional parameter " + upperCase, z);
            }
        } else {
            addVarParam(DataExpression.FED_TYPE, new StringIdentifier("KMV", this));
        }
        checkStringParam(true, name, "dir", z);
        if (!varParams.keySet().contains("dir")) {
            dataIdentifier.setDataType(Types.DataType.SCALAR);
            dataIdentifier.setDimensions(0L, 0L);
            dataIdentifier.setBlocksize(0);
            dataIdentifier.setValueType(Types.ValueType.INT64);
            dataIdentifier.setNnz(1L);
            return;
        }
        String upperCase2 = varParams.get("dir").toString().toUpperCase();
        if (upperCase2.equals(Types.Direction.Row.toString())) {
            dataIdentifier.setDataType(Types.DataType.MATRIX);
            dataIdentifier.setDimensions(output.getDim1(), 1L);
            dataIdentifier.setBlocksize(output.getBlocksize());
            dataIdentifier.setValueType(Types.ValueType.INT64);
            dataIdentifier.setNnz(output.getDim1());
            return;
        }
        if (upperCase2.equals(Types.Direction.Col.toString())) {
            dataIdentifier.setDataType(Types.DataType.MATRIX);
            dataIdentifier.setDimensions(1L, output.getDim2());
            dataIdentifier.setBlocksize(output.getBlocksize());
            dataIdentifier.setValueType(Types.ValueType.INT64);
            dataIdentifier.setNnz(output.getDim2());
            return;
        }
        if (!upperCase2.equals(Types.Direction.RowCol.toString())) {
            raiseValidateError("Invalid argument: " + upperCase2 + " is not recognized");
            return;
        }
        dataIdentifier.setDataType(Types.DataType.SCALAR);
        dataIdentifier.setDimensions(0L, 0L);
        dataIdentifier.setBlocksize(0);
        dataIdentifier.setValueType(Types.ValueType.INT64);
        dataIdentifier.setNnz(1L);
    }

    private void checkStringParam(boolean z, String str, String str2, boolean z2) {
        Expression varParam = getVarParam(str2);
        if (varParam == null) {
            if (z) {
                return;
            } else {
                raiseValidateError(String.format("Function %s should provide parameter '%s'", str, str2), z2);
            }
        }
        if (varParam.getOutput().getDataType().isScalar() && varParam.getOutput().getValueType().equals(Types.ValueType.STRING)) {
            return;
        }
        raiseValidateError(String.format("Function %s should provide a string value for %s parameter.", str, str2), z2);
    }

    private void validateTokenize(DataIdentifier dataIdentifier, boolean z) {
        checkDataType(false, "tokenize", "target", Types.DataType.FRAME, z);
        checkDataValueType(false, "tokenize", TF_FN_PARAM_SPEC, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        validateTransformSpec(TF_FN_PARAM_SPEC, z);
        dataIdentifier.setDataType(Types.DataType.FRAME);
        dataIdentifier.setValueType(Types.ValueType.STRING);
        dataIdentifier.setDimensions(-1L, -1L);
    }

    private void validateTransformApply(DataIdentifier dataIdentifier, boolean z) {
        checkDataType(false, "transformapply", "target", Types.DataType.FRAME, z);
        checkDataType(false, "transformapply", TF_FN_PARAM_MTD2, Types.DataType.FRAME, z);
        checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        validateTransformSpec(TF_FN_PARAM_SPEC, z);
        dataIdentifier.setDataType(Types.DataType.MATRIX);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(-1L, -1L);
    }

    private void validateTransformDecode(DataIdentifier dataIdentifier, boolean z) {
        checkDataType(false, "transformdecode", "target", Types.DataType.MATRIX, z);
        checkDataType(false, "transformdecode", TF_FN_PARAM_MTD2, Types.DataType.FRAME, z);
        checkDataValueType(false, "transformdecode", TF_FN_PARAM_SPEC, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        validateTransformSpec(TF_FN_PARAM_SPEC, z);
        dataIdentifier.setDataType(Types.DataType.FRAME);
        dataIdentifier.setValueType(Types.ValueType.STRING);
        dataIdentifier.setDimensions(-1L, -1L);
    }

    private void validateTransformColmap(DataIdentifier dataIdentifier, boolean z) {
        Expression varParam = getVarParam("target");
        checkDataType(false, "transformcolmap", "target", Types.DataType.FRAME, z);
        checkDataValueType(false, "transformcolmap", TF_FN_PARAM_SPEC, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        validateTransformSpec(TF_FN_PARAM_SPEC, z);
        dataIdentifier.setDataType(Types.DataType.MATRIX);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(varParam.getOutput().getDim2(), 3L);
    }

    private void validateTransformMeta(DataIdentifier dataIdentifier, boolean z) {
        checkDataValueType(false, "transformmeta", TF_FN_PARAM_SPEC, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        validateTransformSpec(TF_FN_PARAM_SPEC, z);
        checkDataValueType(false, "transformmeta", TF_FN_PARAM_MTD, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        dataIdentifier.setDataType(Types.DataType.FRAME);
        dataIdentifier.setValueType(Types.ValueType.STRING);
        dataIdentifier.setDimensions(-1L, -1L);
    }

    private void validateTransformEncode(DataIdentifier dataIdentifier, DataIdentifier dataIdentifier2, boolean z) {
        checkDataType(false, "transformencode", "target", Types.DataType.FRAME, z);
        checkDataValueType(false, "transformencode", TF_FN_PARAM_SPEC, Types.DataType.SCALAR, Types.ValueType.STRING, z);
        validateTransformSpec(TF_FN_PARAM_SPEC, z);
        dataIdentifier.setDataType(Types.DataType.MATRIX);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(-1L, -1L);
        dataIdentifier2.setDataType(Types.DataType.FRAME);
        dataIdentifier2.setValueType(Types.ValueType.STRING);
        dataIdentifier2.setDimensions(-1L, -1L);
    }

    private void validateTransformSpec(String str, boolean z) {
        Expression varParam = getVarParam(str);
        if (varParam instanceof StringIdentifier) {
            try {
                new JSONObject(((StringIdentifier) varParam).getValue());
            } catch (Exception e) {
                raiseValidateError("Transform specification parsing issue: ", z, e.getMessage());
            }
        }
    }

    private void validateExtractTriangular(DataIdentifier dataIdentifier, Builtins builtins, boolean z) {
        checkInvalidParameters(builtins, getVarParams(), CollectionUtils.asSet("target", "diag", "values"));
        checkTargetParam(getVarParam("target"), z);
        checkOptionalBooleanParam(getVarParam("diag"), "diag", z);
        checkOptionalBooleanParam(getVarParam("values"), "values", z);
        if (getVarParam("diag") == null) {
            this._varParams.put("diag", new BooleanIdentifier(false));
        }
        if (getVarParam("values") == null) {
            this._varParams.put("values", new BooleanIdentifier(false));
        }
        Identifier output = getVarParam("target").getOutput();
        dataIdentifier.setDataType(Types.DataType.MATRIX);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(output.getDim1(), output.getDim2());
    }

    private void validateReplace(DataIdentifier dataIdentifier, boolean z) {
        Expression varParam = getVarParam("target");
        if (varParam.getOutput().getDataType() != Types.DataType.FRAME) {
            checkTargetParam(varParam, z);
        }
        Expression varParam2 = getVarParam("pattern");
        if (varParam2 == null) {
            raiseValidateError("Named parameter 'pattern' missing. Please specify the replacement pattern.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        } else if (varParam2.getOutput().getDataType() != Types.DataType.SCALAR) {
            raiseValidateError("Replacement pattern 'pattern' is of type '" + varParam2.getOutput().getDataType() + "'. Please, specify a scalar replacement pattern.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        Expression varParam3 = getVarParam("replacement");
        if (varParam3 == null) {
            raiseValidateError("Named parameter 'replacement' missing. Please specify the replacement value.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        } else if (varParam3.getOutput().getDataType() != Types.DataType.SCALAR) {
            raiseValidateError("Replacement value 'replacement' is of type '" + varParam3.getOutput().getDataType() + "'. Please, specify a scalar replacement value.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        dataIdentifier.setDataType(varParam.getOutput().getDataType());
        if (varParam.getOutput().getDataType() == Types.DataType.FRAME) {
            dataIdentifier.setValueType(Types.ValueType.STRING);
        } else {
            dataIdentifier.setValueType(Types.ValueType.FP64);
        }
        dataIdentifier.setDimensions(varParam.getOutput().getDim1(), varParam.getOutput().getDim2());
    }

    private void validateOrder(DataIdentifier dataIdentifier, boolean z) {
        Expression varParam = getVarParam("target");
        checkTargetParam(varParam, z);
        for (String str : getVarParams().keySet()) {
            if (!str.equals("target") && !str.equals("by") && !str.equals("decreasing") && !str.equals("index.return")) {
                raiseValidateError("Unsupported order parameter: '" + str + "'", false);
            }
        }
        Expression varParam2 = getVarParam("by");
        if (varParam2 == null) {
            addVarParam("by", new IntIdentifier(1L));
        } else if (!varParam2.getOutput().getDataType().isScalar() && !varParam2.getOutput().getDataType().isMatrix()) {
            raiseValidateError("Orderby column 'by' is of type '" + varParam2.getOutput().getDataType() + "'. Please, use a scalar or row vector to specify column indexes.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        Expression varParam3 = getVarParam("decreasing");
        if (varParam3 == null) {
            addVarParam("decreasing", new BooleanIdentifier(false));
        } else if (varParam3.getOutput().getDataType() != Types.DataType.SCALAR) {
            raiseValidateError("Ordering 'decreasing' is of type '" + varParam3.getOutput().getDataType() + "', '" + varParam3.getOutput().getValueType() + "'. Please, specify 'decreasing' as a scalar boolean.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        Expression varParam4 = getVarParam("index.return");
        if (varParam4 == null) {
            varParam4 = new BooleanIdentifier(false);
            addVarParam("index.return", varParam4);
        } else if (varParam4.getOutput().getDataType() != Types.DataType.SCALAR) {
            raiseValidateError("Return type 'index.return' is of type '" + varParam4.getOutput().getDataType() + "', '" + varParam4.getOutput().getValueType() + "'. Please, specify 'indexreturn' as a scalar boolean.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        long dim2 = varParam4 instanceof BooleanIdentifier ? ((BooleanIdentifier) varParam4).getValue() ? 1L : varParam.getOutput().getDim2() : -1L;
        dataIdentifier.setDataType(Types.DataType.MATRIX);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(varParam.getOutput().getDim1(), dim2);
    }

    private void validateRemoveEmpty(DataIdentifier dataIdentifier, boolean z) {
        Set asSet = CollectionUtils.asSet("target", "margin", "select", "empty.return");
        Set set = (Set) this._varParams.keySet().stream().filter(str -> {
            return !asSet.contains(str);
        }).collect(Collectors.toSet());
        if (!set.isEmpty()) {
            raiseValidateError("Invalid parameters for removeEmpty: " + Arrays.toString(set.toArray(new String[0])), false);
        }
        Expression varParam = getVarParam("target");
        checkEmptyTargetParam(varParam, z);
        Expression varParam2 = getVarParam("margin");
        if (varParam2 == null) {
            raiseValidateError("Named parameter 'margin' missing. Please specify 'rows' or 'cols'.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        } else if (!(varParam2 instanceof DataIdentifier) && !varParam2.toString().equals("rows") && !varParam2.toString().equals("cols")) {
            raiseValidateError("Named parameter 'margin' has an invalid value '" + varParam2.toString() + "'. Please specify 'rows' or 'cols'.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        Expression varParam3 = getVarParam("select");
        if (varParam3 != null && varParam3.getOutput().getDataType() != Types.DataType.MATRIX) {
            raiseValidateError("Index matrix 'select' is of type '" + varParam3.getOutput().getDataType() + "'. Please specify the select matrix.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        Expression varParam4 = getVarParam("empty.return");
        if (varParam4 != null && (!varParam4.getOutput().getDataType().isScalar() || varParam4.getOutput().getValueType() != Types.ValueType.BOOLEAN)) {
            raiseValidateError("Boolean parameter 'empty.return' is of type " + varParam4.getOutput().getDataType() + "[" + varParam4.getOutput().getValueType() + "].", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        if (varParam4 == null) {
            this._varParams.put("empty.return", new BooleanIdentifier(true));
        }
        dataIdentifier.setDataType(varParam.getOutput().getDataType());
        if (varParam.getOutput().getDataType() == Types.DataType.FRAME) {
            dataIdentifier.setValueType(Types.ValueType.STRING);
        } else {
            dataIdentifier.setValueType(Types.ValueType.FP64);
        }
        dataIdentifier.setDimensions(-1L, -1L);
    }

    private void validateGroupedAgg(DataIdentifier dataIdentifier, boolean z) {
        if (getVarParam("target") == null || getVarParam(Statement.GAGG_GROUPS) == null) {
            raiseValidateError("Must define both target and groups.", z);
        }
        Expression varParam = getVarParam("target");
        Expression varParam2 = getVarParam(Statement.GAGG_GROUPS);
        Expression varParam3 = getVarParam(Statement.GAGG_NUM_GROUPS);
        boolean z2 = true;
        boolean z3 = false;
        if (varParam2.getOutput().dimsKnown() && varParam.getOutput().dimsKnown()) {
            if (varParam2.getOutput().getDim2() == 1 && varParam.getOutput().getDim2() > 1) {
                if (getVarParam(Statement.GAGG_WEIGHTS) != null) {
                    raiseValidateError("Matrix input not supported with weights.", z);
                }
                if (getVarParam(Statement.GAGG_NUM_GROUPS) == null) {
                    raiseValidateError("Matrix input not supported without specified numgroups.", z);
                }
                if (varParam2.getOutput().getDim1() != varParam.getOutput().getDim1()) {
                    long dim1 = varParam.getOutput().getDim1();
                    long dim2 = varParam.getOutput().getDim2();
                    varParam2.getOutput().getDim1();
                    raiseValidateError("Target and groups must have same dimensions --  target dims: " + dim1 + " x " + this + ", groups dims: " + dim2 + " x 1.", z);
                }
                z3 = true;
            } else if (varParam2.getOutput().getDim2() == 1 && varParam.getOutput().getDim2() == 1) {
                if (varParam2.getOutput().getDim1() != varParam.getOutput().getDim1()) {
                    long dim12 = varParam.getOutput().getDim1();
                    varParam2.getOutput().getDim1();
                    raiseValidateError("Target and groups must have same dimensions --  target dims: " + dim12 + " x 1, groups dims: " + this + " x 1.", z);
                }
            } else if (varParam2.getOutput().getDim1() == 1 && varParam.getOutput().getDim1() == 1) {
                if (varParam2.getOutput().getDim2() != varParam.getOutput().getDim2()) {
                    long dim22 = varParam.getOutput().getDim2();
                    varParam2.getOutput().getDim2();
                    raiseValidateError("Target and groups must have same dimensions --  target dims: 1 x " + dim22 + ", groups dims: 1 x " + this + ".", z);
                }
                z2 = true;
            } else {
                raiseValidateError("Invalid target and groups inputs - dimension mismatch.", z);
            }
        }
        Expression varParam4 = getVarParam(Statement.GAGG_FN);
        if (varParam4 == null) {
            raiseValidateError("must define function name (fn=<function name>) for aggregate()", z);
        } else if (varParam4 instanceof Identifier) {
            String obj = varParam4.toString();
            if (obj.equals(Statement.GAGG_FN_CM)) {
                String obj2 = getVarParam(Statement.GAGG_FN_CM_ORDER) == null ? null : getVarParam(Statement.GAGG_FN_CM_ORDER).toString();
                if (obj2 == null || (!obj2.equals("2") && !obj2.equals("3") && !obj2.equals("4"))) {
                    raiseValidateError("for centralmoment, must define order.  Order must be equal to 2,3, or 4", z);
                }
            } else if (!obj.equals(Statement.GAGG_FN_COUNT) && !obj.equals(Statement.GAGG_FN_SUM) && !obj.equals(Statement.GAGG_FN_MEAN) && !obj.equals(Statement.GAGG_FN_VARIANCE) && !obj.equals("min") && !obj.equals("max")) {
                raiseValidateError("fname is " + obj + " but must be either centeralmoment, count, sum, mean, variance", z);
            }
        }
        long j = -1;
        long j2 = -1;
        if (varParam3 != null && (varParam3 instanceof Identifier)) {
            Identifier identifier = (Identifier) varParam3;
            if (identifier instanceof ConstIdentifier) {
                long longValue = ((ConstIdentifier) identifier).getLongValue();
                if (z2) {
                    j = longValue;
                    j2 = z3 ? varParam.getOutput().getDim2() : 1L;
                } else {
                    j = 1;
                    j2 = longValue;
                }
            }
        }
        dataIdentifier.setDataType(Types.DataType.MATRIX);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(j, j2);
    }

    private void checkTargetParam(Expression expression, boolean z) {
        if (expression == null) {
            raiseValidateError("Named parameter 'target' missing. Please specify the input matrix.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        } else if (expression.getOutput().getDataType() != Types.DataType.MATRIX) {
            raiseValidateError("Input matrix 'target' is of type '" + expression.getOutput().getDataType() + "'. Please specify the input matrix.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
    }

    private void checkEmptyTargetParam(Expression expression, boolean z) {
        if (expression == null) {
            raiseValidateError("Named parameter 'target' missing. Please specify the input matrix.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
    }

    private void checkOptionalBooleanParam(Expression expression, String str, boolean z) {
        if (expression != null) {
            if (expression.getOutput().getDataType().isScalar() && expression.getOutput().getValueType() == Types.ValueType.BOOLEAN) {
                return;
            }
            raiseValidateError("Boolean parameter '" + str + "' is of type " + expression.getOutput().getDataType() + "[" + expression.getOutput().getValueType() + "].", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
    }

    private void checkInvalidParameters(Builtins builtins, HashMap<String, Expression> hashMap, Set<String> set) {
        Set set2 = (Set) hashMap.keySet().stream().filter(str -> {
            return !set.contains(str);
        }).collect(Collectors.toSet());
        if (set2.isEmpty()) {
            return;
        }
        raiseValidateError(String.format("Invalid parameters for %s: %s", builtins.name(), (List) set2.stream().map(str2 -> {
            String text = ((Expression) hashMap.get(str2)).getText();
            return str2 == null ? text : str2 + "=" + text;
        }).collect(Collectors.toList())), false);
    }

    private void validateDistributionFunctions(DataIdentifier dataIdentifier, boolean z) {
        Builtins opCode = getOpCode();
        if (getVarParam("target") == null || getVarParam("target").getOutput().getDataType() != Types.DataType.SCALAR) {
            raiseValidateError("target must be provided for distribution functions, and it must be a scalar value.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
        switch (opCode) {
            case CDF:
            case INVCDF:
                if (getVarParam("dist") == null) {
                    raiseValidateError("For cdf() and icdf(), a distribution function must be specified (as a string).", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
                    break;
                }
                break;
            case PT:
            case QT:
                if (getVarParam("df") == null) {
                    raiseValidateError("Degrees of freedom df must be provided for t-distribution.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
                    break;
                }
                break;
            case PF:
            case QF:
                if (getVarParam("df1") == null || getVarParam("df2") == null) {
                    raiseValidateError("Two degrees of freedom df1 and df2 must be provided for F-distribution.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
                    break;
                }
                break;
            case PCHISQ:
            case QCHISQ:
                if (getVarParam("df") == null) {
                    raiseValidateError("Degrees of freedom df must be provided for chi-squared-distribution.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
                    break;
                }
                break;
        }
        switch (opCode) {
            case INVCDF:
            case QNORM:
            case QT:
            case QF:
            case QCHISQ:
            case QEXP:
                if (getVarParam("lower.tail") != null) {
                    raiseValidateError("Lower tail argument is invalid while computing inverse cumulative probabilities.", z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
                    break;
                }
                break;
        }
        dataIdentifier.setDataType(Types.DataType.SCALAR);
        dataIdentifier.setValueType(Types.ValueType.FP64);
        dataIdentifier.setDimensions(0L, 0L);
    }

    private void validateCastAsString(DataIdentifier dataIdentifier, boolean z) {
        HashMap<String, Expression> varParams = getVarParams();
        if (varParams.containsKey(null)) {
            varParams.put("target", varParams.remove(null));
        }
        String[] strArr = {"target", "rows", "cols", "decimal", DataExpression.DELIM_SPARSE, DataExpression.DELIM_DELIMITER, "linesep"};
        HashSet hashSet = new HashSet(Arrays.asList(strArr));
        for (String str : varParams.keySet()) {
            if (!hashSet.contains(str)) {
                raiseValidateError("Invalid parameter " + str + " for toString, valid parameters are " + Arrays.toString(strArr), z, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
            }
        }
        dataIdentifier.setDataType(Types.DataType.SCALAR);
        dataIdentifier.setValueType(Types.ValueType.STRING);
        dataIdentifier.setDimensions(0L, 0L);
    }

    private void validateNamedList(DataIdentifier dataIdentifier, boolean z) {
        HashMap<String, Expression> varParams = getVarParams();
        dataIdentifier.setDataType(Types.DataType.LIST);
        dataIdentifier.setValueType(Types.ValueType.UNKNOWN);
        dataIdentifier.setDimensions(varParams.size(), 1L);
        dataIdentifier.setBlocksize(-1);
    }

    private void checkDataType(boolean z, String str, String str2, Types.DataType dataType, boolean z2) {
        Expression varParam = getVarParam(str2);
        if (varParam == null) {
            if (z) {
                return;
            }
            raiseValidateError("Named parameter '" + str2 + "' missing. Please specify the input.", z2, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        } else if (varParam.getOutput().getDataType() != dataType) {
            raiseValidateError("Input to " + str + "::" + str2 + " must be of type '" + dataType.toString() + "'. It should not be of type '" + varParam.getOutput().getDataType() + "'.", z2, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
    }

    private void checkDataValueType(boolean z, String str, String str2, Types.DataType dataType, Types.ValueType valueType, boolean z2) {
        Expression varParam = getVarParam(str2);
        if (varParam == null) {
            if (z) {
                return;
            }
            raiseValidateError(String.format("Named parameter '%s' is missing. Please specify the input.", str2), z2, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        } else {
            if (varParam.getOutput().getDataType() == dataType && varParam.getOutput().getValueType() == valueType) {
                return;
            }
            raiseValidateError(String.format("Input to %s::%s must be of type '%s', '%s'.It should not be of type '%s', '%s'.", str, str2, dataType.toString(), valueType.toString(), varParam.getOutput().getDataType().toString(), varParam.getOutput().getValueType().toString()), z2, LanguageException.LanguageErrorCodes.INVALID_PARAMETERS);
        }
    }

    @Override // org.apache.sysds.parser.DataIdentifier
    public String toString() {
        StringBuilder sb = new StringBuilder(this._opcode.toString() + "(");
        for (String str : this._varParams.keySet()) {
            sb.append("," + str + "=" + this._varParams.get(str));
        }
        sb.append(" )");
        return sb.toString();
    }

    @Override // org.apache.sysds.parser.DataIdentifier, org.apache.sysds.parser.Expression
    public VariableSet variablesRead() {
        VariableSet variableSet = new VariableSet();
        Iterator<String> it = this._varParams.keySet().iterator();
        while (it.hasNext()) {
            variableSet.addVariables(this._varParams.get(it.next()).variablesRead());
        }
        return variableSet;
    }

    @Override // org.apache.sysds.parser.DataIdentifier, org.apache.sysds.parser.Expression
    public VariableSet variablesUpdated() {
        VariableSet variableSet = new VariableSet();
        Iterator<String> it = this._varParams.keySet().iterator();
        while (it.hasNext()) {
            variableSet.addVariables(this._varParams.get(it.next()).variablesUpdated());
        }
        variableSet.addVariable(((DataIdentifier) getOutput()).getName(), (DataIdentifier) getOutput());
        return variableSet;
    }

    @Override // org.apache.sysds.parser.DataIdentifier
    public boolean multipleReturns() {
        return this._opcode == Builtins.TRANSFORMENCODE;
    }

    static {
        pbHopMap.put(Builtins.AUTODIFF, Types.ParamBuiltinOp.AUTODIFF);
        pbHopMap.put(Builtins.GROUPEDAGG, Types.ParamBuiltinOp.GROUPEDAGG);
        pbHopMap.put(Builtins.RMEMPTY, Types.ParamBuiltinOp.RMEMPTY);
        pbHopMap.put(Builtins.REPLACE, Types.ParamBuiltinOp.REPLACE);
        pbHopMap.put(Builtins.LOWER_TRI, Types.ParamBuiltinOp.LOWER_TRI);
        pbHopMap.put(Builtins.UPPER_TRI, Types.ParamBuiltinOp.UPPER_TRI);
        pbHopMap.put(Builtins.ORDER, Types.ParamBuiltinOp.INVALID);
        pbHopMap.put(Builtins.CDF, Types.ParamBuiltinOp.CDF);
        pbHopMap.put(Builtins.PNORM, Types.ParamBuiltinOp.CDF);
        pbHopMap.put(Builtins.PT, Types.ParamBuiltinOp.CDF);
        pbHopMap.put(Builtins.PF, Types.ParamBuiltinOp.CDF);
        pbHopMap.put(Builtins.PCHISQ, Types.ParamBuiltinOp.CDF);
        pbHopMap.put(Builtins.PEXP, Types.ParamBuiltinOp.CDF);
        pbHopMap.put(Builtins.INVCDF, Types.ParamBuiltinOp.INVCDF);
        pbHopMap.put(Builtins.QNORM, Types.ParamBuiltinOp.INVCDF);
        pbHopMap.put(Builtins.QT, Types.ParamBuiltinOp.INVCDF);
        pbHopMap.put(Builtins.QF, Types.ParamBuiltinOp.INVCDF);
        pbHopMap.put(Builtins.QCHISQ, Types.ParamBuiltinOp.INVCDF);
        pbHopMap.put(Builtins.QEXP, Types.ParamBuiltinOp.INVCDF);
        pbHopMap.put(Builtins.TOSTRING, Types.ParamBuiltinOp.TOSTRING);
    }
}
