package org.apache.spark.network.crypto;

import com.google.crypto.tink.subtle.AesGcmHkdfStreaming;
import com.google.crypto.tink.subtle.StreamSegmentDecrypter;
import com.google.crypto.tink.subtle.StreamSegmentEncrypter;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.FileRegion;
import io.netty.util.ReferenceCounted;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.security.GeneralSecurityException;
import java.security.InvalidAlgorithmParameterException;
import javax.crypto.spec.SecretKeySpec;
import org.apache.spark.network.util.AbstractFileRegion;
import org.apache.spark.network.util.ByteBufferWriteableChannel;
import org.sparkproject.guava.annotations.VisibleForTesting;
import org.sparkproject.guava.base.Preconditions;
import org.sparkproject.guava.primitives.Longs;

/* loaded from: input_file:org/apache/spark/network/crypto/GcmTransportCipher.class */
public class GcmTransportCipher implements TransportCipher {
    private static final String HKDF_ALG = "HmacSha256";
    private static final int LENGTH_HEADER_BYTES = 8;

    @VisibleForTesting
    static final int CIPHERTEXT_BUFFER_SIZE = 32768;
    private final SecretKeySpec aesKey;

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/spark/network/crypto/GcmTransportCipher$DecryptionHandler.class */
    public class DecryptionHandler extends ChannelInboundHandlerAdapter {
        private final ByteBuffer headerBuffer;
        private final ByteBuffer ciphertextBuffer;
        private final AesGcmHkdfStreaming aesGcmHkdfStreaming;
        private final StreamSegmentDecrypter decrypter;
        private final int plaintextSegmentSize;
        private boolean decrypterInit = false;
        private boolean completed = false;
        private int segmentNumber = 0;
        private long expectedLength = -1;
        private long ciphertextRead = 0;
        private final ByteBuffer expectedLengthBuffer = ByteBuffer.allocate(8);

