/*
 * Decompiled with CFR 0.152.
 */
package org.mitre.jwt.encryption.service.impl;

import com.google.common.base.Strings;
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.crypto.DirectDecrypter;
import com.nimbusds.jose.crypto.DirectEncrypter;
import com.nimbusds.jose.crypto.RSADecrypter;
import com.nimbusds.jose.crypto.RSAEncrypter;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import javax.annotation.PostConstruct;
import org.mitre.jose.keystore.JWKSetKeyStore;
import org.mitre.jwt.encryption.service.JwtEncryptionAndDecryptionService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultJwtEncryptionAndDecryptionService
implements JwtEncryptionAndDecryptionService {
    private static Logger logger = LoggerFactory.getLogger(DefaultJwtEncryptionAndDecryptionService.class);
    private Map<String, JWEEncrypter> encrypters = new HashMap<String, JWEEncrypter>();
    private Map<String, JWEDecrypter> decrypters = new HashMap<String, JWEDecrypter>();
    private String defaultEncryptionKeyId;
    private String defaultDecryptionKeyId;
    private JWEAlgorithm defaultAlgorithm;
    private Map<String, JWK> keys = new HashMap<String, JWK>();

    public DefaultJwtEncryptionAndDecryptionService(Map<String, JWK> keys) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
        this.keys = keys;
        this.buildEncryptersAndDecrypters();
    }

    public DefaultJwtEncryptionAndDecryptionService(JWKSetKeyStore keyStore) throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
        for (JWK key : keyStore.getKeys()) {
            if (!Strings.isNullOrEmpty((String)key.getKeyID())) {
                this.keys.put(key.getKeyID(), key);
                continue;
            }
            throw new IllegalArgumentException("Tried to load a key from a keystore without a 'kid' field: " + key);
        }
        this.buildEncryptersAndDecrypters();
    }

    @PostConstruct
    public void afterPropertiesSet() {
        if (this.keys == null) {
            throw new IllegalArgumentException("Encryption and decryption service must have at least one key configured.");
        }
        try {
            this.buildEncryptersAndDecrypters();
        }
        catch (NoSuchAlgorithmException e) {
            throw new IllegalArgumentException("Encryption and decryption service could not find given algorithm.");
        }
        catch (InvalidKeySpecException e) {
            throw new IllegalArgumentException("Encryption and decryption service saw an invalid key specification.");
        }
        catch (JOSEException e) {
            throw new IllegalArgumentException("Encryption and decryption service was unable to process JOSE object.");
        }
    }

    public String getDefaultEncryptionKeyId() {
        if (this.defaultEncryptionKeyId != null) {
            return this.defaultEncryptionKeyId;
        }
        if (this.keys.size() == 1) {
            return this.keys.keySet().iterator().next();
        }
        return null;
    }

    public void setDefaultEncryptionKeyId(String defaultEncryptionKeyId) {
        this.defaultEncryptionKeyId = defaultEncryptionKeyId;
    }

    public String getDefaultDecryptionKeyId() {
        if (this.defaultDecryptionKeyId != null) {
            return this.defaultDecryptionKeyId;
        }
        if (this.keys.size() == 1) {
            return this.keys.keySet().iterator().next();
        }
        return null;
    }

    public void setDefaultDecryptionKeyId(String defaultDecryptionKeyId) {
        this.defaultDecryptionKeyId = defaultDecryptionKeyId;
    }

    public JWEAlgorithm getDefaultAlgorithm() {
        return this.defaultAlgorithm;
    }

    public void setDefaultAlgorithm(JWEAlgorithm defaultAlgorithm) {
        this.defaultAlgorithm = defaultAlgorithm;
    }

    @Override
    public void encryptJwt(JWEObject jwt) {
        if (this.getDefaultEncryptionKeyId() == null) {
            throw new IllegalStateException("Tried to call default encryption with no default encrypter ID set");
        }
        JWEEncrypter encrypter = this.encrypters.get(this.getDefaultEncryptionKeyId());
        try {
            jwt.encrypt(encrypter);
        }
        catch (JOSEException e) {
            logger.error("Failed to encrypt JWT, error was: ", (Throwable)e);
        }
    }

    @Override
    public void decryptJwt(JWEObject jwt) {
        if (this.getDefaultDecryptionKeyId() == null) {
            throw new IllegalStateException("Tried to call default decryption with no default decrypter ID set");
        }
        JWEDecrypter decrypter = this.decrypters.get(this.getDefaultDecryptionKeyId());
        try {
            jwt.decrypt(decrypter);
        }
        catch (JOSEException e) {
            logger.error("Failed to decrypt JWT, error was: ", (Throwable)e);
        }
    }

    private void buildEncryptersAndDecrypters() throws NoSuchAlgorithmException, InvalidKeySpecException, JOSEException {
        for (Map.Entry<String, JWK> jwkEntry : this.keys.entrySet()) {
            DirectDecrypter decrypter;
            RSAEncrypter encrypter;
            String id = jwkEntry.getKey();
            JWK jwk = jwkEntry.getValue();
            if (jwk instanceof RSAKey) {
                encrypter = new RSAEncrypter(((RSAKey)jwk).toRSAPublicKey());
                this.encrypters.put(id, (JWEEncrypter)encrypter);
                if (jwk.isPrivate()) {
                    decrypter = new RSADecrypter(((RSAKey)jwk).toRSAPrivateKey());
                    this.decrypters.put(id, (JWEDecrypter)decrypter);
                    continue;
                }
                logger.warn("No private key for key #" + jwk.getKeyID());
                continue;
            }
            if (jwk instanceof OctetSequenceKey) {
                encrypter = new DirectEncrypter(((OctetSequenceKey)jwk).toByteArray());
                decrypter = new DirectDecrypter(((OctetSequenceKey)jwk).toByteArray());
                this.encrypters.put(id, (JWEEncrypter)encrypter);
                this.decrypters.put(id, (JWEDecrypter)decrypter);
                continue;
            }
            logger.warn("Unknown key type: " + jwk);
        }
    }

    @Override
    public Map<String, JWK> getAllPublicKeys() {
        HashMap<String, JWK> pubKeys = new HashMap<String, JWK>();
        for (String keyId : this.keys.keySet()) {
            JWK key = this.keys.get(keyId);
            JWK pub = key.toPublicJWK();
            if (pub == null) continue;
            pubKeys.put(keyId, pub);
        }
        return pubKeys;
    }

    @Override
    public Collection<JWEAlgorithm> getAllEncryptionAlgsSupported() {
        HashSet<JWEAlgorithm> algs = new HashSet<JWEAlgorithm>();
        for (JWEEncrypter encrypter : this.encrypters.values()) {
            algs.addAll(encrypter.supportedAlgorithms());
        }
        for (JWEDecrypter decrypter : this.decrypters.values()) {
            algs.addAll(decrypter.supportedAlgorithms());
        }
        return algs;
    }

    @Override
    public Collection<EncryptionMethod> getAllEncryptionEncsSupported() {
        HashSet<EncryptionMethod> encs = new HashSet<EncryptionMethod>();
        for (JWEEncrypter encrypter : this.encrypters.values()) {
            encs.addAll(encrypter.supportedEncryptionMethods());
        }
        for (JWEDecrypter decrypter : this.decrypters.values()) {
            encs.addAll(decrypter.supportedEncryptionMethods());
        }
        return encs;
    }
}

