/*
 * Decompiled with CFR 0.152.
 */
package org.jgroups.protocols;

import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
import java.util.zip.Adler32;
import java.util.zip.CRC32;
import java.util.zip.Checksum;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import org.jgroups.Address;
import org.jgroups.Event;
import org.jgroups.Message;
import org.jgroups.View;
import org.jgroups.annotations.ManagedAttribute;
import org.jgroups.annotations.Property;
import org.jgroups.protocols.EncryptHeader;
import org.jgroups.stack.Protocol;
import org.jgroups.util.AsciiString;
import org.jgroups.util.Bits;
import org.jgroups.util.Buffer;
import org.jgroups.util.MessageBatch;
import org.jgroups.util.Util;

public abstract class EncryptBase
extends Protocol {
    protected static final String DEFAULT_SYM_ALGO = "AES";
    @Property(description="Cryptographic Service Provider")
    protected String provider;
    @Property(description="Cipher engine transformation for asymmetric algorithm. Default is RSA")
    protected String asym_algorithm = "RSA";
    @Property(description="Cipher engine transformation for symmetric algorithm. Default is AES")
    protected String sym_algorithm = "AES";
    @Property(description="Initial public/private key length. Default is 512")
    protected int asym_keylength = 512;
    @Property(description="Initial key length for matching symmetric algorithm. Default is 128")
    protected int sym_keylength = 128;
    @Property(description="Number of ciphers in the pool to parallelize encrypt and decrypt requests", writable=false)
    protected int cipher_pool_size = 8;
    @Property(description="If true, the entire message (including payload and headers) is encrypted, else only the payload")
    protected boolean encrypt_entire_message = true;
    @Property(description="If true, all messages are digitally signed by adding an encrypted checksum of the encrypted message to the header. Ignored if encrypt_entire_message is false")
    protected boolean sign_msgs = true;
    @Property(description="When sign_msgs is true, by default CRC32 is used to create the checksum. If use_adler is true, Adler32 will be used")
    protected boolean use_adler;
    protected volatile Address local_addr;
    protected volatile View view;
    protected BlockingQueue<Cipher> encoding_ciphers;
    protected BlockingQueue<Cipher> decoding_ciphers;
    protected volatile byte[] sym_version;
    protected volatile SecretKey secret_key;
    protected final Map<AsciiString, Cipher> key_map = new WeakHashMap<AsciiString, Cipher>();

    public int asymKeylength() {
        return this.asym_keylength;
    }

    public <T extends EncryptBase> T asymKeylength(int len) {
        this.asym_keylength = len;
        return (T)this;
    }

    public int symKeylength() {
        return this.sym_keylength;
    }

    public <T extends EncryptBase> T symKeylength(int len) {
        this.sym_keylength = len;
        return (T)this;
    }

    public SecretKey secretKey() {
        return this.secret_key;
    }

    public <T extends EncryptBase> T secretKey(SecretKey key) {
        this.secret_key = key;
        return (T)this;
    }

    public String symAlgorithm() {
        return this.sym_algorithm;
    }

    public <T extends EncryptBase> T symAlgorithm(String alg) {
        this.sym_algorithm = alg;
        return (T)this;
    }

    public String asymAlgorithm() {
        return this.asym_algorithm;
    }

    public <T extends EncryptBase> T asymAlgorithm(String alg) {
        this.asym_algorithm = alg;
        return (T)this;
    }

    public byte[] symVersion() {
        return this.sym_version;
    }

    public <T extends EncryptBase> T symVersion(byte[] v) {
        this.sym_version = Arrays.copyOf(v, v.length);
        return (T)this;
    }

    public <T extends EncryptBase> T localAddress(Address addr) {
        this.local_addr = addr;
        return (T)this;
    }

    public boolean encryptEntireMessage() {
        return this.encrypt_entire_message;
    }

    public <T extends EncryptBase> T encryptEntireMessage(boolean b) {
        this.encrypt_entire_message = b;
        return (T)this;
    }

    public boolean signMessages() {
        return this.sign_msgs;
    }

    public <T extends EncryptBase> T signMessages(boolean flag) {
        this.sign_msgs = flag;
        return (T)this;
    }

    public boolean adler() {
        return this.use_adler;
    }

    public <T extends EncryptBase> T adler(boolean flag) {
        this.use_adler = flag;
        return (T)this;
    }

    @ManagedAttribute
    public String version() {
        return Util.byteArrayToHexString(this.sym_version);
    }

    @Override
    public void init() throws Exception {
        int tmp = Util.getNextHigherPowerOfTwo(this.cipher_pool_size);
        if (tmp != this.cipher_pool_size) {
            this.log.warn("%s: setting cipher_pool_size (%d) to %d (power of 2) for faster modulo operation", this.local_addr, this.cipher_pool_size, tmp);
            this.cipher_pool_size = tmp;
        }
        this.encoding_ciphers = new ArrayBlockingQueue<Cipher>(this.cipher_pool_size);
        this.decoding_ciphers = new ArrayBlockingQueue<Cipher>(this.cipher_pool_size);
        this.initSymCiphers(this.sym_algorithm, this.secret_key);
    }

    @Override
    public Object down(Event evt) {
        switch (evt.getType()) {
            case 1: {
                Message msg = (Message)evt.arg();
                try {
                    if (this.secret_key == null) {
                        this.log.trace("%s: discarded %s message to %s as secret key is null, hdrs: %s", this.local_addr, msg.dest() == null ? "mcast" : "unicast", msg.dest(), msg.printHeaders());
                        return null;
                    }
                    this.encryptAndSend(msg);
                }
                catch (Exception e) {
                    this.log.warn("%s: unable to send message down", this.local_addr, e);
                }
                return null;
            }
            case 6: {
                this.handleView((View)evt.getArg());
                break;
            }
            case 8: {
                this.local_addr = (Address)evt.arg();
            }
        }
        return this.down_prot.down(evt);
    }

    @Override
    public Object up(Event evt) {
        switch (evt.getType()) {
            case 6: {
                this.handleView((View)evt.getArg());
                break;
            }
            case 1: {
                Message msg = (Message)evt.arg();
                try {
                    return this.handleUpMessage(msg);
                }
                catch (Exception e) {
                    this.log.warn("%s: exception occurred decrypting message", this.local_addr, e);
                    return null;
                }
            }
        }
        return this.up_prot.up(evt);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public void up(MessageBatch batch) {
        block9: {
            Cipher cipher;
            block7: {
                block8: {
                    cipher = null;
                    if (this.secret_key != null) break block7;
                    this.log.trace("%s: discarded %s batch from %s as secret key is null", this.local_addr, batch.dest() == null ? "mcast" : "unicast", batch.sender());
                    if (cipher == null) break block8;
                    this.decoding_ciphers.offer(cipher);
                }
                return;
            }
            try {
                cipher = this.decoding_ciphers.take();
                Decrypter decrypter = new Decrypter(cipher);
                batch.map(decrypter);
                if (cipher == null) break block9;
                this.decoding_ciphers.offer(cipher);
            }
            catch (InterruptedException e) {
                block10: {
                    try {
                        this.log.error("%s: failed processing batch; discarding batch", this.local_addr, e);
                        if (cipher == null) break block10;
                        this.decoding_ciphers.offer(cipher);
                    }
                    catch (Throwable throwable) {
                        if (cipher != null) {
                            this.decoding_ciphers.offer(cipher);
                        }
                        throw throwable;
                    }
                }
                return;
            }
        }
        if (!batch.isEmpty()) {
            this.up_prot.up(batch);
        }
    }

    protected synchronized void initSymCiphers(String algorithm, SecretKey secret) throws Exception {
        if (secret == null) {
            return;
        }
        this.encoding_ciphers.clear();
        this.decoding_ciphers.clear();
        for (int i = 0; i < this.cipher_pool_size; ++i) {
            this.encoding_ciphers.add(this.createCipher(1, secret, algorithm));
            this.decoding_ciphers.add(this.createCipher(2, secret, algorithm));
        }
        MessageDigest digest = MessageDigest.getInstance("MD5");
        digest.reset();
        digest.update(secret.getEncoded());
        byte[] tmp = digest.digest();
        this.sym_version = Arrays.copyOf(tmp, tmp.length);
        this.log.debug("%s: created %d symmetric ciphers with secret key (%d bytes)", this.local_addr, this.cipher_pool_size, this.sym_version.length);
    }

    protected Cipher createCipher(int mode, SecretKey secret_key, String algorithm) throws Exception {
        Cipher cipher = this.provider != null && !this.provider.trim().isEmpty() ? Cipher.getInstance(algorithm, this.provider) : Cipher.getInstance(algorithm);
        cipher.init(mode, secret_key);
        return cipher;
    }

    protected Object handleUpMessage(Message msg) throws Exception {
        EncryptHeader hdr = (EncryptHeader)msg.getHeader(this.id);
        if (hdr == null) {
            this.log.error("%s: received message without encrypt header from %s; dropping it", this.local_addr, msg.src());
            return null;
        }
        switch (hdr.type()) {
            case 1: {
                return this.handleEncryptedMessage(msg);
            }
        }
        return this.handleUpEvent(msg, hdr);
    }

    protected Object handleEncryptedMessage(Message msg) throws Exception {
        if (!this.process(msg)) {
            return null;
        }
        Message tmpMsg = this.decryptMessage(null, msg.copy());
        if (tmpMsg != null) {
            return this.up_prot.up(new Event(1, tmpMsg));
        }
        this.log.warn("%s: unrecognized cipher; discarding message from %s", this.local_addr, msg.src());
        return null;
    }

    protected Object handleUpEvent(Message msg, EncryptHeader hdr) {
        return null;
    }

    protected boolean process(Message msg) {
        return true;
    }

    protected void handleView(View view) {
        this.view = view;
    }

    protected boolean inView(Address sender, String error_msg) {
        View curr_view = this.view;
        if (curr_view == null || curr_view.containsMember(sender)) {
            return true;
        }
        this.log.error(error_msg, sender, curr_view);
        return false;
    }

    protected Checksum createChecksummer() {
        return this.use_adler ? new Adler32() : new CRC32();
    }

    protected Message decryptMessage(Cipher cipher, Message msg) throws Exception {
        EncryptHeader hdr = (EncryptHeader)msg.getHeader(this.id);
        if (!Arrays.equals(hdr.version(), this.sym_version)) {
            cipher = this.key_map.get(new AsciiString(hdr.version()));
            if (cipher == null) {
                this.handleUnknownVersion();
                return null;
            }
            this.log.trace("%s: decrypting msg from %s using previous cipher version", this.local_addr, msg.src());
            return this._decrypt(cipher, msg, hdr);
        }
        return this._decrypt(cipher, msg, hdr);
    }

    protected Message _decrypt(Cipher cipher, Message msg, EncryptHeader hdr) throws Exception {
        if (!this.encrypt_entire_message && msg.getLength() == 0) {
            return msg;
        }
        if (this.encrypt_entire_message && this.sign_msgs) {
            byte[] signature = hdr.signature();
            if (signature == null) {
                this.log.error("%s: dropped message from %s as the header did not have a checksum", this.local_addr, msg.src());
                return null;
            }
            long msg_checksum = this.decryptChecksum(cipher, signature, 0, signature.length);
            long actual_checksum = this.computeChecksum(msg.getRawBuffer(), msg.getOffset(), msg.getLength());
            if (actual_checksum != msg_checksum) {
                this.log.error("%s: dropped message from %s as the message's checksum (%d) did not match the computed checksum (%d)", this.local_addr, msg.src(), msg_checksum, actual_checksum);
                return null;
            }
        }
        byte[] decrypted_msg = cipher == null ? this.code(msg.getRawBuffer(), msg.getOffset(), msg.getLength(), true) : cipher.doFinal(msg.getRawBuffer(), msg.getOffset(), msg.getLength());
        if (!this.encrypt_entire_message) {
            msg.setBuffer(decrypted_msg);
            return msg;
        }
        Message ret = Util.streamableFromBuffer(Message.class, decrypted_msg, 0, decrypted_msg.length);
        if (ret.getDest() == null) {
            ret.setDest(msg.getDest());
        }
        if (ret.getSrc() == null) {
            ret.setSrc(msg.getSrc());
        }
        return ret;
    }

    protected void encryptAndSend(Message msg) throws Exception {
        EncryptHeader hdr = new EncryptHeader(1, this.symVersion());
        if (this.encrypt_entire_message) {
            if (msg.getSrc() == null) {
                msg.setSrc(this.local_addr);
            }
            Buffer serialized_msg = Util.streamableToBuffer(msg);
            byte[] encrypted_msg = this.code(serialized_msg.getBuf(), serialized_msg.getOffset(), serialized_msg.getLength(), false);
            if (this.sign_msgs) {
                long checksum = this.computeChecksum(encrypted_msg, 0, encrypted_msg.length);
                byte[] checksum_array = this.encryptChecksum(checksum);
                hdr.signature(checksum_array);
            }
            Message tmp = msg.copy(false, false).setBuffer(encrypted_msg).putHeader(this.id, hdr);
            this.down_prot.down(new Event(1, tmp));
            return;
        }
        Message msgEncrypted = msg.copy(false).putHeader(this.id, hdr);
        if (msg.getLength() > 0) {
            msgEncrypted.setBuffer(this.code(msg.getRawBuffer(), msg.getOffset(), msg.getLength(), false));
        }
        this.down_prot.down(new Event(1, msgEncrypted));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected byte[] code(byte[] buf, int offset, int length, boolean decode) throws Exception {
        BlockingQueue<Cipher> queue = decode ? this.decoding_ciphers : this.encoding_ciphers;
        Cipher cipher = queue.take();
        try {
            byte[] byArray = cipher.doFinal(buf, offset, length);
            return byArray;
        }
        finally {
            queue.offer(cipher);
        }
    }

    protected long computeChecksum(byte[] input, int offset, int length) {
        Checksum checksummer = this.createChecksummer();
        checksummer.update(input, offset, length);
        return checksummer.getValue();
    }

    protected byte[] encryptChecksum(long checksum) throws Exception {
        byte[] checksum_array = new byte[8];
        Bits.writeLong(checksum, checksum_array, 0);
        return this.code(checksum_array, 0, checksum_array.length, false);
    }

    protected long decryptChecksum(Cipher cipher, byte[] input, int offset, int length) throws Exception {
        byte[] decrypted_checksum = cipher == null ? this.code(input, offset, length, true) : cipher.doFinal(input, offset, length);
        return Bits.readLong(decrypted_checksum, 0);
    }

    protected static String getAlgorithm(String s) {
        int index = s.indexOf(47);
        return index == -1 ? s : s.substring(0, index);
    }

    protected void handleUnknownVersion() {
    }

    protected class Decrypter
    implements MessageBatch.Visitor<Message> {
        protected final Cipher cipher;

        public Decrypter(Cipher cipher) {
            this.cipher = cipher;
        }

        @Override
        public Message visit(Message msg, MessageBatch batch) {
            block7: {
                EncryptHeader hdr = (EncryptHeader)msg.getHeader(EncryptBase.this.id);
                if (hdr == null) {
                    EncryptBase.this.log.error("%s: received message without encrypt header from %s; dropping it", EncryptBase.this.local_addr, batch.sender());
                    batch.remove(msg);
                    return null;
                }
                if (hdr.type() == 1) {
                    try {
                        if (!EncryptBase.this.process(msg)) {
                            batch.remove(msg);
                            return null;
                        }
                        Message tmpMsg = EncryptBase.this.decryptMessage(this.cipher, msg.copy());
                        if (tmpMsg != null) {
                            batch.replace(msg, tmpMsg);
                            break block7;
                        }
                        batch.remove(msg);
                    }
                    catch (Exception e) {
                        EncryptBase.this.log.error("%s: failed decrypting message from %s (offset=%d, length=%d, buf.length=%d): %s, headers are %s", EncryptBase.this.local_addr, msg.getSrc(), msg.getOffset(), msg.getLength(), msg.getRawBuffer().length, e, msg.printHeaders());
                        batch.remove(msg);
                    }
                } else {
                    batch.remove(msg);
                    EncryptBase.this.handleUpEvent(msg, hdr);
                }
            }
            return null;
        }
    }
}

