package org.apache.nifi.web.security.oidc.client.web;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Consumer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.nifi.components.state.Scope;
import org.apache.nifi.components.state.StateManager;
import org.apache.nifi.components.state.StateMap;
import org.apache.nifi.web.security.oidc.client.web.converter.AuthorizedClientConverter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.oidc.OidcIdToken;
import org.springframework.security.oauth2.core.oidc.user.OidcUser;

/* loaded from: input_file:org/apache/nifi/web/security/oidc/client/web/StandardOidcAuthorizedClientRepository.class */
public class StandardOidcAuthorizedClientRepository implements OAuth2AuthorizedClientRepository, TrackedAuthorizedClientRepository {
    private static final Logger logger = LoggerFactory.getLogger(StandardOidcAuthorizedClientRepository.class);
    private static final Scope SCOPE = Scope.LOCAL;
    private final StateManager stateManager;
    private final AuthorizedClientConverter authorizedClientConverter;

    public StandardOidcAuthorizedClientRepository(StateManager stateManager, AuthorizedClientConverter authorizedClientConverter) {
        this.stateManager = (StateManager) Objects.requireNonNull(stateManager, "State Manager required");
        this.authorizedClientConverter = (AuthorizedClientConverter) Objects.requireNonNull(authorizedClientConverter, "Authorized Client Converter required");
    }

    public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String str, Authentication authentication, HttpServletRequest httpServletRequest) {
        OidcAuthorizedClient decoded;
        String findEncoded = findEncoded(authentication);
        String principalId = getPrincipalId(authentication);
        if (findEncoded == null) {
            logger.debug("Identity [{}] OIDC Authorized Client not found", principalId);
            decoded = null;
        } else {
            decoded = this.authorizedClientConverter.getDecoded(findEncoded);
            if (decoded == null) {
                logger.warn("Identity [{}] Removing OIDC Authorized Client after decoding failed", principalId);
                removeAuthorizedClient(authentication);
            }
        }
        return decoded;
    }

    public void saveAuthorizedClient(OAuth2AuthorizedClient oAuth2AuthorizedClient, Authentication authentication, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
        String encoded = this.authorizedClientConverter.getEncoded(getOidcAuthorizedClient(oAuth2AuthorizedClient, authentication));
        String principalId = getPrincipalId(authentication);
        updateState(principalId, map -> {
        });
    }

    public void removeAuthorizedClient(String str, Authentication authentication, HttpServletRequest httpServletRequest, HttpServletResponse httpServletResponse) {
        removeAuthorizedClient(authentication);
    }

    @Override // org.apache.nifi.web.security.oidc.client.web.TrackedAuthorizedClientRepository
    public synchronized List<OidcAuthorizedClient> deleteExpired() {
        Map map = getStateMap().toMap();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        ArrayList arrayList = new ArrayList();
        for (Map.Entry entry : map.entrySet()) {
            String str = (String) entry.getKey();
            String str2 = (String) entry.getValue();
            try {
                OidcAuthorizedClient decoded = this.authorizedClientConverter.getDecoded(str2);
                if (isExpired(decoded)) {
                    arrayList.add(decoded);
                } else {
                    linkedHashMap.put(str, str2);
                }
            } catch (Exception e) {
                logger.warn("Decoding OIDC Authorized Client [{}] failed", str, e);
            }
        }
        setStateMap(linkedHashMap);
        if (arrayList.isEmpty()) {
            logger.debug("Expired Authorized Clients not found");
        } else {
            logger.debug("Deleted Expired Authorized Clients: State before contained [{}] and after [{}]", Integer.valueOf(map.size()), Integer.valueOf(linkedHashMap.size()));
        }
        return arrayList;
    }

    private boolean isExpired(OidcAuthorizedClient oidcAuthorizedClient) {
        Instant expiresAt = oidcAuthorizedClient.getAccessToken().getExpiresAt();
        return expiresAt == null || Instant.now().isAfter(expiresAt);
    }

    private void removeAuthorizedClient(Authentication authentication) {
        String principalId = getPrincipalId(authentication);
        updateState(principalId, map -> {
        });
    }

    private OidcAuthorizedClient getOidcAuthorizedClient(OAuth2AuthorizedClient oAuth2AuthorizedClient, Authentication authentication) {
        return new OidcAuthorizedClient(oAuth2AuthorizedClient.getClientRegistration(), oAuth2AuthorizedClient.getPrincipalName(), oAuth2AuthorizedClient.getAccessToken(), oAuth2AuthorizedClient.getRefreshToken(), getOidcIdToken(authentication));
    }

    private OidcIdToken getOidcIdToken(Authentication authentication) {
        if (!(authentication instanceof OAuth2AuthenticationToken)) {
            throw new IllegalArgumentException(String.format("OpenID Connect Authentication Token not found [%s]", authentication.getClass()));
        }
        OidcUser principal = ((OAuth2AuthenticationToken) authentication).getPrincipal();
        if (principal instanceof OidcUser) {
            return principal.getIdToken();
        }
        throw new IllegalArgumentException(String.format("OpenID Connect User not found [%s]", principal.getClass()));
    }

    private String findEncoded(Authentication authentication) {
        return getStateMap().get(getPrincipalId(authentication));
    }

    private String getPrincipalId(Authentication authentication) {
        return authentication.getName();
    }

    private synchronized void updateState(String str, Consumer<Map<String, String>> consumer) {
        boolean replace;
        try {
            StateMap stateMap = getStateMap();
            Map map = stateMap.toMap();
            LinkedHashMap linkedHashMap = new LinkedHashMap(map);
            consumer.accept(linkedHashMap);
            if (map.isEmpty()) {
                this.stateManager.setState(linkedHashMap, SCOPE);
                replace = true;
            } else {
                replace = this.stateManager.replace(stateMap, linkedHashMap, SCOPE);
            }
            if (replace) {
                logger.info("Identity [{}] OIDC Authorized Client update completed", str);
            } else {
                logger.info("Identity [{}] OIDC Authorized Client update failed", str);
            }
        } catch (Exception e) {
            logger.warn("Identity [{}] OIDC Authorized Client update processing failed", str, e);
        }
    }

    private void setStateMap(Map<String, String> map) {
        try {
            this.stateManager.setState(map, SCOPE);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private StateMap getStateMap() {
        try {
            return this.stateManager.getState(SCOPE);
        } catch (IOException e) {
            throw new UncheckedIOException("Get State for OIDC Authorized Clients failed", e);
        }
    }
}
