package org.wso2.ballerinalang.compiler.semantics.analyzer;

import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.stream.Collectors;
import org.ballerinalang.model.tree.NodeKind;
import org.ballerinalang.model.tree.OperatorKind;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolEnv;
import org.wso2.ballerinalang.compiler.semantics.model.SymbolTable;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BTypeSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.symbols.BVarSymbol;
import org.wso2.ballerinalang.compiler.semantics.model.types.BFiniteType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BType;
import org.wso2.ballerinalang.compiler.semantics.model.types.BUnionType;
import org.wso2.ballerinalang.compiler.tree.BLangBlockFunctionBody;
import org.wso2.ballerinalang.compiler.tree.BLangNode;
import org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangBinaryExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangExpression;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangGroupExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangSimpleVarRef;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangTypeTestExpr;
import org.wso2.ballerinalang.compiler.tree.expressions.BLangUnaryExpr;
import org.wso2.ballerinalang.compiler.tree.statements.BLangBlockStmt;
import org.wso2.ballerinalang.compiler.util.CompilerContext;

/* loaded from: input_file:org/wso2/ballerinalang/compiler/semantics/analyzer/TypeNarrower.class */
public class TypeNarrower extends BLangNodeVisitor {
    private SymbolEnv env;
    private SymbolTable symTable;
    private Types types;
    private SymbolEnter symbolEnter;
    private static final CompilerContext.Key<TypeNarrower> TYPE_NARROWER_KEY = new CompilerContext.Key<>();

    private TypeNarrower(CompilerContext compilerContext) {
        compilerContext.put((CompilerContext.Key<CompilerContext.Key<TypeNarrower>>) TYPE_NARROWER_KEY, (CompilerContext.Key<TypeNarrower>) this);
        this.symTable = SymbolTable.getInstance(compilerContext);
        this.types = Types.getInstance(compilerContext);
        this.symbolEnter = SymbolEnter.getInstance(compilerContext);
    }

    public static TypeNarrower getInstance(CompilerContext compilerContext) {
        TypeNarrower typeNarrower = (TypeNarrower) compilerContext.get(TYPE_NARROWER_KEY);
        if (typeNarrower == null) {
            typeNarrower = new TypeNarrower(compilerContext);
        }
        return typeNarrower;
    }

