package org.springframework.security.oauth2.client.web.reactive.function.client;

import java.net.URI;
import java.time.Clock;
import java.time.Duration;
import java.time.temporal.TemporalAmount;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.reactivestreams.Subscription;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.http.HttpMethod;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.CoreSubscriber;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.core.scheduler.Schedulers;
import reactor.util.context.Context;

/* loaded from: input_file:BOOT-INF/lib/spring-security-oauth2-client-5.1.5.RELEASE.jar:org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.class */
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction, InitializingBean, DisposableBean {
    private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
    private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
    private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
    private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
    private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
    private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
    private ClientRegistrationRepository clientRegistrationRepository;
    private OAuth2AuthorizedClientRepository authorizedClientRepository;
    private boolean defaultOAuth2AuthorizedClient;
    private String defaultClientRegistrationId;
    private Clock clock = Clock.systemUTC();
    private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
    private OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new DefaultClientCredentialsTokenResponseClient();

    /* loaded from: input_file:BOOT-INF/lib/spring-security-oauth2-client-5.1.5.RELEASE.jar:org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction$PrincipalNameAuthentication.class */
    private static class PrincipalNameAuthentication implements Authentication {
        private final String username;

        private PrincipalNameAuthentication(String str) {
            this.username = str;
        }

        @Override // org.springframework.security.core.Authentication
        public Collection<? extends GrantedAuthority> getAuthorities() {
            throw unsupported();
        }

        @Override // org.springframework.security.core.Authentication
        public Object getCredentials() {
            throw unsupported();
        }

        @Override // org.springframework.security.core.Authentication
        public Object getDetails() {
            throw unsupported();
        }

        @Override // org.springframework.security.core.Authentication
        public Object getPrincipal() {
            throw unsupported();
        }

        @Override // org.springframework.security.core.Authentication
        public boolean isAuthenticated() {
            throw unsupported();
        }

        @Override // org.springframework.security.core.Authentication
        public void setAuthenticated(boolean z) throws IllegalArgumentException {
            throw unsupported();
        }

        @Override // java.security.Principal
        public String getName() {
            return this.username;
        }

        private UnsupportedOperationException unsupported() {
            return new UnsupportedOperationException("Not Supported");
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/spring-security-oauth2-client-5.1.5.RELEASE.jar:org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction$RequestContextSubscriber.class */
    public static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
        private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
        private final CoreSubscriber<T> delegate;
        private final HttpServletRequest request;
        private final HttpServletResponse response;
        private final Authentication authentication;

        private RequestContextSubscriber(CoreSubscriber<T> coreSubscriber, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, Authentication authentication) {
            this.delegate = coreSubscriber;
            this.request = httpServletRequest;
            this.response = httpServletResponse;
            this.authentication = authentication;
        }

        @Override // reactor.core.CoreSubscriber
        public Context currentContext() {
            Context currentContext = this.delegate.currentContext();
            return currentContext.hasKey(CONTEXT_DEFAULTED_ATTR_NAME) ? currentContext : Context.of(CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE, ServletOAuth2AuthorizedClientExchangeFilterFunction.HTTP_SERVLET_REQUEST_ATTR_NAME, this.request, ServletOAuth2AuthorizedClientExchangeFilterFunction.HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response, ServletOAuth2AuthorizedClientExchangeFilterFunction.AUTHENTICATION_ATTR_NAME, this.authentication);
        }

        @Override // reactor.core.CoreSubscriber, org.reactivestreams.Subscriber
        public void onSubscribe(Subscription subscription) {
            this.delegate.onSubscribe(subscription);
        }

        @Override // org.reactivestreams.Subscriber
        public void onNext(T t) {
            this.delegate.onNext(t);
        }

        @Override // org.reactivestreams.Subscriber
        public void onError(Throwable th) {
            this.delegate.onError(th);
        }

        @Override // org.reactivestreams.Subscriber
        public void onComplete() {
            this.delegate.onComplete();
        }
    }

    public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
    }

    public ServletOAuth2AuthorizedClientExchangeFilterFunction(ClientRegistrationRepository clientRegistrationRepository, OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository) {
        this.clientRegistrationRepository = clientRegistrationRepository;
        this.authorizedClientRepository = oAuth2AuthorizedClientRepository;
    }

    @Override // org.springframework.beans.factory.InitializingBean
    public void afterPropertiesSet() throws Exception {
        Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((scannable, coreSubscriber) -> {
            return createRequestContextSubscriber(coreSubscriber);
        }));
    }

    @Override // org.springframework.beans.factory.DisposableBean
    public void destroy() throws Exception {
        Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
    }

    public void setClientCredentialsTokenResponseClient(OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> oAuth2AccessTokenResponseClient) {
        Assert.notNull(oAuth2AccessTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
        this.clientCredentialsTokenResponseClient = oAuth2AccessTokenResponseClient;
    }

    public void setDefaultOAuth2AuthorizedClient(boolean z) {
        this.defaultOAuth2AuthorizedClient = z;
    }

    public void setDefaultClientRegistrationId(String str) {
        this.defaultClientRegistrationId = str;
    }

    public Consumer<WebClient.Builder> oauth2Configuration() {
        return builder -> {
            builder.defaultRequest(defaultRequest()).filter(this);
        };
    }

    public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() {
        return requestHeadersSpec -> {
            requestHeadersSpec.attributes(map -> {
                populateDefaultRequestResponse(map);
                populateDefaultAuthentication(map);
                populateDefaultOAuth2AuthorizedClient(map);
            });
        };
    }

    public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return map -> {
            if (oAuth2AuthorizedClient == null) {
                map.remove(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
            } else {
                map.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, oAuth2AuthorizedClient);
            }
        };
    }

    public static Consumer<Map<String, Object>> clientRegistrationId(String str) {
        return map -> {
            map.put(CLIENT_REGISTRATION_ID_ATTR_NAME, str);
        };
    }

    public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
        return map -> {
            map.put(AUTHENTICATION_ATTR_NAME, authentication);
        };
    }

    public static Consumer<Map<String, Object>> httpServletRequest(HttpServletRequest httpServletRequest) {
        return map -> {
            map.put(HTTP_SERVLET_REQUEST_ATTR_NAME, httpServletRequest);
        };
    }

    public static Consumer<Map<String, Object>> httpServletResponse(HttpServletResponse httpServletResponse) {
        return map -> {
            map.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, httpServletResponse);
        };
    }

    public void setAccessTokenExpiresSkew(Duration duration) {
        Assert.notNull(duration, "accessTokenExpiresSkew cannot be null");
        this.accessTokenExpiresSkew = duration;
    }

    public Mono<ClientResponse> filter(ClientRequest clientRequest, ExchangeFunction exchangeFunction) {
        Mono map = Mono.just(clientRequest).filter(clientRequest2 -> {
            return clientRequest2.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent();
        }).switchIfEmpty(mergeRequestAttributesFromContext(clientRequest)).filter(clientRequest3 -> {
            return clientRequest3.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent();
        }).flatMap(clientRequest4 -> {
            return authorizedClient(clientRequest4, exchangeFunction, getOAuth2AuthorizedClient(clientRequest4.attributes()));
        }).map(oAuth2AuthorizedClient -> {
            return bearer(clientRequest, oAuth2AuthorizedClient);
        });
        exchangeFunction.getClass();
        return map.flatMap(exchangeFunction::exchange).switchIfEmpty(exchangeFunction.exchange(clientRequest));
    }

    private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest clientRequest) {
        return Mono.just(ClientRequest.from(clientRequest)).flatMap(builder -> {
            return Mono.subscriberContext().map(context -> {
                return builder.attributes(map -> {
                    populateRequestAttributes(map, context);
                });
            });
        }).map((v0) -> {
            return v0.build();
        });
    }

    private void populateRequestAttributes(Map<String, Object> map, Context context) {
        if (context.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
            map.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, context.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
        }
        if (context.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
            map.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, context.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
        }
        if (context.hasKey(AUTHENTICATION_ATTR_NAME)) {
            map.putIfAbsent(AUTHENTICATION_ATTR_NAME, context.get(AUTHENTICATION_ATTR_NAME));
        }
        populateDefaultOAuth2AuthorizedClient(map);
    }

    private void populateDefaultRequestResponse(Map<String, Object> map) {
        if (map.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && map.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
            return;
        }
        ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        HttpServletRequest httpServletRequest = null;
        HttpServletResponse httpServletResponse = null;
        if (servletRequestAttributes != null) {
            httpServletRequest = servletRequestAttributes.getRequest();
            httpServletResponse = servletRequestAttributes.getResponse();
        }
        map.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, httpServletRequest);
        map.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, httpServletResponse);
    }

    private void populateDefaultAuthentication(Map<String, Object> map) {
        if (map.containsKey(AUTHENTICATION_ATTR_NAME)) {
            return;
        }
        map.putIfAbsent(AUTHENTICATION_ATTR_NAME, SecurityContextHolder.getContext().getAuthentication());
    }

    private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> map) {
        if (this.authorizedClientRepository == null || map.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
            return;
        }
        Authentication authentication = getAuthentication(map);
        String clientRegistrationId = getClientRegistrationId(map);
        if (clientRegistrationId == null) {
            clientRegistrationId = this.defaultClientRegistrationId;
        }
        if (clientRegistrationId == null && this.defaultOAuth2AuthorizedClient && (authentication instanceof OAuth2AuthenticationToken)) {
            clientRegistrationId = ((OAuth2AuthenticationToken) authentication).getAuthorizedClientRegistrationId();
        }
        if (clientRegistrationId != null) {
            OAuth2AuthorizedClient loadAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, getRequest(map));
            if (loadAuthorizedClient == null) {
                loadAuthorizedClient = getAuthorizedClient(clientRegistrationId, map);
            }
            oauth2AuthorizedClient(loadAuthorizedClient).accept(map);
        }
    }

    private OAuth2AuthorizedClient getAuthorizedClient(String str, Map<String, Object> map) {
        ClientRegistration findByRegistrationId = this.clientRegistrationRepository.findByRegistrationId(str);
        if (findByRegistrationId == null) {
            throw new IllegalArgumentException("Could not find ClientRegistration with id " + str);
        }
        if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(findByRegistrationId.getAuthorizationGrantType())) {
            return getAuthorizedClient(findByRegistrationId, map);
        }
        throw new ClientAuthorizationRequiredException(str);
    }

    private OAuth2AuthorizedClient getAuthorizedClient(ClientRegistration clientRegistration, Map<String, Object> map) {
        HttpServletRequest request = getRequest(map);
        HttpServletResponse response = getResponse(map);
        OAuth2AccessTokenResponse tokenResponse = this.clientCredentialsTokenResponseClient.getTokenResponse(new OAuth2ClientCredentialsGrantRequest(clientRegistration));
        Authentication authentication = getAuthentication(map);
        OAuth2AuthorizedClient oAuth2AuthorizedClient = new OAuth2AuthorizedClient(clientRegistration, authentication != null ? authentication.getName() : "anonymousUser", tokenResponse.getAccessToken());
        this.authorizedClientRepository.saveAuthorizedClient(oAuth2AuthorizedClient, authentication, request, response);
        return oAuth2AuthorizedClient;
    }

    private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest clientRequest, ExchangeFunction exchangeFunction, OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return shouldRefresh(oAuth2AuthorizedClient) ? refreshAuthorizedClient(clientRequest, exchangeFunction, oAuth2AuthorizedClient) : Mono.just(oAuth2AuthorizedClient);
    }

    private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ClientRequest clientRequest, ExchangeFunction exchangeFunction, OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        ClientRegistration clientRegistration = oAuth2AuthorizedClient.getClientRegistration();
        return exchangeFunction.exchange(ClientRequest.create(HttpMethod.POST, URI.create(clientRegistration.getProviderDetails().getTokenUri())).header("Accept", new String[]{"application/json"}).headers(httpHeaders -> {
            httpHeaders.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret());
        }).body(refreshTokenBody(oAuth2AuthorizedClient.getRefreshToken().getTokenValue())).build()).flatMap(clientResponse -> {
            return (Mono) clientResponse.body(OAuth2BodyExtractors.oauth2AccessTokenResponse());
        }).map(oAuth2AccessTokenResponse -> {
            return new OAuth2AuthorizedClient(oAuth2AuthorizedClient.getClientRegistration(), oAuth2AuthorizedClient.getPrincipalName(), oAuth2AccessTokenResponse.getAccessToken(), (OAuth2RefreshToken) Optional.ofNullable(oAuth2AccessTokenResponse.getRefreshToken()).orElse(oAuth2AuthorizedClient.getRefreshToken()));
        }).map(oAuth2AuthorizedClient2 -> {
            this.authorizedClientRepository.saveAuthorizedClient(oAuth2AuthorizedClient2, (Authentication) clientRequest.attribute(AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(oAuth2AuthorizedClient.getPrincipalName())), (HttpServletRequest) clientRequest.attributes().get(HTTP_SERVLET_REQUEST_ATTR_NAME), (HttpServletResponse) clientRequest.attributes().get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
            return oAuth2AuthorizedClient2;
        }).publishOn(Schedulers.elastic());
    }

    private boolean shouldRefresh(OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return (this.authorizedClientRepository == null || oAuth2AuthorizedClient.getRefreshToken() == null || !this.clock.instant().isAfter(oAuth2AuthorizedClient.getAccessToken().getExpiresAt().minus((TemporalAmount) this.accessTokenExpiresSkew))) ? false : true;
    }

    private ClientRequest bearer(ClientRequest clientRequest, OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return ClientRequest.from(clientRequest).headers(httpHeaders -> {
            httpHeaders.setBearerAuth(oAuth2AuthorizedClient.getAccessToken().getTokenValue());
        }).build();
    }

    private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> coreSubscriber) {
        HttpServletRequest httpServletRequest = null;
        HttpServletResponse httpServletResponse = null;
        ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        if (servletRequestAttributes != null) {
            httpServletRequest = servletRequestAttributes.getRequest();
            httpServletResponse = servletRequestAttributes.getResponse();
        }
        return new RequestContextSubscriber(coreSubscriber, httpServletRequest, httpServletResponse, SecurityContextHolder.getContext().getAuthentication());
    }

    private static BodyInserters.FormInserter<String> refreshTokenBody(String str) {
        return BodyInserters.fromFormData(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue()).with(OAuth2ParameterNames.REFRESH_TOKEN, str);
    }

    static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> map) {
        return (OAuth2AuthorizedClient) map.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
    }

    static String getClientRegistrationId(Map<String, Object> map) {
        return (String) map.get(CLIENT_REGISTRATION_ID_ATTR_NAME);
    }

    static Authentication getAuthentication(Map<String, Object> map) {
        return (Authentication) map.get(AUTHENTICATION_ATTR_NAME);
    }

    static HttpServletRequest getRequest(Map<String, Object> map) {
        return (HttpServletRequest) map.get(HTTP_SERVLET_REQUEST_ATTR_NAME);
    }

    static HttpServletResponse getResponse(Map<String, Object> map) {
        return (HttpServletResponse) map.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
    }
}
