package org.mitre.oauth2.repository.impl;

import com.nimbusds.jwt.JWT;
import com.nimbusds.jwt.JWTParser;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import javax.persistence.EntityManager;
import javax.persistence.PersistenceContext;
import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaDelete;
import org.mitre.data.DefaultPageCriteria;
import org.mitre.data.PageCriteria;
import org.mitre.oauth2.model.ClientDetailsEntity;
import org.mitre.oauth2.model.OAuth2AccessTokenEntity;
import org.mitre.oauth2.model.OAuth2RefreshTokenEntity;
import org.mitre.oauth2.repository.OAuth2TokenRepository;
import org.mitre.openid.connect.model.ApprovedSite;
import org.mitre.openid.connect.view.UserInfoJWTView;
import org.mitre.uma.model.ResourceSet;
import org.mitre.util.jpa.JpaUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;

@Repository
/* loaded from: input_file:org/mitre/oauth2/repository/impl/JpaOAuth2TokenRepository.class */
public class JpaOAuth2TokenRepository implements OAuth2TokenRepository {
    private static final int MAXEXPIREDRESULTS = 1000;
    private static final Logger logger = LoggerFactory.getLogger(JpaOAuth2TokenRepository.class);

    @PersistenceContext(unitName = "defaultPersistenceUnit")
    private EntityManager manager;

    public Set<OAuth2AccessTokenEntity> getAllAccessTokens() {
        return new LinkedHashSet(this.manager.createNamedQuery("OAuth2AccessTokenEntity.getAll", OAuth2AccessTokenEntity.class).getResultList());
    }

    public Set<OAuth2RefreshTokenEntity> getAllRefreshTokens() {
        return new LinkedHashSet(this.manager.createNamedQuery("OAuth2RefreshTokenEntity.getAll", OAuth2RefreshTokenEntity.class).getResultList());
    }

    public OAuth2AccessTokenEntity getAccessTokenByValue(String str) {
        try {
            JWT parse = JWTParser.parse(str);
            TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByTokenValue", OAuth2AccessTokenEntity.class);
            createNamedQuery.setParameter("tokenValue", parse);
            return (OAuth2AccessTokenEntity) JpaUtil.getSingleResult(createNamedQuery.getResultList());
        } catch (ParseException e) {
            return null;
        }
    }

    public OAuth2AccessTokenEntity getAccessTokenById(Long l) {
        return (OAuth2AccessTokenEntity) this.manager.find(OAuth2AccessTokenEntity.class, l);
    }

    @Transactional("defaultTransactionManager")
    public OAuth2AccessTokenEntity saveAccessToken(OAuth2AccessTokenEntity oAuth2AccessTokenEntity) {
        return (OAuth2AccessTokenEntity) JpaUtil.saveOrUpdate(oAuth2AccessTokenEntity.getId(), this.manager, oAuth2AccessTokenEntity);
    }

    @Transactional("defaultTransactionManager")
    public void removeAccessToken(OAuth2AccessTokenEntity oAuth2AccessTokenEntity) {
        OAuth2AccessTokenEntity accessTokenById = getAccessTokenById(oAuth2AccessTokenEntity.getId());
        if (accessTokenById == null) {
            throw new IllegalArgumentException("Access token not found: " + oAuth2AccessTokenEntity);
        }
        this.manager.remove(accessTokenById);
    }

    @Transactional("defaultTransactionManager")
    public void clearAccessTokensForRefreshToken(OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByRefreshToken", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter("refreshToken", oAuth2RefreshTokenEntity);
        Iterator it = createNamedQuery.getResultList().iterator();
        while (it.hasNext()) {
            removeAccessToken((OAuth2AccessTokenEntity) it.next());
        }
    }

    public OAuth2RefreshTokenEntity getRefreshTokenByValue(String str) {
        try {
            JWT parse = JWTParser.parse(str);
            TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2RefreshTokenEntity.getByTokenValue", OAuth2RefreshTokenEntity.class);
            createNamedQuery.setParameter("tokenValue", parse);
            return (OAuth2RefreshTokenEntity) JpaUtil.getSingleResult(createNamedQuery.getResultList());
        } catch (ParseException e) {
            return null;
        }
    }

    public OAuth2RefreshTokenEntity getRefreshTokenById(Long l) {
        return (OAuth2RefreshTokenEntity) this.manager.find(OAuth2RefreshTokenEntity.class, l);
    }

    @Transactional("defaultTransactionManager")
    public OAuth2RefreshTokenEntity saveRefreshToken(OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity) {
        return (OAuth2RefreshTokenEntity) JpaUtil.saveOrUpdate(oAuth2RefreshTokenEntity.getId(), this.manager, oAuth2RefreshTokenEntity);
    }

    @Transactional("defaultTransactionManager")
    public void removeRefreshToken(OAuth2RefreshTokenEntity oAuth2RefreshTokenEntity) {
        OAuth2RefreshTokenEntity refreshTokenById = getRefreshTokenById(oAuth2RefreshTokenEntity.getId());
        if (refreshTokenById == null) {
            throw new IllegalArgumentException("Refresh token not found: " + oAuth2RefreshTokenEntity);
        }
        this.manager.remove(refreshTokenById);
    }

