/*
 * Decompiled with CFR 0.152.
 */
package com.sap.cloud.security.xsuaa.token.authentication;

import com.github.benmanes.caffeine.cache.Caffeine;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTParser;
import com.sap.cloud.security.token.ProviderNotFoundException;
import com.sap.cloud.security.token.validation.XsuaaJkuFactory;
import com.sap.cloud.security.xsuaa.XsuaaServiceConfiguration;
import com.sap.cloud.security.xsuaa.token.authentication.PostValidationAction;
import com.sap.cloud.security.xsuaa.token.authentication.TokenInfoExtractor;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.ProviderException;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.X509EncodedKeySpec;
import java.text.ParseException;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Base64;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.ServiceConfigurationError;
import java.util.ServiceLoader;
import javax.annotation.Nullable;
import org.json.JSONObject;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.cache.Cache;
import org.springframework.cache.concurrent.ConcurrentMapCache;
import org.springframework.security.oauth2.core.OAuth2TokenValidator;
import org.springframework.security.oauth2.jwt.BadJwtException;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtDecoder;
import org.springframework.security.oauth2.jwt.JwtException;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;

public class XsuaaJwtDecoder
implements JwtDecoder {
    List<XsuaaJkuFactory> jkuFactories = new ArrayList<XsuaaJkuFactory>(){
        {
            try {
                ServiceLoader.load(XsuaaJkuFactory.class).forEach(this::add);
                logger.debug("loaded XsuaaJkuFactory service providers: {}", (Object)this);
            }
            catch (Exception | ServiceConfigurationError e) {
                logger.warn("Unexpected failure while loading XsuaaJkuFactory service providers: {}", (Object)e.getMessage());
            }
        }
    };
    private static final Logger logger = LoggerFactory.getLogger(XsuaaJwtDecoder.class);
    private final XsuaaServiceConfiguration xsuaaServiceConfiguration;
    private final Duration cacheValidityInSeconds;
    private final int cacheSize;
    final com.github.benmanes.caffeine.cache.Cache<String, JwtDecoder> cache;
    private final OAuth2TokenValidator<Jwt> tokenValidators;
    private final Collection<PostValidationAction> postValidationActions;
    private TokenInfoExtractor tokenInfoExtractor;
    private RestOperations restOperations;

    XsuaaJwtDecoder(final XsuaaServiceConfiguration xsuaaServiceConfiguration, int cacheValidityInSeconds, int cacheSize, OAuth2TokenValidator<Jwt> tokenValidators, Collection<PostValidationAction> postValidationActions) {
        this.cacheValidityInSeconds = Duration.ofSeconds(cacheValidityInSeconds);
        this.cacheSize = cacheSize;
        this.cache = Caffeine.newBuilder().expireAfterWrite(this.cacheValidityInSeconds).maximumSize((long)this.cacheSize).build();
        this.tokenValidators = tokenValidators;
        this.xsuaaServiceConfiguration = xsuaaServiceConfiguration;
        this.tokenInfoExtractor = new TokenInfoExtractor(){

            @Override
            public String getJku(JWT jwt) {
                return new JSONObject(jwt.getHeader().toString()).optString("jku", null);
            }

            @Override
            public String getKid(JWT jwt) {
                return new JSONObject(jwt.getHeader().toString()).optString("kid", null);
            }

            @Override
            public String getUaaDomain(JWT jwt) {
                return xsuaaServiceConfiguration.getUaaDomain();
            }
        };
        this.postValidationActions = postValidationActions != null ? postValidationActions : Collections.emptyList();
    }

    public Jwt decode(String token) throws BadJwtException {
        JWT jwt;
        Assert.notNull((Object)token, (String)"token is required");
        try {
            jwt = JWTParser.parse((String)token);
        }
        catch (ParseException ex) {
            throw new BadJwtException("Error initializing JWT decoder: " + ex.getMessage());
        }
        Jwt verifiedToken = this.verifyToken(jwt);
        this.postValidationActions.forEach(action -> action.perform(verifiedToken));
        return verifiedToken;
    }

    public void setTokenInfoExtractor(TokenInfoExtractor tokenInfoExtractor) {
        this.tokenInfoExtractor = tokenInfoExtractor;
    }

    public void setRestOperations(RestOperations restOperations) {
        this.restOperations = restOperations;
    }

    private Jwt verifyToken(JWT jwt) {
        try {
            String kid = this.tokenInfoExtractor.getKid(jwt);
            String uaaDomain = this.tokenInfoExtractor.getUaaDomain(jwt);
            this.validateJwksParameters(kid, uaaDomain);
            return this.verifyToken(jwt.getParsedString(), kid, uaaDomain, XsuaaJwtDecoder.getZid(jwt));
        }
        catch (JwtException e) {
            if (e.getMessage().contains("Couldn't retrieve remote JWK set") || e.getMessage().contains("Cannot verify with online token key, uaadomain is")) {
                logger.error(e.getMessage());
                return this.tryToVerifyWithVerificationKey(jwt.getParsedString(), e);
            }
            throw e;
        }
    }

    @Nullable
    private static String getZid(JWT jwt) {
        String zid;
        try {
            zid = jwt.getJWTClaimsSet().getStringClaim("zid");
        }
        catch (ParseException e) {
            zid = null;
        }
        if (zid != null && zid.isBlank()) {
            zid = null;
        }
        return zid;
    }

    private Jwt verifyToken(String token, String kid, String uaaDomain, String zid) {
        String jku;
        if (this.jkuFactories.isEmpty()) {
            jku = this.composeJku(uaaDomain, zid);
        } else {
            logger.info("Loaded custom JKU factory");
            try {
                jku = this.jkuFactories.get(0).create(token);
            }
            catch (ProviderNotFoundException | IllegalArgumentException | ProviderException e) {
                throw new BadJwtException("JKU validation failed: " + e.getMessage());
            }
        }
        return this.verifyWithKey(token, jku, kid);
    }

    private void validateJwksParameters(String kid, String uaadomain) {
        if (kid != null && uaadomain != null) {
            return;
        }
        ArrayList<String> nullParams = new ArrayList<String>();
        if (kid == null) {
            nullParams.add("kid");
        }
        if (uaadomain == null) {
            nullParams.add("uaadomain");
        }
        throw new BadJwtException(String.format("Cannot verify with online token key, %s is null", String.join((CharSequence)", ", nullParams)));
    }

    private String composeJku(String uaaDomain, String zid) {
        Object zidQueryParam;
        Object object = zidQueryParam = zid != null ? "?zid=" + zid : "";
        if (uaaDomain.startsWith("http://")) {
            return uaaDomain + "/token_keys" + (String)zidQueryParam;
        }
        return "https://" + uaaDomain + "/token_keys" + (String)zidQueryParam;
    }

    private Jwt verifyWithKey(String token, String jku, String kid) {
        String cacheKey = jku + kid;
        JwtDecoder decoder = (JwtDecoder)this.cache.get((Object)cacheKey, k -> this.getDecoder(jku));
        return decoder.decode(token);
    }

    private JwtDecoder getDecoder(String jku) {
        ConcurrentMapCache jwkSetCache = new ConcurrentMapCache("jwkSetCache", Caffeine.newBuilder().expireAfterWrite(this.cacheValidityInSeconds).maximumSize((long)this.cacheSize).build().asMap(), false);
        NimbusJwtDecoder.JwkSetUriJwtDecoderBuilder jwkSetUriJwtDecoderBuilder = NimbusJwtDecoder.withJwkSetUri((String)jku).cache((Cache)jwkSetCache);
        if (this.restOperations != null) {
            jwkSetUriJwtDecoderBuilder.restOperations(this.restOperations);
        }
        NimbusJwtDecoder jwtDecoder = jwkSetUriJwtDecoderBuilder.build();
        jwtDecoder.setJwtValidator(this.tokenValidators);
        return jwtDecoder;
    }

    private Jwt tryToVerifyWithVerificationKey(String token, JwtException verificationException) {
        logger.debug("Falling back to token validation with verificationkey");
        String verificationKey = this.xsuaaServiceConfiguration.getVerificationKey();
        if (!StringUtils.hasText((String)verificationKey)) {
            throw verificationException;
        }
        return this.verifyWithVerificationKey(token, verificationKey, verificationException);
    }

    private Jwt verifyWithVerificationKey(String token, String verificationKey, JwtException onlineVerificationException) {
        try {
            RSAPublicKey rsaPublicKey = this.createPublicKey(verificationKey);
            NimbusJwtDecoder decoder = NimbusJwtDecoder.withPublicKey((RSAPublicKey)rsaPublicKey).build();
            decoder.setJwtValidator(this.tokenValidators);
            return decoder.decode(token);
        }
        catch (IllegalArgumentException | NoSuchAlgorithmException | InvalidKeySpecException e) {
            logger.error("Jwt signature validation with fallback verificationkey failed: {}", (Object)e.getMessage());
            throw new JwtException("Jwt validation with fallback verificationkey failed", (Throwable)onlineVerificationException);
        }
    }

    private static String extractKey(String pemEncodedKey) {
        return pemEncodedKey.replace("\n", "").replace("\\n", "").replace("\r", "").replace("\\r", "").replace("-----BEGIN PUBLIC KEY-----", "").replace("-----END PUBLIC KEY-----", "");
    }

    private RSAPublicKey createPublicKey(String pemEncodedPublicKey) throws NoSuchAlgorithmException, InvalidKeySpecException {
        logger.debug("verificationkey={}", (Object)pemEncodedPublicKey);
        String key = XsuaaJwtDecoder.extractKey(pemEncodedPublicKey);
        logger.debug("RSA public key n+e={}", (Object)key);
        byte[] decodedKey = Base64.getDecoder().decode(key);
        X509EncodedKeySpec specX509 = new X509EncodedKeySpec(decodedKey);
        KeyFactory keyFactory = KeyFactory.getInstance("RSA");
        RSAPublicKey rsaPublicKeyX509 = (RSAPublicKey)keyFactory.generatePublic(specX509);
        logger.debug("parsed RSA e={}, n={}", (Object)rsaPublicKeyX509.getPublicExponent(), (Object)rsaPublicKeyX509.getModulus());
        return rsaPublicKeyX509;
    }
}

