package org.openmetadata.service.security;

import com.nimbusds.jose.proc.BadJOSEException;
import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.proc.BadJWTException;
import com.nimbusds.oauth2.sdk.AuthorizationCode;
import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant;
import com.nimbusds.oauth2.sdk.AuthorizationGrant;
import com.nimbusds.oauth2.sdk.ErrorObject;
import com.nimbusds.oauth2.sdk.TokenErrorResponse;
import com.nimbusds.oauth2.sdk.TokenRequest;
import com.nimbusds.oauth2.sdk.auth.ClientAuthentication;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import com.nimbusds.oauth2.sdk.id.ClientID;
import com.nimbusds.oauth2.sdk.id.State;
import com.nimbusds.oauth2.sdk.pkce.CodeVerifier;
import com.nimbusds.oauth2.sdk.token.AccessToken;
import com.nimbusds.openid.connect.sdk.AuthenticationErrorResponse;
import com.nimbusds.openid.connect.sdk.AuthenticationResponseParser;
import com.nimbusds.openid.connect.sdk.AuthenticationSuccessResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponse;
import com.nimbusds.openid.connect.sdk.OIDCTokenResponseParser;
import com.nimbusds.openid.connect.sdk.op.OIDCProviderMetadata;
import com.nimbusds.openid.connect.sdk.token.OIDCTokens;
import com.nimbusds.openid.connect.sdk.validators.BadJWTExceptions;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.text.ParseException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.pac4j.core.exception.TechnicalException;
import org.pac4j.core.util.CommonHelper;
import org.pac4j.oidc.client.OidcClient;
import org.pac4j.oidc.credentials.OidcCredentials;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@WebServlet({"/callback"})
/* loaded from: input_file:org/openmetadata/service/security/AuthCallbackServlet.class */
public class AuthCallbackServlet extends HttpServlet {
    private static final Logger LOG = LoggerFactory.getLogger(AuthCallbackServlet.class);
    private final OidcClient client;
    private final ClientAuthentication clientAuthentication;
    private final List<String> claimsOrder;
    private final String serverUrl;

    public AuthCallbackServlet(OidcClient oidcClient, String str, List<String> list) {
        CommonHelper.assertNotBlank("ServerUrl", str);
        this.client = oidcClient;
        this.claimsOrder = list;
        this.serverUrl = str;
        this.clientAuthentication = SecurityUtil.getClientAuthentication(this.client.getConfiguration());
    }

    protected void doGet(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
        try {
            LOG.debug("Performing Auth Callback For User Session: {} ", httpServletRequest.getSession().getId());
            String callbackUrl = this.client.getCallbackUrl();
            AuthenticationErrorResponse parse = AuthenticationResponseParser.parse(new URI(callbackUrl), retrieveParameters(httpServletRequest));
            if (parse instanceof AuthenticationErrorResponse) {
                LOG.error("Bad authentication response, error={}", parse.getErrorObject());
                throw new TechnicalException("Bad authentication response");
            }
            LOG.debug("Authentication response successful");
            AuthenticationSuccessResponse authenticationSuccessResponse = (AuthenticationSuccessResponse) parse;
            OIDCProviderMetadata providerMetadata = this.client.getConfiguration().getProviderMetadata();
            if (providerMetadata.supportsAuthorizationResponseIssuerParam() && !providerMetadata.getIssuer().equals(authenticationSuccessResponse.getIssuer())) {
                throw new TechnicalException("Issuer mismatch, possible mix-up attack.");
            }
            validateStateIfRequired(httpServletRequest, httpServletResponse, authenticationSuccessResponse);
            OidcCredentials buildCredentials = buildCredentials(authenticationSuccessResponse);
            validateAndSendTokenRequest(httpServletRequest, buildCredentials, callbackUrl);
            if (buildCredentials.getRefreshToken() == null) {
                LOG.error("Refresh token is null for user session: {}", httpServletRequest.getSession().getId());
            }
            validateNonceIfRequired(httpServletRequest, buildCredentials.getIdToken().getJWTClaimsSet());
            httpServletRequest.getSession().setAttribute(AuthLoginServlet.OIDC_CREDENTIAL_PROFILE, buildCredentials);
            SecurityUtil.sendRedirectWithToken(httpServletResponse, buildCredentials, this.serverUrl, this.claimsOrder);
        } catch (Exception e) {
            SecurityUtil.getErrorMessage(httpServletResponse, e);
        }
    }

    private OidcCredentials buildCredentials(AuthenticationSuccessResponse authenticationSuccessResponse) {
        OidcCredentials oidcCredentials = new OidcCredentials();
        AuthorizationCode authorizationCode = authenticationSuccessResponse.getAuthorizationCode();
        if (authorizationCode != null) {
            oidcCredentials.setCode(authorizationCode);
        }
        JWT iDToken = authenticationSuccessResponse.getIDToken();
        if (iDToken != null) {
            oidcCredentials.setIdToken(iDToken);
        }
        AccessToken accessToken = authenticationSuccessResponse.getAccessToken();
        if (accessToken != null) {
            oidcCredentials.setAccessToken(accessToken);
        }
        return oidcCredentials;
    }