    public SymbolEnv evaluateTruth(BLangExpression bLangExpression, BLangNode bLangNode, SymbolEnv symbolEnv) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = getNarrowedTypes(bLangExpression, symbolEnv);
        if (narrowedTypes.isEmpty()) {
            return symbolEnv;
        }
        SymbolEnv targetEnv = getTargetEnv(bLangNode, symbolEnv);
        narrowedTypes.forEach((bVarSymbol, narrowedTypes2) -> {
            this.symbolEnter.defineTypeNarrowedSymbol(bLangExpression.pos, targetEnv, getOriginalVarSymbol(bVarSymbol), narrowedTypes2.trueType);
        });
        return targetEnv;
    }

    public SymbolEnv evaluateFalsity(BLangExpression bLangExpression, BLangNode bLangNode, SymbolEnv symbolEnv) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = getNarrowedTypes(bLangExpression, symbolEnv);
        if (narrowedTypes.isEmpty()) {
            return symbolEnv;
        }
        SymbolEnv targetEnv = getTargetEnv(bLangNode, symbolEnv);
        narrowedTypes.forEach((bVarSymbol, narrowedTypes2) -> {
            this.symbolEnter.defineTypeNarrowedSymbol(bLangExpression.pos, targetEnv, getOriginalVarSymbol(bVarSymbol), narrowedTypes2.falseType);
        });
        return targetEnv;
    }

    @Override // org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor
    public void visit(BLangUnaryExpr bLangUnaryExpr) {
        if (bLangUnaryExpr.operator != OperatorKind.NOT) {
            return;
        }
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = getNarrowedTypes(bLangUnaryExpr.expr, this.env);
        HashMap hashMap = new HashMap(narrowedTypes.size());
        for (Map.Entry<BVarSymbol, BType.NarrowedTypes> entry : narrowedTypes.entrySet()) {
            hashMap.put(entry.getKey(), new BType.NarrowedTypes(entry.getValue().falseType, entry.getValue().trueType));
        }
        bLangUnaryExpr.narrowedTypeInfo = hashMap;
    }

    @Override // org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor
    public void visit(BLangBinaryExpr bLangBinaryExpr) {
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes = getNarrowedTypes(bLangBinaryExpr.lhsExpr, this.env);
        Map<BVarSymbol, BType.NarrowedTypes> narrowedTypes2 = getNarrowedTypes(bLangBinaryExpr.rhsExpr, this.env);
        LinkedHashSet linkedHashSet = new LinkedHashSet(narrowedTypes.keySet());
        linkedHashSet.addAll(narrowedTypes2.keySet());
        if (bLangBinaryExpr.opKind == OperatorKind.AND || bLangBinaryExpr.opKind == OperatorKind.OR) {
            bLangBinaryExpr.narrowedTypeInfo.putAll((Map) linkedHashSet.stream().collect(Collectors.toMap(bVarSymbol -> {
                return getOriginalVarSymbol(bVarSymbol);
            }, bVarSymbol2 -> {
                return getNarrowedTypesForBinaryOp(narrowedTypes, narrowedTypes2, getOriginalVarSymbol(bVarSymbol2), bLangBinaryExpr.opKind);
            })));
        }
    }

    @Override // org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor
    public void visit(BLangGroupExpr bLangGroupExpr) {
        analyzeExpr(bLangGroupExpr.expression, this.env);
        bLangGroupExpr.narrowedTypeInfo.putAll(bLangGroupExpr.expression.narrowedTypeInfo);
    }

    @Override // org.wso2.ballerinalang.compiler.tree.BLangNodeVisitor
    public void visit(BLangTypeTestExpr bLangTypeTestExpr) {
        BVarSymbol bVarSymbol;
        analyzeExpr(bLangTypeTestExpr.expr, this.env);
        if (bLangTypeTestExpr.expr.getKind() == NodeKind.SIMPLE_VARIABLE_REF && (bVarSymbol = (BVarSymbol) ((BLangSimpleVarRef) bLangTypeTestExpr.expr).symbol) != null) {
            bLangTypeTestExpr.narrowedTypeInfo.put(getOriginalVarSymbol(bVarSymbol), new BType.NarrowedTypes(getTypeIntersection(bVarSymbol.type, bLangTypeTestExpr.typeNode.type), this.types.getRemainingType(bVarSymbol.type, bLangTypeTestExpr.typeNode.type)));
        }
    }

    private Map<BVarSymbol, BType.NarrowedTypes> getNarrowedTypes(BLangExpression bLangExpression, SymbolEnv symbolEnv) {
        analyzeExpr(bLangExpression, symbolEnv);
        return bLangExpression.narrowedTypeInfo;
    }

    private void analyzeExpr(BLangExpression bLangExpression, SymbolEnv symbolEnv) {
        switch (bLangExpression.getKind()) {
            case BINARY_EXPR:
            case TYPE_TEST_EXPR:
            case GROUP_EXPR:
            case UNARY_EXPR:
                SymbolEnv symbolEnv2 = this.env;
                this.env = symbolEnv;
                if (bLangExpression.narrowedTypeInfo == null) {
                    bLangExpression.narrowedTypeInfo = new HashMap();
                    bLangExpression.accept(this);
                }
                this.env = symbolEnv2;
                return;
            default:
                if (bLangExpression.narrowedTypeInfo == null) {
                    bLangExpression.narrowedTypeInfo = new HashMap();
                    return;
                }
                return;
        }
    }

    private BType.NarrowedTypes getNarrowedTypesForBinaryOp(Map<BVarSymbol, BType.NarrowedTypes> map, Map<BVarSymbol, BType.NarrowedTypes> map2, BVarSymbol bVarSymbol, OperatorKind operatorKind) {
        BType bType;
        BType bType2;
        BType bType3;
        BType bType4;
        BType typeUnion;
        BType typeIntersection;
        if (map.containsKey(bVarSymbol)) {
            BType.NarrowedTypes narrowedTypes = map.get(bVarSymbol);
            bType2 = narrowedTypes.trueType;
            bType = narrowedTypes.falseType;
        } else {
            BType bType5 = bVarSymbol.type;
            bType = bType5;
            bType2 = bType5;
        }
        if (map2.containsKey(bVarSymbol)) {
            BType.NarrowedTypes narrowedTypes2 = map2.get(bVarSymbol);
            bType4 = narrowedTypes2.trueType;
            bType3 = narrowedTypes2.falseType;
        } else {
            BType bType6 = bVarSymbol.type;
            bType3 = bType6;
            bType4 = bType6;
        }
        if (operatorKind == OperatorKind.AND) {
            typeUnion = getTypeIntersection(bType2, bType4);
            typeIntersection = getTypeUnion(bType, getTypeIntersection(bType2, bType3));
        } else {
            typeUnion = getTypeUnion(bType2, getTypeIntersection(bType, bType4));
            typeIntersection = getTypeIntersection(bType, bType3);
        }
        return new BType.NarrowedTypes(typeUnion, typeIntersection);
    }

    private BType getTypeIntersection(BType bType, BType bType2) {
        LinkedHashSet linkedHashSet = (LinkedHashSet) this.types.getAllTypes(bType2).stream().map(bType3 -> {
            BType typeForUnionTypeMembersAssignableToType;
            if (this.types.isAssignable(bType3, bType)) {
                return bType3;
            }
            if (this.types.isAssignable(bType, bType3)) {
                return bType;
            }
            if (bType.tag == 31) {
                BType typeForFiniteTypeValuesAssignableToType = this.types.getTypeForFiniteTypeValuesAssignableToType((BFiniteType) bType, bType3);
                if (typeForFiniteTypeValuesAssignableToType != this.symTable.semanticError) {
                    return typeForFiniteTypeValuesAssignableToType;
                }
                return null;
            }
            if (bType3.tag == 31) {
                BType typeForFiniteTypeValuesAssignableToType2 = this.types.getTypeForFiniteTypeValuesAssignableToType((BFiniteType) bType3, bType);
                if (typeForFiniteTypeValuesAssignableToType2 != this.symTable.semanticError) {
                    return typeForFiniteTypeValuesAssignableToType2;
                }
                return null;
            }
            if (bType.tag == 20) {
                BType typeForUnionTypeMembersAssignableToType2 = this.types.getTypeForUnionTypeMembersAssignableToType((BUnionType) bType, bType3);
                if (typeForUnionTypeMembersAssignableToType2 != this.symTable.semanticError) {
                    return typeForUnionTypeMembersAssignableToType2;
                }
                return null;
            }
            if (bType3.tag != 20 || (typeForUnionTypeMembersAssignableToType = this.types.getTypeForUnionTypeMembersAssignableToType((BUnionType) bType3, bType)) == this.symTable.semanticError) {
                return null;
            }
            return typeForUnionTypeMembersAssignableToType;
        }).filter(bType4 -> {
            return bType4 != null;
        }).collect(Collectors.toCollection(LinkedHashSet::new));
        return (linkedHashSet.isEmpty() || linkedHashSet.contains(this.symTable.semanticError)) ? this.symTable.semanticError : linkedHashSet.size() == 1 ? ((BType[]) linkedHashSet.toArray(new BType[0]))[0] : BUnionType.create((BTypeSymbol) null, (LinkedHashSet<BType>) linkedHashSet);
    }

    private BType getTypeUnion(BType bType, BType bType2) {
        LinkedHashSet linkedHashSet = new LinkedHashSet(this.types.getAllTypes(bType));
        this.types.getAllTypes(bType2).stream().filter(bType3 -> {
            return linkedHashSet.stream().anyMatch(bType3 -> {
                return !this.types.isAssignable(bType3, bType3);
            });
        }).forEach(bType4 -> {
            linkedHashSet.add(bType4);
        });
        return linkedHashSet.contains(this.symTable.semanticError) ? this.symTable.semanticError : linkedHashSet.size() == 1 ? ((BType[]) linkedHashSet.toArray(new BType[1]))[0] : BUnionType.create((BTypeSymbol) null, (LinkedHashSet<BType>) linkedHashSet);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public BVarSymbol getOriginalVarSymbol(BVarSymbol bVarSymbol) {
        return bVarSymbol.originalSymbol == null ? bVarSymbol : getOriginalVarSymbol(bVarSymbol.originalSymbol);
    }

    private SymbolEnv getTargetEnv(BLangNode bLangNode, SymbolEnv symbolEnv) {
        SymbolEnv createTypeNarrowedEnv = SymbolEnv.createTypeNarrowedEnv(bLangNode, symbolEnv);
        if (bLangNode.getKind() == NodeKind.BLOCK) {
            ((BLangBlockStmt) bLangNode).scope = createTypeNarrowedEnv.scope;
        }
        if (bLangNode.getKind() == NodeKind.BLOCK_FUNCTION_BODY) {
            ((BLangBlockFunctionBody) bLangNode).scope = createTypeNarrowedEnv.scope;
        }
        return createTypeNarrowedEnv;
    }
}
