/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.web.security.oidc.web.authentication;

import java.net.URI;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.nifi.admin.service.IdpUserGroupService;
import org.apache.nifi.authorization.util.IdentityMapping;
import org.apache.nifi.authorization.util.IdentityMappingUtil;
import org.apache.nifi.idp.IdpType;
import org.apache.nifi.web.security.cookie.ApplicationCookieName;
import org.apache.nifi.web.security.cookie.ApplicationCookieService;
import org.apache.nifi.web.security.cookie.StandardApplicationCookieService;
import org.apache.nifi.web.security.jwt.provider.BearerTokenProvider;
import org.apache.nifi.web.security.oidc.OidcConfigurationException;
import org.apache.nifi.web.security.token.LoginAuthenticationToken;
import org.apache.nifi.web.util.RequestUriBuilder;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;

public class OidcAuthenticationSuccessHandler
extends SimpleUrlAuthenticationSuccessHandler {
    private static final String UI_PATH = "/nifi/";
    private static final String ROOT_PATH = "/";
    private final ApplicationCookieService applicationCookieService = new StandardApplicationCookieService();
    private final BearerTokenProvider bearerTokenProvider;
    private final IdpUserGroupService idpUserGroupService;
    private final List<IdentityMapping> userIdentityMappings;
    private final List<IdentityMapping> groupIdentityMappings;
    private final List<String> userClaimNames;
    private final String groupsClaimName;

    public OidcAuthenticationSuccessHandler(BearerTokenProvider bearerTokenProvider, IdpUserGroupService idpUserGroupService, List<IdentityMapping> userIdentityMappings, List<IdentityMapping> groupIdentityMappings, List<String> userClaimNames, String groupsClaimName) {
        this.bearerTokenProvider = Objects.requireNonNull(bearerTokenProvider, "Bearer Token Provider required");
        this.idpUserGroupService = Objects.requireNonNull(idpUserGroupService, "User Group Service required");
        this.userIdentityMappings = Objects.requireNonNull(userIdentityMappings, "User Identity Mappings required");
        this.groupIdentityMappings = Objects.requireNonNull(groupIdentityMappings, "Group Identity Mappings required");
        this.userClaimNames = Objects.requireNonNull(userClaimNames, "User Claim Names required");
        this.groupsClaimName = groupsClaimName;
    }

    public String determineTargetUrl(HttpServletRequest request, HttpServletResponse response, Authentication authentication) {
        URI resourceUri = RequestUriBuilder.fromHttpServletRequest((HttpServletRequest)request).path(ROOT_PATH).build();
        this.processAuthentication(response, authentication, resourceUri);
        URI targetUri = RequestUriBuilder.fromHttpServletRequest((HttpServletRequest)request).path(UI_PATH).build();
        return targetUri.toString();
    }

    private void processAuthentication(HttpServletResponse response, Authentication authentication, URI resourceUri) {
        OAuth2AuthenticationToken authenticationToken = this.getAuthenticationToken(authentication);
        OidcUser oidcUser = this.getOidcUser(authenticationToken);
        String identity = this.getIdentity(oidcUser);
        Set<String> groups = this.getGroups(oidcUser);
        this.idpUserGroupService.replaceUserGroups(identity, IdpType.OIDC, groups);
        OAuth2AccessToken accessToken = this.getAccessToken(authenticationToken);
        String bearerToken = this.getBearerToken(identity, oidcUser, accessToken);
        this.applicationCookieService.addSessionCookie(resourceUri, response, ApplicationCookieName.AUTHORIZATION_BEARER, bearerToken);
    }

    private String getBearerToken(String identity, OidcUser oidcUser, OAuth2AccessToken accessToken) {
        long sessionExpiration = this.getSessionExpiration((OAuth2Token)accessToken);
        String issuer = oidcUser.getIssuer().toString();
        LoginAuthenticationToken loginAuthenticationToken = new LoginAuthenticationToken(identity, identity, sessionExpiration, issuer);
        return this.bearerTokenProvider.getBearerToken(loginAuthenticationToken);
    }

    private long getSessionExpiration(OAuth2Token token) {
        Instant tokenExpiration = token.getExpiresAt();
        if (tokenExpiration == null) {
            throw new IllegalArgumentException("Token expiration claim not found");
        }
        Instant tokenIssued = token.getIssuedAt();
        if (tokenIssued == null) {
            throw new IllegalArgumentException("Token issued claim not found");
        }
        Duration expiration = Duration.between(tokenIssued, tokenExpiration);
        return expiration.toMillis();
    }

    private OAuth2AuthenticationToken getAuthenticationToken(Authentication authentication) {
        if (authentication instanceof OAuth2AuthenticationToken) {
            return (OAuth2AuthenticationToken)authentication;
        }
        String message = String.format("OAuth2AuthenticationToken not found [%s]", authentication.getClass());
        throw new IllegalArgumentException(message);
    }

    private OAuth2AccessToken getAccessToken(OAuth2AuthenticationToken authenticationToken) {
        Object credentials = authenticationToken.getCredentials();
        if (credentials instanceof OAuth2AccessToken) {
            return (OAuth2AccessToken)credentials;
        }
        String message = String.format("OAuth2AccessToken not found in credentials [%s]", credentials.getClass());
        throw new IllegalArgumentException(message);
    }

    private OidcUser getOidcUser(OAuth2AuthenticationToken authenticationToken) {
        OAuth2User principalUser = authenticationToken.getPrincipal();
        if (principalUser instanceof OidcUser) {
            return (OidcUser)principalUser;
        }
        String message = String.format("OpenID Connect User not found [%s]", principalUser.getClass());
        throw new IllegalArgumentException(message);
    }

    private String getIdentity(OidcUser oidcUser) {
        Optional<String> userNameFound = this.userClaimNames.stream().map(arg_0 -> ((OidcUser)oidcUser).getClaimAsString(arg_0)).filter(Objects::nonNull).findFirst();
        String identity = userNameFound.orElseThrow(() -> {
            String message = String.format("User Identity not found in Token Claims %s", this.userClaimNames);
            return new OidcConfigurationException(message);
        });
        return IdentityMappingUtil.mapIdentity((String)identity, this.userIdentityMappings);
    }

    private Set<String> getGroups(OidcUser oidcUser) {
        Set<String> groups;
        if (this.groupsClaimName == null || this.groupsClaimName.isEmpty()) {
            groups = Collections.emptySet();
        } else {
            List groupsFound = oidcUser.getClaimAsStringList(this.groupsClaimName);
            List claimGroups = groupsFound == null ? Collections.emptyList() : groupsFound;
            groups = claimGroups.stream().map(group -> IdentityMappingUtil.mapIdentity((String)group, this.groupIdentityMappings)).collect(Collectors.toSet());
        }
        return groups;
    }
}