    private void validateNonceIfRequired(HttpServletRequest httpServletRequest, JWTClaimsSet jWTClaimsSet) throws BadJOSEException {
        if (this.client.getConfiguration().isUseNonce()) {
            String str = (String) httpServletRequest.getSession().getAttribute(this.client.getNonceSessionAttributeName());
            if (!CommonHelper.isNotBlank(str)) {
                throw new TechnicalException("Missing nonce parameter from Session.");
            }
            try {
                String stringClaim = jWTClaimsSet.getStringClaim("nonce");
                if (stringClaim == null) {
                    throw BadJWTExceptions.MISSING_NONCE_CLAIM_EXCEPTION;
                }
                if (!str.equals(stringClaim)) {
                    throw new BadJWTException("Unexpected JWT nonce (nonce) claim: " + stringClaim);
                }
            } catch (ParseException e) {
                throw new BadJWTException("Invalid JWT nonce (nonce) claim: " + e.getMessage());
            }
        }
    }

    private void validateStateIfRequired(HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse, AuthenticationSuccessResponse authenticationSuccessResponse) {
        if (this.client.getConfiguration().isWithState()) {
            State state = (State) httpServletRequest.getSession().getAttribute(this.client.getStateSessionAttributeName());
            if (state == null || CommonHelper.isBlank(state.getValue())) {
                SecurityUtil.getErrorMessage(httpServletResponse, new TechnicalException("Missing state parameter"));
                return;
            }
            State state2 = authenticationSuccessResponse.getState();
            if (state2 == null) {
                throw new TechnicalException("Missing state parameter");
            }
            LOG.debug("Request state: {}/response state: {}", state, state2);
            if (!state.equals(state2)) {
                throw new TechnicalException("State parameter is different from the one sent in authentication request.");
            }
        }
    }

    private void validateAndSendTokenRequest(HttpServletRequest httpServletRequest, OidcCredentials oidcCredentials, String str) throws IOException, com.nimbusds.oauth2.sdk.ParseException, URISyntaxException {
        if (oidcCredentials.getCode() != null) {
            LOG.debug("Initiating Token Request for User Session: {} ", httpServletRequest.getSession().getId());
            executeTokenRequest(createTokenRequest(new AuthorizationCodeGrant(oidcCredentials.getCode(), new URI(str), (CodeVerifier) httpServletRequest.getSession().getAttribute(this.client.getCodeVerifierSessionAttributeName()))), oidcCredentials);
        }
    }

    protected Map<String, List<String>> retrieveParameters(HttpServletRequest httpServletRequest) {
        Map parameterMap = httpServletRequest.getParameterMap();
        HashMap hashMap = new HashMap();
        for (Map.Entry entry : parameterMap.entrySet()) {
            hashMap.put((String) entry.getKey(), Arrays.asList((String[]) entry.getValue()));
        }
        return hashMap;
    }

    protected TokenRequest createTokenRequest(AuthorizationGrant authorizationGrant) {
        return this.client.getConfiguration().getClientAuthenticationMethod() != null ? new TokenRequest(this.client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), this.clientAuthentication, authorizationGrant) : new TokenRequest(this.client.getConfiguration().findProviderMetadata().getTokenEndpointURI(), new ClientID(this.client.getConfiguration().getClientId()), authorizationGrant);
    }

    private void executeTokenRequest(TokenRequest tokenRequest, OidcCredentials oidcCredentials) throws IOException, com.nimbusds.oauth2.sdk.ParseException {
        HTTPRequest hTTPRequest = tokenRequest.toHTTPRequest();
        this.client.getConfiguration().configureHttpRequest(hTTPRequest);
        HTTPResponse send = hTTPRequest.send();
        LOG.debug("Token response: status={}, content={}", Integer.valueOf(send.getStatusCode()), send.getContent());
        TokenErrorResponse parse = OIDCTokenResponseParser.parse(send);
        if (parse instanceof TokenErrorResponse) {
            ErrorObject errorObject = parse.getErrorObject();
            throw new TechnicalException("Bad token response, error=" + errorObject.getCode() + ", description=" + errorObject.getDescription());
        }
        LOG.debug("Token response successful");
        OIDCTokens oIDCTokens = ((OIDCTokenResponse) parse).getOIDCTokens();
        oidcCredentials.setAccessToken(oIDCTokens.getAccessToken());
        oidcCredentials.setRefreshToken(oIDCTokens.getRefreshToken());
        if (oIDCTokens.getIDToken() != null) {
            oidcCredentials.setIdToken(oIDCTokens.getIDToken());
        }
    }
}