        DecryptionHandler() throws GeneralSecurityException {
            this.aesGcmHkdfStreaming = GcmTransportCipher.this.getAesGcmHkdfStreaming();
            this.headerBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getHeaderLength());
            this.ciphertextBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getCiphertextSegmentSize());
            this.decrypter = this.aesGcmHkdfStreaming.newStreamSegmentDecrypter();
            this.plaintextSegmentSize = this.aesGcmHkdfStreaming.getPlaintextSegmentSize();
        }

        private boolean initalizeExpectedLength(ByteBuf byteBuf) {
            if (this.expectedLength >= 0) {
                return true;
            }
            byteBuf.readBytes(this.expectedLengthBuffer);
            if (this.expectedLengthBuffer.hasRemaining()) {
                return false;
            }
            this.expectedLengthBuffer.flip();
            this.expectedLength = this.expectedLengthBuffer.getLong();
            if (this.expectedLength < 0) {
                throw new IllegalStateException("Invalid expected ciphertext length.");
            }
            this.ciphertextRead += 8;
            return true;
        }

        private boolean initalizeDecrypter(ByteBuf byteBuf) throws GeneralSecurityException {
            if (this.decrypterInit) {
                return true;
            }
            byteBuf.readBytes(this.headerBuffer);
            if (this.headerBuffer.hasRemaining()) {
                return false;
            }
            this.headerBuffer.flip();
            this.decrypter.init(this.headerBuffer, Longs.toByteArray(this.expectedLength));
            this.decrypterInit = true;
            this.ciphertextRead += this.aesGcmHkdfStreaming.getHeaderLength();
            if (this.expectedLength != this.ciphertextRead) {
                return true;
            }
            this.completed = true;
            return true;
        }

        public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws GeneralSecurityException {
            Preconditions.checkArgument(obj instanceof ByteBuf, "Unrecognized message type: %s", obj.getClass().getName());
            ByteBuf byteBuf = (ByteBuf) obj;
            try {
                if (initalizeExpectedLength(byteBuf)) {
                    if (!initalizeDecrypter(byteBuf)) {
                        byteBuf.release();
                        return;
                    }
                    for (int readableBytes = byteBuf.readableBytes(); readableBytes > 0 && !this.completed; readableBytes = byteBuf.readableBytes()) {
                        int min = Integer.min(Integer.min(readableBytes, this.ciphertextBuffer.remaining()), (int) (this.expectedLength - this.ciphertextRead));
                        this.ciphertextBuffer.limit(this.ciphertextBuffer.position() + min);
                        byteBuf.readBytes(this.ciphertextBuffer);
                        this.ciphertextRead += min;
                        if (this.ciphertextRead == this.expectedLength) {
                            this.completed = true;
                        } else if (this.ciphertextRead > this.expectedLength) {
                            throw new IllegalStateException("Read more ciphertext than expected.");
                        }
                        if (this.ciphertextBuffer.limit() == this.ciphertextBuffer.capacity() || this.completed) {
                            ByteBuffer allocate = ByteBuffer.allocate(this.plaintextSegmentSize);
                            this.ciphertextBuffer.flip();
                            this.decrypter.decryptSegment(this.ciphertextBuffer, this.segmentNumber, this.completed, allocate);
                            this.segmentNumber++;
                            this.ciphertextBuffer.clear();
                            allocate.flip();
                            channelHandlerContext.fireChannelRead(Unpooled.wrappedBuffer(allocate));
                        } else {
                            this.ciphertextBuffer.limit(this.ciphertextBuffer.capacity());
                        }
                    }
                    byteBuf.release();
                }
            } finally {
                byteBuf.release();
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:org/apache/spark/network/crypto/GcmTransportCipher$EncryptionHandler.class */
    public class EncryptionHandler extends ChannelOutboundHandlerAdapter {
        private final ByteBuffer plaintextBuffer;
        private final ByteBuffer ciphertextBuffer;
        private final AesGcmHkdfStreaming aesGcmHkdfStreaming;

        EncryptionHandler() throws InvalidAlgorithmParameterException {
            this.aesGcmHkdfStreaming = GcmTransportCipher.this.getAesGcmHkdfStreaming();
            this.plaintextBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getPlaintextSegmentSize());
            this.ciphertextBuffer = ByteBuffer.allocate(this.aesGcmHkdfStreaming.getCiphertextSegmentSize());
        }

        public void write(ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
            channelHandlerContext.write(new GcmEncryptedMessage(this.aesGcmHkdfStreaming, obj, this.plaintextBuffer, this.ciphertextBuffer), channelPromise);
        }
    }

    /* loaded from: input_file:org/apache/spark/network/crypto/GcmTransportCipher$GcmEncryptedMessage.class */
    static class GcmEncryptedMessage extends AbstractFileRegion {
        private final Object plaintextMessage;
        private final ByteBuffer plaintextBuffer;
        private final ByteBuffer ciphertextBuffer;
        private final ByteBuffer headerByteBuffer;
        private final long bytesToRead;
        private final StreamSegmentEncrypter encrypter;
        private final long encryptedCount;
        private long bytesRead = 0;
        private long transferred = 0;

        GcmEncryptedMessage(AesGcmHkdfStreaming aesGcmHkdfStreaming, Object obj, ByteBuffer byteBuffer, ByteBuffer byteBuffer2) throws GeneralSecurityException {
            Preconditions.checkArgument((obj instanceof ByteBuf) || (obj instanceof FileRegion), "Unrecognized message type: %s", obj.getClass().getName());
            this.plaintextMessage = obj;
            this.plaintextBuffer = byteBuffer;
            this.ciphertextBuffer = byteBuffer2;
            this.ciphertextBuffer.limit(0);
            this.bytesToRead = getReadableBytes();
            this.encryptedCount = 8 + aesGcmHkdfStreaming.expectedCiphertextSize(this.bytesToRead);
            this.encrypter = aesGcmHkdfStreaming.newStreamSegmentEncrypter(Longs.toByteArray(this.encryptedCount));
            this.headerByteBuffer = createHeaderByteBuffer();
        }

        private ByteBuffer createHeaderByteBuffer() {
            ByteBuffer header = this.encrypter.getHeader();
            ByteBuffer put = ByteBuffer.allocate(header.remaining() + 8).putLong(this.encryptedCount).put(header);
            put.flip();
            return put;
        }

        public long position() {
            return 0L;
        }

        public long transferred() {
            return this.transferred;
        }

        public long count() {
            return this.encryptedCount;
        }

        @Override // org.apache.spark.network.util.AbstractFileRegion
        /* renamed from: touch */
        public GcmEncryptedMessage mo7touch(Object obj) {
            super.mo7touch(obj);
            if (this.plaintextMessage instanceof ByteBuf) {
                ((ByteBuf) this.plaintextMessage).touch(obj);
            } else if (this.plaintextMessage instanceof FileRegion) {
                ((FileRegion) this.plaintextMessage).touch(obj);
            }
            return this;
        }

        @Override // org.apache.spark.network.util.AbstractFileRegion
        /* renamed from: retain */
        public GcmEncryptedMessage mo8retain(int i) {
            super.mo8retain(i);
            if (this.plaintextMessage instanceof ByteBuf) {
                ((ByteBuf) this.plaintextMessage).retain(i);
            } else if (this.plaintextMessage instanceof FileRegion) {
                ((FileRegion) this.plaintextMessage).retain(i);
            }
            return this;
        }

        public boolean release(int i) {
            if (this.plaintextMessage instanceof ByteBuf) {
                ((ByteBuf) this.plaintextMessage).release(i);
            } else if (this.plaintextMessage instanceof FileRegion) {
                ((FileRegion) this.plaintextMessage).release(i);
            }
            return super.release(i);
        }

        public long transferTo(WritableByteChannel writableByteChannel, long j) throws IOException {
            int i = 0;
            if (this.headerByteBuffer.hasRemaining()) {
                int write = writableByteChannel.write(this.headerByteBuffer);
                i = 0 + write;
                this.transferred += write;
                if (this.headerByteBuffer.hasRemaining()) {
                    return write;
                }
            }
            if (this.ciphertextBuffer.hasRemaining()) {
                int write2 = writableByteChannel.write(this.ciphertextBuffer);
                i += write2;
                this.transferred += write2;
                if (this.ciphertextBuffer.hasRemaining()) {
                    return i;
                }
            }
            while (this.bytesRead < this.bytesToRead) {
                int min = (int) Math.min(getReadableBytes(), this.plaintextBuffer.remaining());
                if (this.plaintextMessage instanceof ByteBuf) {
                    ByteBuf byteBuf = (ByteBuf) this.plaintextMessage;
                    Preconditions.checkState(0 == this.plaintextBuffer.position());
                    this.plaintextBuffer.limit(min);
                    byteBuf.readBytes(this.plaintextBuffer);
                    Preconditions.checkState(min == this.plaintextBuffer.position());
                } else if (this.plaintextMessage instanceof FileRegion) {
                    FileRegion fileRegion = (FileRegion) this.plaintextMessage;
                    if (fileRegion.transferTo(new ByteBufferWriteableChannel(this.plaintextBuffer), fileRegion.transferred()) < min) {
                        return i;
                    }
                }
                boolean z = getReadableBytes() == 0;
                this.plaintextBuffer.flip();
                this.bytesRead += this.plaintextBuffer.remaining();
                this.ciphertextBuffer.clear();
                try {
                    this.encrypter.encryptSegment(this.plaintextBuffer, z, this.ciphertextBuffer);
                    this.plaintextBuffer.clear();
                    this.ciphertextBuffer.flip();
                    int write3 = writableByteChannel.write(this.ciphertextBuffer);
                    i += write3;
                    this.transferred += write3;
                    if (this.ciphertextBuffer.hasRemaining()) {
                        return i;
                    }
                } catch (GeneralSecurityException e) {
                    throw new IllegalStateException("GeneralSecurityException from encrypter", e);
                }
            }
            return i;
        }

        private long getReadableBytes() {
            if (this.plaintextMessage instanceof ByteBuf) {
                return ((ByteBuf) this.plaintextMessage).readableBytes();
            }
            if (!(this.plaintextMessage instanceof FileRegion)) {
                throw new IllegalArgumentException("Unsupported message type: " + this.plaintextMessage.getClass().getName());
            }
            FileRegion fileRegion = (FileRegion) this.plaintextMessage;
            return fileRegion.count() - fileRegion.transferred();
        }

        protected void deallocate() {
            if (this.plaintextMessage instanceof ReferenceCounted) {
                ((ReferenceCounted) this.plaintextMessage).release();
            }
            this.plaintextBuffer.clear();
            this.ciphertextBuffer.clear();
        }
    }

    public GcmTransportCipher(SecretKeySpec secretKeySpec) {
        this.aesKey = secretKeySpec;
    }

    AesGcmHkdfStreaming getAesGcmHkdfStreaming() throws InvalidAlgorithmParameterException {
        return new AesGcmHkdfStreaming(this.aesKey.getEncoded(), HKDF_ALG, this.aesKey.getEncoded().length, CIPHERTEXT_BUFFER_SIZE, 0);
    }

    @Override // org.apache.spark.network.crypto.TransportCipher
    @VisibleForTesting
    public String getKeyId() throws GeneralSecurityException {
        return TransportCipherUtil.getKeyId(this.aesKey);
    }

    @VisibleForTesting
    EncryptionHandler getEncryptionHandler() throws GeneralSecurityException {
        return new EncryptionHandler();
    }

    @VisibleForTesting
    DecryptionHandler getDecryptionHandler() throws GeneralSecurityException {
        return new DecryptionHandler();
    }

    @Override // org.apache.spark.network.crypto.TransportCipher
    public void addToChannel(Channel channel) throws GeneralSecurityException {
        channel.pipeline().addFirst("GcmTransportEncryption", getEncryptionHandler()).addFirst("GcmTransportDecryption", getDecryptionHandler());
    }
}
