package org.springframework.security.oauth2.client.filter;

import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Map;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.security.oauth2.client.UserRedirectRequiredException;
import org.springframework.security.oauth2.client.context.OAuth2ClientContext;
import org.springframework.security.oauth2.client.context.OAuth2ClientContextHolder;
import org.springframework.security.oauth2.client.filter.cache.AccessTokenCache;
import org.springframework.security.oauth2.client.filter.cache.HttpSessionAccessTokenCache;
import org.springframework.security.oauth2.client.filter.state.HttpSessionStatePersistenceServices;
import org.springframework.security.oauth2.client.filter.state.StatePersistenceServices;
import org.springframework.security.oauth2.client.http.AccessTokenRequiredException;
import org.springframework.security.oauth2.client.resource.OAuth2AccessDeniedException;
import org.springframework.security.oauth2.client.resource.OAuth2ProtectedResourceDetails;
import org.springframework.security.oauth2.client.token.AccessTokenProvider;
import org.springframework.security.oauth2.client.token.AccessTokenProviderChain;
import org.springframework.security.oauth2.client.token.AccessTokenRequest;
import org.springframework.security.oauth2.client.token.grant.client.ClientCredentialsAccessTokenProvider;
import org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider;
import org.springframework.security.oauth2.common.DefaultThrowableAnalyzer;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidRequestException;
import org.springframework.security.web.DefaultRedirectStrategy;
import org.springframework.security.web.PortResolver;
import org.springframework.security.web.PortResolverImpl;
import org.springframework.security.web.RedirectStrategy;
import org.springframework.security.web.util.ThrowableAnalyzer;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;

/* loaded from: input_file:org/springframework/security/oauth2/client/filter/OAuth2ClientContextFilter.class */
public class OAuth2ClientContextFilter implements Filter, InitializingBean {
    private AccessTokenProvider accessTokenProvider = new AccessTokenProviderChain(Arrays.asList(new AuthorizationCodeAccessTokenProvider(), new ClientCredentialsAccessTokenProvider()));
    private AccessTokenCache tokenCache = new HttpSessionAccessTokenCache();
    private StatePersistenceServices statePersistenceServices = new HttpSessionStatePersistenceServices();
    private PortResolver portResolver = new PortResolverImpl();
    private ThrowableAnalyzer throwableAnalyzer = new DefaultThrowableAnalyzer();
    private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
    private boolean redirectOnError = false;

