package io.micronaut.security.oauth2.endpoint.authorization.response;

import com.nimbusds.jwt.JWT;
import io.micronaut.context.annotation.Requirements;
import io.micronaut.context.annotation.Requires;
import io.micronaut.core.annotation.Nullable;
import io.micronaut.security.authentication.AuthenticationFailed;
import io.micronaut.security.authentication.AuthenticationResponse;
import io.micronaut.security.oauth2.client.OpenIdProviderMetadata;
import io.micronaut.security.oauth2.configuration.OauthClientConfiguration;
import io.micronaut.security.oauth2.endpoint.SecureEndpoint;
import io.micronaut.security.oauth2.endpoint.authorization.pkce.persistence.PkcePersistence;
import io.micronaut.security.oauth2.endpoint.authorization.state.InvalidStateException;
import io.micronaut.security.oauth2.endpoint.authorization.state.State;
import io.micronaut.security.oauth2.endpoint.authorization.state.validation.StateValidator;
import io.micronaut.security.oauth2.endpoint.token.request.TokenEndpointClient;
import io.micronaut.security.oauth2.endpoint.token.request.context.OpenIdCodeTokenRequestContext;
import io.micronaut.security.oauth2.endpoint.token.response.JWTOpenIdClaims;
import io.micronaut.security.oauth2.endpoint.token.response.OpenIdAuthenticationMapper;
import io.micronaut.security.oauth2.endpoint.token.response.OpenIdTokenResponse;
import io.micronaut.security.oauth2.endpoint.token.response.validation.OpenIdTokenResponseValidator;
import io.micronaut.security.oauth2.url.OauthRouteUrlBuilder;
import jakarta.inject.Singleton;
import java.text.ParseException;
import java.util.Objects;
import java.util.Optional;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

@Requirements({@Requires(beans = {OpenIdTokenResponseValidator.class, OpenIdAuthenticationMapper.class, TokenEndpointClient.class, OauthRouteUrlBuilder.class}), @Requires(configuration = "io.micronaut.security.token.jwt")})
@Singleton
/* loaded from: input_file:io/micronaut/security/oauth2/endpoint/authorization/response/DefaultOpenIdAuthorizationResponseHandler.class */
public class DefaultOpenIdAuthorizationResponseHandler<T> implements OpenIdAuthorizationResponseHandler {
    private static final Logger LOG = LoggerFactory.getLogger(DefaultOpenIdAuthorizationResponseHandler.class);
    private final OpenIdTokenResponseValidator tokenResponseValidator;
    private final OpenIdAuthenticationMapper defaultAuthenticationMapper;
    private final TokenEndpointClient tokenEndpointClient;
    private final OauthRouteUrlBuilder<T> oauthRouteUrlBuilder;

    @Nullable
    private final StateValidator stateValidator;

    @Nullable
    private final PkcePersistence pkcePersistence;

    public DefaultOpenIdAuthorizationResponseHandler(OpenIdTokenResponseValidator openIdTokenResponseValidator, OpenIdAuthenticationMapper openIdAuthenticationMapper, TokenEndpointClient tokenEndpointClient, OauthRouteUrlBuilder<T> oauthRouteUrlBuilder, @Nullable StateValidator stateValidator, @Nullable PkcePersistence pkcePersistence) {
        this.tokenResponseValidator = openIdTokenResponseValidator;
        this.defaultAuthenticationMapper = openIdAuthenticationMapper;
        this.tokenEndpointClient = tokenEndpointClient;
        this.oauthRouteUrlBuilder = oauthRouteUrlBuilder;
        this.stateValidator = stateValidator;
        this.pkcePersistence = pkcePersistence;
    }