    @Transactional("defaultTransactionManager")
    public void clearTokensForClient(ClientDetailsEntity clientDetailsEntity) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByClient", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter(UserInfoJWTView.CLIENT, clientDetailsEntity);
        Iterator it = createNamedQuery.getResultList().iterator();
        while (it.hasNext()) {
            removeAccessToken((OAuth2AccessTokenEntity) it.next());
        }
        TypedQuery createNamedQuery2 = this.manager.createNamedQuery("OAuth2RefreshTokenEntity.getByClient", OAuth2RefreshTokenEntity.class);
        createNamedQuery2.setParameter(UserInfoJWTView.CLIENT, clientDetailsEntity);
        Iterator it2 = createNamedQuery2.getResultList().iterator();
        while (it2.hasNext()) {
            removeRefreshToken((OAuth2RefreshTokenEntity) it2.next());
        }
    }

    public List<OAuth2AccessTokenEntity> getAccessTokensForClient(ClientDetailsEntity clientDetailsEntity) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByClient", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter(UserInfoJWTView.CLIENT, clientDetailsEntity);
        return createNamedQuery.getResultList();
    }

    public List<OAuth2RefreshTokenEntity> getRefreshTokensForClient(ClientDetailsEntity clientDetailsEntity) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2RefreshTokenEntity.getByClient", OAuth2RefreshTokenEntity.class);
        createNamedQuery.setParameter(UserInfoJWTView.CLIENT, clientDetailsEntity);
        return createNamedQuery.getResultList();
    }

    public Set<OAuth2AccessTokenEntity> getAccessTokensByUserName(String str) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByName", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter("name", str);
        List resultList = createNamedQuery.getResultList();
        return resultList != null ? new HashSet(resultList) : new HashSet();
    }

    public Set<OAuth2RefreshTokenEntity> getRefreshTokensByUserName(String str) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2RefreshTokenEntity.getByName", OAuth2RefreshTokenEntity.class);
        createNamedQuery.setParameter("name", str);
        List resultList = createNamedQuery.getResultList();
        return resultList != null ? new HashSet(resultList) : new HashSet();
    }

    public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens() {
        return getAllExpiredAccessTokens(new DefaultPageCriteria(0, MAXEXPIREDRESULTS));
    }

    public Set<OAuth2AccessTokenEntity> getAllExpiredAccessTokens(PageCriteria pageCriteria) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getAllExpiredByDate", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter("date", new Date());
        return new LinkedHashSet(JpaUtil.getResultPage(createNamedQuery, pageCriteria));
    }

    public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens() {
        return getAllExpiredRefreshTokens(new DefaultPageCriteria(0, MAXEXPIREDRESULTS));
    }

    public Set<OAuth2RefreshTokenEntity> getAllExpiredRefreshTokens(PageCriteria pageCriteria) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2RefreshTokenEntity.getAllExpiredByDate", OAuth2RefreshTokenEntity.class);
        createNamedQuery.setParameter("date", new Date());
        return new LinkedHashSet(JpaUtil.getResultPage(createNamedQuery, pageCriteria));
    }

    public Set<OAuth2AccessTokenEntity> getAccessTokensForResourceSet(ResourceSet resourceSet) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByResourceSet", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter("rsid", resourceSet.getId());
        return new LinkedHashSet(createNamedQuery.getResultList());
    }

    @Transactional("defaultTransactionManager")
    public void clearDuplicateAccessTokens() {
        List<Object[]> resultList = this.manager.createQuery("select a.jwt, count(1) as c from OAuth2AccessTokenEntity a GROUP BY a.jwt HAVING count(1) > 1").getResultList();
        ArrayList arrayList = new ArrayList();
        for (Object[] objArr : resultList) {
            logger.warn("Found duplicate access tokens: {}, {}", ((JWT) objArr[0]).serialize(), objArr[1]);
            arrayList.add((JWT) objArr[0]);
        }
        if (arrayList.size() > 0) {
            CriteriaDelete createCriteriaDelete = this.manager.getCriteriaBuilder().createCriteriaDelete(OAuth2AccessTokenEntity.class);
            createCriteriaDelete.where(createCriteriaDelete.from(OAuth2AccessTokenEntity.class).get("jwt").in(arrayList));
            logger.warn("Deleted {} duplicate access tokens", Integer.valueOf(this.manager.createQuery(createCriteriaDelete).executeUpdate()));
        }
    }

    @Transactional("defaultTransactionManager")
    public void clearDuplicateRefreshTokens() {
        List<Object[]> resultList = this.manager.createQuery("select a.jwt, count(1) as c from OAuth2RefreshTokenEntity a GROUP BY a.jwt HAVING count(1) > 1").getResultList();
        ArrayList arrayList = new ArrayList();
        for (Object[] objArr : resultList) {
            logger.warn("Found duplicate refresh tokens: {}, {}", ((JWT) objArr[0]).serialize(), objArr[1]);
            arrayList.add((JWT) objArr[0]);
        }
        if (arrayList.size() > 0) {
            CriteriaDelete createCriteriaDelete = this.manager.getCriteriaBuilder().createCriteriaDelete(OAuth2RefreshTokenEntity.class);
            createCriteriaDelete.where(createCriteriaDelete.from(OAuth2RefreshTokenEntity.class).get("jwt").in(arrayList));
            logger.warn("Deleted {} duplicate refresh tokens", Integer.valueOf(this.manager.createQuery(createCriteriaDelete).executeUpdate()));
        }
    }

    public List<OAuth2AccessTokenEntity> getAccessTokensForApprovedSite(ApprovedSite approvedSite) {
        TypedQuery createNamedQuery = this.manager.createNamedQuery("OAuth2AccessTokenEntity.getByApprovedSite", OAuth2AccessTokenEntity.class);
        createNamedQuery.setParameter("approvedSite", approvedSite);
        return createNamedQuery.getResultList();
    }
}
