/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.mpp.plan.analyze;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.exception.sql.StatementAnalyzeException;
import org.apache.iotdb.db.metadata.path.MeasurementPath;
import org.apache.iotdb.db.mpp.plan.analyze.TypeProvider;
import org.apache.iotdb.db.mpp.plan.expression.Expression;
import org.apache.iotdb.db.mpp.plan.expression.leaf.TimeSeriesOperand;
import org.apache.iotdb.db.mpp.plan.expression.multi.FunctionExpression;
import org.apache.iotdb.db.query.aggregation.AggregationType;
import org.apache.iotdb.db.utils.SchemaUtils;
import org.apache.iotdb.tsfile.file.metadata.enums.TSDataType;

public class GroupByLevelController {
    private final int[] levels;
    private final Map<Expression, Set<Expression>> groupedPathMap;
    private final Map<Expression, Expression> rawPathToGroupedPathMap;
    private final Map<String, String> columnToAliasMap;
    private final Map<String, String> aliasToColumnMap;
    private final TypeProvider typeProvider;

    public GroupByLevelController(int[] levels, TypeProvider typeProvider) {
        this.levels = levels;
        this.groupedPathMap = new LinkedHashMap<Expression, Set<Expression>>();
        this.rawPathToGroupedPathMap = new HashMap<Expression, Expression>();
        this.columnToAliasMap = new HashMap<String, String>();
        this.aliasToColumnMap = new HashMap<String, String>();
        this.typeProvider = typeProvider;
    }

    public void control(boolean isCountStar, Expression expression, String alias) {
        if (!(expression instanceof FunctionExpression) || !expression.isBuiltInAggregationFunctionExpression()) {
            throw new SemanticException(expression + " can't be used in group by level.");
        }
        PartialPath rawPath = ((TimeSeriesOperand)expression.getExpressions().get(0)).getPath();
        PartialPath groupedPath = this.generatePartialPathByLevel(isCountStar, rawPath, this.levels);
        String functionName = ((FunctionExpression)expression).getFunctionName();
        this.checkDatatypeConsistency(groupedPath.getFullPath(), functionName, rawPath);
        this.updateTypeProvider(functionName, groupedPath.getFullPath(), rawPath);
        TimeSeriesOperand rawPathExpression = new TimeSeriesOperand(rawPath);
        TimeSeriesOperand groupedPathExpression = new TimeSeriesOperand(groupedPath);
        if (!this.rawPathToGroupedPathMap.containsKey(rawPathExpression)) {
            this.rawPathToGroupedPathMap.put(rawPathExpression, groupedPathExpression);
        }
        FunctionExpression groupedExpression = new FunctionExpression(((FunctionExpression)expression).getFunctionName(), ((FunctionExpression)expression).getFunctionAttributes(), Collections.singletonList(groupedPathExpression));
        this.groupedPathMap.computeIfAbsent(groupedExpression, key -> new HashSet()).add(expression);
        if (alias != null) {
            this.checkAliasAndUpdateAliasMap(alias, groupedExpression.getExpressionString());
        }
    }

    private void checkDatatypeConsistency(String groupedPath, String functionName, PartialPath rawPath) {
        switch (functionName.toLowerCase()) {
            case "min_time": 
            case "max_time": 
            case "count": 
            case "avg": 
            case "sum": {
                if (!this.typeProvider.containsTypeInfoOf(groupedPath)) {
                    this.typeProvider.setType(groupedPath, rawPath.getSeriesType());
                }
                return;
            }
            case "min_value": 
            case "last_value": 
            case "first_value": 
            case "max_value": 
            case "extreme": {
                if (!this.typeProvider.containsTypeInfoOf(groupedPath)) {
                    this.typeProvider.setType(groupedPath, rawPath.getSeriesType());
                } else {
                    TSDataType tsDataType = this.typeProvider.getType(groupedPath);
                    if (tsDataType != rawPath.getSeriesType()) {
                        throw new SemanticException(String.format("GROUP BY LEVEL: the data types of the same output column[%s] should be the same.", groupedPath));
                    }
                }
                return;
            }
        }
        throw new IllegalArgumentException("Invalid Aggregation function: " + functionName);
    }

    private void checkAliasAndUpdateAliasMap(String alias, String groupedExpressionString) throws StatementAnalyzeException {
        if (this.columnToAliasMap.get(groupedExpressionString) == null) {
            if (this.aliasToColumnMap.get(alias) != null) {
                throw new StatementAnalyzeException(String.format("alias '%s' can only be matched with one result column", alias));
            }
            this.columnToAliasMap.put(groupedExpressionString, alias);
            this.aliasToColumnMap.put(alias, groupedExpressionString);
        } else if (!this.columnToAliasMap.get(groupedExpressionString).equals(alias)) {
            throw new StatementAnalyzeException(String.format("Result column %s with more than one alias[%s, %s]", groupedExpressionString, this.columnToAliasMap.get(groupedExpressionString), alias));
        }
    }

    public PartialPath generatePartialPathByLevel(boolean isCountStar, PartialPath rawPath, int[] pathLevels) {
        String[] nodes = rawPath.getNodes();
        HashSet<Integer> levelSet = new HashSet<Integer>();
        for (int level : pathLevels) {
            levelSet.add(level);
        }
        ArrayList<String> transformedNodes = new ArrayList<String>(nodes.length);
        transformedNodes.add(nodes[0]);
        for (int k = 1; k < nodes.length - 1; ++k) {
            if (levelSet.contains(k)) {
                transformedNodes.add(nodes[k]);
                continue;
            }
            transformedNodes.add("*");
        }
        if (isCountStar) {
            transformedNodes.add("*");
        } else {
            transformedNodes.add(nodes[nodes.length - 1]);
        }
        MeasurementPath groupedPath = new MeasurementPath(new PartialPath(transformedNodes.toArray(new String[0])), ((MeasurementPath)rawPath).getMeasurementSchema());
        if (rawPath.isMeasurementAliasExists()) {
            groupedPath.setMeasurementAlias(rawPath.getMeasurementAlias());
        }
        return groupedPath;
    }

    public Map<Expression, Set<Expression>> getGroupedPathMap() {
        return this.groupedPathMap;
    }

    public String getAlias(String columnName) {
        return this.columnToAliasMap.get(columnName) != null ? this.columnToAliasMap.get(columnName) : null;
    }

    public Map<Expression, Expression> getRawPathToGroupedPathMap() {
        return this.rawPathToGroupedPathMap;
    }

    private void updateTypeProvider(String functionName, String groupedPath, PartialPath rawPath) {
        List<AggregationType> splitAggregations = SchemaUtils.splitPartialAggregation(AggregationType.valueOf(functionName.toUpperCase()));
        for (AggregationType aggregationType : splitAggregations) {
            String splitFunctionName = aggregationType.toString().toLowerCase();
            this.typeProvider.setType(String.format("%s(%s)", splitFunctionName, groupedPath), SchemaUtils.getSeriesTypeByPath(rawPath, splitFunctionName));
        }
    }
}