    public void afterPropertiesSet() throws Exception {
        Assert.notNull(this.accessTokenProvider, "An OAuth2 access token provider must be supplied.");
        Assert.notNull(this.tokenCache, "TokenCacheServices must be supplied.");
        Assert.notNull(this.redirectStrategy, "A redirect strategy must be supplied.");
    }

    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) servletRequest;
        HttpServletResponse httpServletResponse = (HttpServletResponse) servletResponse;
        OAuth2ClientContext oAuth2ClientContext = new OAuth2ClientContext(this.tokenCache.loadRememberedTokens(httpServletRequest, httpServletResponse));
        OAuth2ClientContextHolder.setContext(oAuth2ClientContext);
        try {
            try {
                filterChain.doFilter(servletRequest, servletResponse);
            } catch (Exception e) {
                OAuth2ProtectedResourceDetails checkForResourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e);
                oAuth2ClientContext.removeAccessToken(checkForResourceThatNeedsAuthorization);
                AccessTokenRequest accessTokenRequest = new AccessTokenRequest((Map<String, String[]>) httpServletRequest.getParameterMap());
                accessTokenRequest.setCurrentUri(calculateCurrentUri(httpServletRequest));
                String parameter = httpServletRequest.getParameter("state");
                if (parameter != null) {
                    Object loadPreservedState = this.statePersistenceServices.loadPreservedState(parameter, httpServletRequest, httpServletResponse);
                    if (loadPreservedState == null) {
                        throw new InvalidRequestException("Possible CSRF detected - state parameter was present but no state could be found");
                    }
                    accessTokenRequest.setPreservedState(loadPreservedState);
                }
                while (!oAuth2ClientContext.containsResource(checkForResourceThatNeedsAuthorization)) {
                    OAuth2AccessToken accessToken = oAuth2ClientContext.getAccessToken(checkForResourceThatNeedsAuthorization);
                    if (accessToken != null) {
                        accessTokenRequest.setExistingToken(accessToken);
                    }
                    try {
                        OAuth2AccessToken obtainAccessToken = this.accessTokenProvider.obtainAccessToken(checkForResourceThatNeedsAuthorization, accessTokenRequest);
                        if (obtainAccessToken == null) {
                            throw new IllegalStateException("Access token manager returned a null access token, which is illegal according to the contract.");
                        }
                        oAuth2ClientContext.addAccessToken(checkForResourceThatNeedsAuthorization, obtainAccessToken);
                        try {
                            if (httpServletResponse.isCommitted() || this.redirectOnError) {
                                String servletPath = httpServletRequest.getServletPath();
                                if (httpServletRequest.getQueryString() != null) {
                                    servletPath = servletPath + "?" + httpServletRequest.getQueryString();
                                }
                                this.redirectStrategy.sendRedirect(httpServletRequest, httpServletResponse, servletPath);
                            } else {
                                filterChain.doFilter(httpServletRequest, httpServletResponse);
                            }
                        } catch (Exception e2) {
                            checkForResourceThatNeedsAuthorization = checkForResourceThatNeedsAuthorization(e2);
                            oAuth2ClientContext.removeAccessToken(checkForResourceThatNeedsAuthorization);
                        }
                    } catch (UserRedirectRequiredException e3) {
                        redirectUser(checkForResourceThatNeedsAuthorization, e3, httpServletRequest, httpServletResponse);
                        OAuth2ClientContextHolder.clearContext();
                        this.tokenCache.rememberTokens(oAuth2ClientContext.getNewAccessTokens(), httpServletRequest, httpServletResponse);
                        return;
                    }
                    OAuth2ClientContextHolder.clearContext();
                    this.tokenCache.rememberTokens(oAuth2ClientContext.getNewAccessTokens(), httpServletRequest, httpServletResponse);
                    throw th;
                }
            }
            OAuth2ClientContextHolder.clearContext();
            this.tokenCache.rememberTokens(oAuth2ClientContext.getNewAccessTokens(), httpServletRequest, httpServletResponse);
        } catch (Throwable th) {
            OAuth2ClientContextHolder.clearContext();
            this.tokenCache.rememberTokens(oAuth2ClientContext.getNewAccessTokens(), httpServletRequest, httpServletResponse);
            throw th;
        }
    }

    protected void redirectUser(OAuth2ProtectedResourceDetails oAuth2ProtectedResourceDetails, UserRedirectRequiredException userRedirectRequiredException, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) throws IOException {
        String redirectUri = userRedirectRequiredException.getRedirectUri();
        StringBuilder sb = new StringBuilder(redirectUri);
        Map<String, String> requestParams = userRedirectRequiredException.getRequestParams();
        char c = redirectUri.indexOf(63) < 0 ? '?' : '&';
        for (Map.Entry<String, String> entry : requestParams.entrySet()) {
            try {
                sb.append(c).append(entry.getKey()).append('=').append(URLEncoder.encode(entry.getValue(), "UTF-8"));
                c = '&';
            } catch (UnsupportedEncodingException e) {
                throw new IllegalStateException(e);
            }
        }
        if (userRedirectRequiredException.getStateKey() != null) {
            sb.append(c).append("state").append('=').append(userRedirectRequiredException.getStateKey());
            Object stateToPreserve = userRedirectRequiredException.getStateToPreserve();
            if (stateToPreserve == null) {
                stateToPreserve = "state";
            }
            this.statePersistenceServices.preserveState(userRedirectRequiredException.getStateKey(), stateToPreserve, httpServletRequest, httpServletResponse);
        }
        this.redirectStrategy.sendRedirect(httpServletRequest, httpServletResponse, sb.toString());
    }

    protected OAuth2ProtectedResourceDetails checkForResourceThatNeedsAuthorization(Exception exc) throws ServletException, IOException {
        AccessTokenRequiredException accessTokenRequiredException = (AccessTokenRequiredException) this.throwableAnalyzer.getFirstThrowableOfType(AccessTokenRequiredException.class, this.throwableAnalyzer.determineCauseChain(exc));
        if (accessTokenRequiredException != null) {
            OAuth2ProtectedResourceDetails resource = accessTokenRequiredException.getResource();
            if (resource == null) {
                throw new OAuth2AccessDeniedException(accessTokenRequiredException.getMessage());
            }
            return resource;
        }
        if (exc instanceof ServletException) {
            throw ((ServletException) exc);
        }
        if (exc instanceof IOException) {
            throw ((IOException) exc);
        }
        if (exc instanceof RuntimeException) {
            throw ((RuntimeException) exc);
        }
        throw new RuntimeException(exc);
    }

    protected String calculateCurrentUri(HttpServletRequest httpServletRequest) throws UnsupportedEncodingException {
        StringBuilder sb = new StringBuilder();
        Enumeration parameterNames = httpServletRequest.getParameterNames();
        while (parameterNames.hasMoreElements()) {
            String str = (String) parameterNames.nextElement();
            if (!"code".equals(str)) {
                String[] parameterValues = httpServletRequest.getParameterValues(str);
                if (parameterValues.length == 0) {
                    sb.append(URLEncoder.encode(str, "UTF-8"));
                } else {
                    for (int i = 0; i < parameterValues.length; i++) {
                        sb.append(URLEncoder.encode(str, "UTF-8")).append('=').append(URLEncoder.encode(parameterValues[i], "UTF-8"));
                        if (i + 1 < parameterValues.length) {
                            sb.append('&');
                        }
                    }
                }
            }
            if (parameterNames.hasMoreElements() && sb.length() > 0) {
                sb.append('&');
            }
        }
        return UrlUtils.buildFullRequestUrl(httpServletRequest.getScheme(), httpServletRequest.getServerName(), this.portResolver.getServerPort(httpServletRequest), httpServletRequest.getRequestURI(), sb.length() > 0 ? sb.toString() : null);
    }

    public void init(FilterConfig filterConfig) throws ServletException {
    }

    public void destroy() {
    }

    public void setAccessTokenProvider(AccessTokenProvider accessTokenProvider) {
        this.accessTokenProvider = accessTokenProvider;
    }

    public void setClientTokenCache(AccessTokenCache accessTokenCache) {
        this.tokenCache = accessTokenCache;
    }

    public void setStatePersistenceServices(StatePersistenceServices statePersistenceServices) {
        this.statePersistenceServices = statePersistenceServices;
    }

    public void setThrowableAnalyzer(ThrowableAnalyzer throwableAnalyzer) {
        this.throwableAnalyzer = throwableAnalyzer;
    }

    public void setPortResolver(PortResolver portResolver) {
        this.portResolver = portResolver;
    }

    public void setRedirectStrategy(RedirectStrategy redirectStrategy) {
        this.redirectStrategy = redirectStrategy;
    }

    public void setRedirectOnError(boolean z) {
        this.redirectOnError = z;
    }
}
