/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.ballerinalang.compiler.semantics.analyzer;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.ballerinalang.model.tree.NodeKind;
import org.ballerinalang.model.tree.OperatorKind;
import org.wso2.ballerinalang.compiler.semantics.analyzer.SymbolEnter;
import org.wso2.ballerinalang.compiler.semantics.analyzer.Types;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolEnv;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolTable;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BVarSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.types.BFiniteType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BUnionType;
import org.wso2.ballerinalang.compiler.tree.BLangBlockFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangNode;
import org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangBinaryExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangExpression;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangGroupExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangSimpleVarRef;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangTypeTestExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangUnaryExpr;
import org.wso2.ballerinalang.compiler.tree.statements.BLangBlockStmt;
import org.wso2.ballerinalang.compiler.util.CompilerContext;

public class TypeNarrower
extends BLangNodeVisitor {
    private SymbolEnv env;
    private SymbolTable symTable;
    private Types types;
    private SymbolEnter symbolEnter;
    private static final CompilerContext.Key<TypeNarrower> TYPE_NARROWER_KEY = new CompilerContext.Key();

    private TypeNarrower(CompilerContext context) {
        context.put(TYPE_NARROWER_KEY, this);
        this.symTable = SymbolTable.getInstance(context);
        this.types = Types.getInstance(context);
        this.symbolEnter = SymbolEnter.getInstance(context);
    }

    public static TypeNarrower getInstance(CompilerContext context) {
        TypeNarrower typeNarrower = context.get(TYPE_NARROWER_KEY);
        if (typeNarrower == null) {
            typeNarrower = new TypeNarrower(context);
        }
        return typeNarrower;
    }

    public SymbolEnv evaluateTruth(BLangExpression expr, BLangNode targetNode, SymbolEnv env) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = this.getNarrowedTypes(expr, env);
        if (narrowedTypes.isEmpty()) {
            return env;
        }
        SymbolEnv targetEnv = this.getTargetEnv(targetNode, env);
        narrowedTypes.forEach((symbol, typeInfo) -> this.symbolEnter.defineTypeNarrowedSymbol(expr.pos, targetEnv, this.getOriginalVarSymbol((BVarSymbol)symbol), typeInfo.trueType));
        return targetEnv;
    }

    public SymbolEnv evaluateFalsity(BLangExpression expr, BLangNode targetNode, SymbolEnv env) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = this.getNarrowedTypes(expr, env);
        if (narrowedTypes.isEmpty()) {
            return env;
        }
        SymbolEnv targetEnv = this.getTargetEnv(targetNode, env);
        narrowedTypes.forEach((symbol, typeInfo) -> this.symbolEnter.defineTypeNarrowedSymbol(expr.pos, targetEnv, this.getOriginalVarSymbol((BVarSymbol)symbol), typeInfo.falseType));
        return targetEnv;
    }

    @Override
    public void visit(BLangUnaryExpr unaryExpr) {
        if (unaryExpr.operator != OperatorKind.NOT) {
            return;
        }
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = this.getNarrowedTypes(unaryExpr.expr, this.env);
        HashMap<BVarSymbol, BType.NarrowedTypes> newMap = new HashMap<BVarSymbol, BType.NarrowedTypes>(narrowedTypes.size());
        for (Map.Entry<BVarSymbol, BType.NarrowedTypes> entry : narrowedTypes.entrySet()) {
            newMap.put(entry.getKey(), new BType.NarrowedTypes(entry.getValue().falseType, entry.getValue().trueType));
        }
        unaryExpr.narrowedTypeInfo = newMap;
    }

    @Override
    public void visit(BLangBinaryExpr binaryExpr) {
        Map<BVarSymbol, BType.NarrowedTypes> t1 = this.getNarrowedTypes(binaryExpr.lhsExpr, this.env);
        Map<BVarSymbol, BType.NarrowedTypes> t2 = this.getNarrowedTypes(binaryExpr.rhsExpr, this.env);
        LinkedHashSet<BVarSymbol> updatedSymbols = new LinkedHashSet<BVarSymbol>(t1.keySet());
        updatedSymbols.addAll(t2.keySet());
        if (binaryExpr.opKind == OperatorKind.AND || binaryExpr.opKind == OperatorKind.OR) {
            binaryExpr.narrowedTypeInfo.putAll(updatedSymbols.stream().collect(Collectors.toMap(symbol -> this.getOriginalVarSymbol((BVarSymbol)symbol), symbol -> this.getNarrowedTypesForBinaryOp(t1, t2, this.getOriginalVarSymbol((BVarSymbol)symbol), binaryExpr.opKind))));
        }
    }

    @Override
    public void visit(BLangGroupExpr groupExpr) {
        this.analyzeExpr(groupExpr.expression, this.env);
        groupExpr.narrowedTypeInfo.putAll(groupExpr.expression.narrowedTypeInfo);
    }

    @Override
    public void visit(BLangTypeTestExpr typeTestExpr) {
        this.analyzeExpr(typeTestExpr.expr, this.env);
        if (typeTestExpr.expr.getKind() != NodeKind.SIMPLE_VARIABLE_REF) {
            return;
        }
        BVarSymbol varSymbol = (BVarSymbol)((BLangSimpleVarRef)typeTestExpr.expr).symbol;
        if (varSymbol == null) {
            return;
        }
        BType trueType = this.getTypeIntersection(varSymbol.type, typeTestExpr.typeNode.type);
        BType falseType = this.types.getRemainingType(varSymbol.type, typeTestExpr.typeNode.type);
        typeTestExpr.narrowedTypeInfo.put(this.getOriginalVarSymbol(varSymbol), new BType.NarrowedTypes(trueType, falseType));
    }

    private Map<BVarSymbol, BType.NarrowedTypes> getNarrowedTypes(BLangExpression expr, SymbolEnv env) {
        this.analyzeExpr(expr, env);
        return expr.narrowedTypeInfo;
    }

    private void analyzeExpr(BLangExpression expr, SymbolEnv env) {
        switch (expr.getKind()) {
            case BINARY_EXPR: 
            case TYPE_TEST_EXPR: 
            case GROUP_EXPR: 
            case UNARY_EXPR: {
                break;
            }
            default: {
                if (expr.narrowedTypeInfo == null) {
                    expr.narrowedTypeInfo = new HashMap<BVarSymbol, BType.NarrowedTypes>();
                }
                return;
            }
        }
        SymbolEnv prevEnv = this.env;
        this.env = env;
        if (expr.narrowedTypeInfo == null) {
            expr.narrowedTypeInfo = new HashMap<BVarSymbol, BType.NarrowedTypes>();
            expr.accept(this);
        }
        this.env = prevEnv;
    }

    private BType.NarrowedTypes getNarrowedTypesForBinaryOp(Map<BVarSymbol, BType.NarrowedTypes> lhsTypes, Map<BVarSymbol, BType.NarrowedTypes> rhsTypes, BVarSymbol symbol, OperatorKind operator) {
        BType falseType;
        BType trueType;
        BType rhsFalseType;
        BType rhsTrueType;
        BType lhsFalseType;
        BType lhsTrueType;
        BType.NarrowedTypes narrowedTypes;
        if (lhsTypes.containsKey(symbol)) {
            narrowedTypes = lhsTypes.get(symbol);
            lhsTrueType = narrowedTypes.trueType;
            lhsFalseType = narrowedTypes.falseType;
        } else {
            lhsTrueType = lhsFalseType = symbol.type;
        }
        if (rhsTypes.containsKey(symbol)) {
            narrowedTypes = rhsTypes.get(symbol);
            rhsTrueType = narrowedTypes.trueType;
            rhsFalseType = narrowedTypes.falseType;
        } else {
            rhsTrueType = rhsFalseType = symbol.type;
        }
        if (operator == OperatorKind.AND) {
            trueType = this.getTypeIntersection(lhsTrueType, rhsTrueType);
            BType tmpType = this.getTypeIntersection(lhsTrueType, rhsFalseType);
            falseType = this.getTypeUnion(lhsFalseType, tmpType);
        } else {
            BType tmpType = this.getTypeIntersection(lhsFalseType, rhsTrueType);
            trueType = this.getTypeUnion(lhsTrueType, tmpType);
            falseType = this.getTypeIntersection(lhsFalseType, rhsFalseType);
        }
        return new BType.NarrowedTypes(trueType, falseType);
    }

    private BType getTypeIntersection(BType currentType, BType targetType) {
        List<BType> narrowingTypes = this.types.getAllTypes(targetType);
        LinkedHashSet intersection = narrowingTypes.stream().map(type -> {
            BType intersectionType;
            if (this.types.isAssignable((BType)type, currentType)) {
                return type;
            }
            if (this.types.isAssignable(currentType, (BType)type)) {
                return currentType;
            }
            if (currentType.tag == 31) {
                BType intersectionType2 = this.types.getTypeForFiniteTypeValuesAssignableToType((BFiniteType)currentType, (BType)type);
                if (intersectionType2 != this.symTable.semanticError) {
                    return intersectionType2;
                }
            } else if (type.tag == 31) {
                BType intersectionType3 = this.types.getTypeForFiniteTypeValuesAssignableToType((BFiniteType)type, currentType);
                if (intersectionType3 != this.symTable.semanticError) {
                    return intersectionType3;
                }
            } else if (currentType.tag == 20) {
                BType intersectionType4 = this.types.getTypeForUnionTypeMembersAssignableToType((BUnionType)currentType, (BType)type);
                if (intersectionType4 != this.symTable.semanticError) {
                    return intersectionType4;
                }
            } else if (type.tag == 20 && (intersectionType = this.types.getTypeForUnionTypeMembersAssignableToType((BUnionType)type, currentType)) != this.symTable.semanticError) {
                return intersectionType;
            }
            return null;
        }).filter(type -> type != null).collect(Collectors.toCollection(LinkedHashSet::new));
        if (intersection.isEmpty() || intersection.contains(this.symTable.semanticError)) {
            return this.symTable.semanticError;
        }
        if (intersection.size() == 1) {
            return intersection.toArray(new BType[0])[0];
        }
        return BUnionType.create(null, intersection);
    }

    private BType getTypeUnion(BType currentType, BType targetType) {
        LinkedHashSet<BType> union = new LinkedHashSet<BType>(this.types.getAllTypes(currentType));
        this.types.getAllTypes(targetType).stream().filter(newType -> union.stream().anyMatch(existingType -> !this.types.isAssignable((BType)newType, (BType)existingType))).forEach(newType -> union.add((BType)newType));
        if (union.contains(this.symTable.semanticError)) {
            return this.symTable.semanticError;
        }
        if (union.size() == 1) {
            return union.toArray(new BType[1])[0];
        }
        return BUnionType.create(null, union);
    }

    BVarSymbol getOriginalVarSymbol(BVarSymbol varSymbol) {
        if (varSymbol.originalSymbol == null) {
            return varSymbol;
        }
        return this.getOriginalVarSymbol(varSymbol.originalSymbol);
    }

    private SymbolEnv getTargetEnv(BLangNode targetNode, SymbolEnv env) {
        SymbolEnv targetEnv = SymbolEnv.createTypeNarrowedEnv(targetNode, env);
        if (targetNode.getKind() == NodeKind.BLOCK) {
            ((BLangBlockStmt)targetNode).scope = targetEnv.scope;
        }
        if (targetNode.getKind() == NodeKind.BLOCK_FUNCTION_BODY) {
            ((BLangBlockFunctionBody)targetNode).scope = targetEnv.scope;
        }
        return targetEnv;
    }
}