    @Override // io.micronaut.security.oauth2.endpoint.authorization.response.OpenIdAuthorizationResponseHandler
    public Publisher<AuthenticationResponse> handle(OpenIdAuthorizationResponse openIdAuthorizationResponse, OauthClientConfiguration oauthClientConfiguration, OpenIdProviderMetadata openIdProviderMetadata, @Nullable OpenIdAuthenticationMapper openIdAuthenticationMapper, SecureEndpoint secureEndpoint) {
        try {
            validateState(openIdAuthorizationResponse, oauthClientConfiguration);
            return Flux.from(sendRequest(openIdAuthorizationResponse, oauthClientConfiguration, secureEndpoint)).switchMap(openIdTokenResponse -> {
                Flux from = Flux.from(createAuthenticationResponse(openIdAuthorizationResponse.getNonce(), oauthClientConfiguration, openIdProviderMetadata, openIdTokenResponse, openIdAuthenticationMapper, openIdAuthorizationResponse.getState()));
                Class<AuthenticationResponse> cls = AuthenticationResponse.class;
                Objects.requireNonNull(AuthenticationResponse.class);
                return from.map((v1) -> {
                    return r1.cast(v1);
                });
            });
        } catch (InvalidStateException e) {
            return Flux.just(new AuthenticationFailed("State validation failed: " + e.getMessage()));
        }
    }

    private void validateState(OpenIdAuthorizationResponse openIdAuthorizationResponse, OauthClientConfiguration oauthClientConfiguration) throws InvalidStateException {
        if (this.stateValidator == null) {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Skipping state validation, no state validator found");
            }
        } else {
            if (LOG.isTraceEnabled()) {
                LOG.trace("Validating state found in the authorization response from provider [{}]", oauthClientConfiguration.getName());
            }
            this.stateValidator.validate(openIdAuthorizationResponse.getCallbackRequest(), openIdAuthorizationResponse.getState());
        }
    }

    private Publisher<OpenIdTokenResponse> sendRequest(OpenIdAuthorizationResponse openIdAuthorizationResponse, OauthClientConfiguration oauthClientConfiguration, SecureEndpoint secureEndpoint) {
        return this.tokenEndpointClient.sendRequest(new OpenIdCodeTokenRequestContext(openIdAuthorizationResponse, this.oauthRouteUrlBuilder, secureEndpoint, oauthClientConfiguration, this.pkcePersistence == null ? null : this.pkcePersistence.retrieveCodeVerifier(openIdAuthorizationResponse.getCallbackRequest()).orElse(null)));
    }

    private Flux<AuthenticationResponse> createAuthenticationResponse(String str, OauthClientConfiguration oauthClientConfiguration, OpenIdProviderMetadata openIdProviderMetadata, OpenIdTokenResponse openIdTokenResponse, @Nullable OpenIdAuthenticationMapper openIdAuthenticationMapper, @Nullable State state) {
        try {
            Optional<Publisher<AuthenticationResponse>> validateOpenIdTokenResponse = validateOpenIdTokenResponse(str, oauthClientConfiguration, openIdProviderMetadata, openIdTokenResponse, openIdAuthenticationMapper, state);
            if (validateOpenIdTokenResponse.isPresent()) {
                return Flux.from(validateOpenIdTokenResponse.get());
            }
            if (LOG.isTraceEnabled()) {
                LOG.trace("Token validation failed. Failing authentication");
            }
            return Flux.error(AuthenticationResponse.exception("JWT validation failed"));
        } catch (ParseException e) {
            return Flux.error(e);
        }
    }

    private Optional<Publisher<AuthenticationResponse>> validateOpenIdTokenResponse(String str, OauthClientConfiguration oauthClientConfiguration, OpenIdProviderMetadata openIdProviderMetadata, OpenIdTokenResponse openIdTokenResponse, @Nullable OpenIdAuthenticationMapper openIdAuthenticationMapper, @Nullable State state) throws ParseException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("Token endpoint returned a success response. Validating the JWT");
        }
        Optional<JWT> validate = this.tokenResponseValidator.validate(oauthClientConfiguration, openIdProviderMetadata, openIdTokenResponse, str);
        if (!validate.isPresent()) {
            return Optional.empty();
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace("Token validation succeeded. Creating a user details");
        }
        Flux from = Flux.from((openIdAuthenticationMapper != null ? openIdAuthenticationMapper : this.defaultAuthenticationMapper).createAuthenticationResponse(oauthClientConfiguration.getName(), openIdTokenResponse, new JWTOpenIdClaims(validate.get().getJWTClaimsSet()), state));
        Class<AuthenticationResponse> cls = AuthenticationResponse.class;
        Objects.requireNonNull(AuthenticationResponse.class);
        return Optional.of(from.map((v1) -> {
            return r1.cast(v1);
        }));
    }
}
