package com.microsoft.azure.spring.autoconfigure.aad;

import com.fasterxml.jackson.databind.JsonNode;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.source.RemoteJWKSet;
import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;
import com.nimbusds.jwt.proc.DefaultJWTClaimsVerifier;
import com.nimbusds.jwt.proc.DefaultJWTProcessor;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;

/* loaded from: input_file:com/microsoft/azure/spring/autoconfigure/aad/UserPrincipal.class */
public class UserPrincipal {
    private static final Logger LOG = LoggerFactory.getLogger(UserPrincipal.class);
    private ServiceEndpoints serviceEndpoints;
    private JWKSet jwsKeySet;
    private JWSObject jwsObject;
    private JWTClaimsSet jwtClaimsSet;
    private List<UserGroup> userGroups;

    public UserPrincipal() {
        this.jwsObject = null;
        this.jwtClaimsSet = null;
        this.userGroups = null;
        this.serviceEndpoints = new ServiceEndpoints();
    }

    public UserPrincipal(String str, ServiceEndpoints serviceEndpoints) throws MalformedURLException, ParseException, BadJOSEException, JOSEException {
        this.serviceEndpoints = serviceEndpoints;
        this.jwsKeySet = loadAadPublicKeys();
        ConfigurableJWTProcessor<SecurityContext> aadJwtTokenValidator = getAadJwtTokenValidator();
        this.jwtClaimsSet = aadJwtTokenValidator.process(str, (SecurityContext) null);
        aadJwtTokenValidator.getJWTClaimsSetVerifier().verify(this.jwtClaimsSet, (SecurityContext) null);
        this.jwsObject = JWSObject.parse(str);
        this.userGroups = null;
    }

    private JWKSet loadAadPublicKeys() {
        try {
            return JWKSet.load(new URL(this.serviceEndpoints.getAadKeyDiscoveryUri()));
        } catch (IOException | ParseException e) {
            LOG.error("Error loading AAD public keys: {}", e.getMessage());
            return null;
        }
    }

    public String getIssuer() {
        if (this.jwtClaimsSet == null) {
            return null;
        }
        return this.jwtClaimsSet.getIssuer();
    }

    public String getSubject() {
        if (this.jwtClaimsSet == null) {
            return null;
        }
        return this.jwtClaimsSet.getSubject();
    }

    public Map<String, Object> getClaims() {
        if (this.jwtClaimsSet == null) {
            return null;
        }
        return this.jwtClaimsSet.getClaims();
    }

    public Object getClaim() {
        if (this.jwtClaimsSet == null) {
            return null;
        }
        return this.jwtClaimsSet.getClaim("tid");
    }

    public String getKid() {
        if (this.jwsObject == null) {
            return null;
        }
        return this.jwsObject.getHeader().getKeyID();
    }

    public JWK getJWKByKid(String str) {
        if (this.jwsKeySet == null) {
            return null;
        }
        return this.jwsKeySet.getKeyByKeyId(str);
    }

    public List<UserGroup> getGroups(String str) throws IOException {
        if (this.userGroups == null) {
            this.userGroups = loadUserGroups(str);
        }
        return this.userGroups;
    }

    public boolean isMemberOf(UserGroup userGroup) {
        return (this.userGroups == null || this.userGroups.isEmpty() || !this.userGroups.contains(userGroup)) ? false : true;
    }

    public List<GrantedAuthority> getAuthoritiesByUserGroups(List<UserGroup> list, List<String> list2) {
        return (list == null || list2 == null || list.isEmpty() || list2.isEmpty()) ? Collections.emptyList() : (List) list.stream().filter(userGroup -> {
            return list2.contains(userGroup.getDisplayName());
        }).map(userGroup2 -> {
            return "ROLE_" + userGroup2.getDisplayName();
        }).map(SimpleGrantedAuthority::new).collect(Collectors.toList());
    }

    public Collection<? extends GrantedAuthority> getAuthorities() {
        return SecurityContextHolder.getContext().getAuthentication().getAuthorities();
    }

    public Authentication getAuthentication() {
        return SecurityContextHolder.getContext().getAuthentication();
    }

    private ConfigurableJWTProcessor<SecurityContext> getAadJwtTokenValidator() throws MalformedURLException {
        DefaultJWTProcessor defaultJWTProcessor = new DefaultJWTProcessor();
        defaultJWTProcessor.setJWSKeySelector(new JWSVerificationKeySelector(JWSAlgorithm.RS256, new RemoteJWKSet(new URL(this.serviceEndpoints.getAadKeyDiscoveryUri()))));
        defaultJWTProcessor.setJWTClaimsSetVerifier(new DefaultJWTClaimsVerifier<SecurityContext>() { // from class: com.microsoft.azure.spring.autoconfigure.aad.UserPrincipal.1
            public void verify(JWTClaimsSet jWTClaimsSet, SecurityContext securityContext) throws BadJWTException {
                super.verify(jWTClaimsSet, securityContext);
                String issuer = jWTClaimsSet.getIssuer();
                if (issuer == null || !(issuer.contains("https://sts.windows.net/") || issuer.contains("https://sts.chinacloudapi.cn/"))) {
                    throw new BadJWTException("Invalid token issuer");
                }
            }
        });
        return defaultJWTProcessor;
    }

    private List<UserGroup> loadUserGroups(String str) throws IOException {
        String userMembershipsV1 = AzureADGraphClient.getUserMembershipsV1(str, this.serviceEndpoints.getAadMembershipRestUri());
        ArrayList arrayList = new ArrayList();
        JsonNode jsonNode = ((JsonNode) JacksonObjectMapperFactory.getInstance().readValue(userMembershipsV1, JsonNode.class)).get("value");
        for (int i = 0; jsonNode != null && jsonNode.get(i) != null; i++) {
            if (jsonNode.get(i).get("objectType").asText().equals("Group")) {
                arrayList.add(new UserGroup(jsonNode.get(i).get("objectId").asText(), jsonNode.get(i).get("displayName").asText()));
            }
        }
        return arrayList;
    }
}
