/*
 * Decompiled with CFR 0.152.
 */
package com.fortanix.sdkms.jce.provider.ciphers;

import com.fortanix.dsm.accelerator.Algorithm;
import com.fortanix.dsm.accelerator.CipherMode;
import com.fortanix.dsm.accelerator.DSMAccelerator;
import com.fortanix.dsm.accelerator.DSMAcceleratorException;
import com.fortanix.dsm.accelerator.DecryptRequest;
import com.fortanix.dsm.accelerator.DecryptResponse;
import com.fortanix.dsm.accelerator.EncryptRequest;
import com.fortanix.dsm.accelerator.EncryptResponse;
import com.fortanix.sdkms.jce.provider.service.ISdkmsCommand;
import com.fortanix.sdkms.jce.provider.service.SDKMSLogger;
import com.fortanix.sdkms.jce.provider.valentino.DSMAcceleratorClientSetup;
import com.fortanix.sdkms.v1.ApiException;
import com.fortanix.sdkms.v1.model.CryptMode;
import com.fortanix.sdkms.v1.model.DecryptRequestEx;
import com.fortanix.sdkms.v1.model.EncryptRequestEx;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.slf4j.LoggerFactory;

public class DSMACipher {
    private static final SDKMSLogger LOGGER = new SDKMSLogger(LoggerFactory.getLogger(DSMACipher.class));
    private static final List<Integer> validGcmTagLength = Arrays.asList(128, 120, 112, 104, 96, 64, 32);

    public static void attachGCMTag(ByteArrayOutputStream cipherStream, CryptMode mode, byte[] gcmTag, int expectedTagLength) throws IOException {
        if (!DSMACipher.isGCM(mode)) {
            return;
        }
        if (gcmTag == null) {
            LOGGER.logAndRaiseProviderException(String.format("%s tag is missing", mode), null);
        }
        if (gcmTag.length * 8 != expectedTagLength) {
            LOGGER.logAndRaiseProviderException(String.format("SDKMS generated tag length doesn't matched, Expected: %s, Actual: %s", expectedTagLength, gcmTag.length), null);
        }
        cipherStream.write(gcmTag);
    }

    public static byte[] encrypt(final EncryptRequestEx encryptRequest) {
        ByteArrayOutputStream cipherStream = new ByteArrayOutputStream();
        byte[] cipherBytes = null;
        final DSMAccelerator client = DSMAcceleratorClientSetup.getInstance().getDsmAcceleratorClient();
        try {
            Integer expectedTagLength = encryptRequest.getTagLen();
            EncryptResponse encryptResponse = (EncryptResponse)DSMAcceleratorClientSetup.getInstance().ensureValidSession(new ISdkmsCommand(){

                @Override
                public Object execute() throws ApiException {
                    try {
                        return client.encrypt(EncryptRequest.builder().setKid(encryptRequest.getKey().getKid()).setPlain(encryptRequest.getPlain()).setAlg(Algorithm.valueOf((String)encryptRequest.getAlg().getValue())).setMode(CipherMode.valueOf((String)encryptRequest.getMode().getValue())).setAd(encryptRequest.getAd()).setIv(encryptRequest.getIv()).setTagLen(encryptRequest.getTagLen() != null ? encryptRequest.getTagLen() : 0).build());
                    }
                    catch (DSMAcceleratorException e) {
                        LOGGER.logAndRaiseProviderException("Error during Cipher encryption. Selected mode was " + encryptRequest.getMode(), e);
                        return null;
                    }
                }

                @Override
                public String getDescription() {
                    return "Encrypt";
                }
            });
            cipherStream.write(encryptResponse.getCipher());
            if (DSMACipher.isGCM(encryptRequest.getMode())) {
                DSMACipher.attachGCMTag(cipherStream, encryptRequest.getMode(), encryptResponse.getTag(), expectedTagLength);
            }
            cipherBytes = cipherStream.toByteArray();
        }
        catch (ApiException | IOException e) {
            LOGGER.logAndRaiseProviderException("Error during Cipher encryption. Selected mode was " + encryptRequest.getMode(), e);
        }
        return cipherBytes;
    }

    public static byte[] decrypt(Integer tagLength, final DecryptRequestEx decryptRequest) {
        DecryptResponse decryptResponse = null;
        final DSMAccelerator client = DSMAcceleratorClientSetup.getInstance().getDsmAcceleratorClient();
        try {
            if (DSMACipher.isGCM(decryptRequest.getMode())) {
                CipherAndTag cipherAndTag = DSMACipher.extractGCMTag(tagLength, decryptRequest.getCipher(), 0, decryptRequest.getCipher().length);
                decryptRequest.setCipher(cipherAndTag.cipher);
                decryptRequest.setTag(cipherAndTag.tag);
            }
            decryptResponse = (DecryptResponse)DSMAcceleratorClientSetup.getInstance().ensureValidSession(new ISdkmsCommand(){

                @Override
                public Object execute() throws ApiException {
                    try {
                        return client.decrypt(DecryptRequest.builder().setKid(decryptRequest.getKey().getKid()).setCipher(decryptRequest.getCipher()).setAlg(Algorithm.valueOf((String)decryptRequest.getAlg().getValue())).setMode(CipherMode.valueOf((String)decryptRequest.getMode().getValue())).setIv(decryptRequest.getIv()).setTag(decryptRequest.getTag()).build());
                    }
                    catch (DSMAcceleratorException e) {
                        LOGGER.logAndRaiseProviderException("Error during Cipher encryption. Selected mode was " + decryptRequest.getMode(), e);
                        return null;
                    }
                }

                @Override
                public String getDescription() {
                    return "Encrypt";
                }
            });
        }
        catch (ApiException e) {
            LOGGER.logAndRaiseProviderException("Error during Cipher encryption. Selected mode was " + decryptRequest.getMode(), e);
        }
        return decryptResponse.getPlain();
    }

    public static boolean isGCM(CryptMode mode) {
        return mode == CryptMode.GCM || mode == CryptMode.CCM;
    }

    public static CipherAndTag extractGCMTag(int tagLength, byte[] cipherTagBytes, int cipherTagOffset, int cipherTagLen) {
        CipherAndTag cipherAndTag = new CipherAndTag();
        if (cipherTagBytes == null || cipherTagBytes.length == 0 || cipherTagLen == 0) {
            return cipherAndTag;
        }
        int tagLengthInByte = tagLength / 8;
        int cipherLengthInByte = cipherTagLen - tagLengthInByte;
        cipherAndTag.tag = new byte[tagLengthInByte];
        cipherAndTag.cipher = new byte[cipherLengthInByte];
        cipherAndTag.cipherOffset = 0;
        cipherAndTag.cipherLen = cipherLengthInByte;
        System.arraycopy(cipherTagBytes, cipherTagOffset + cipherLengthInByte, cipherAndTag.tag, 0, tagLengthInByte);
        System.arraycopy(cipherTagBytes, cipherTagOffset, cipherAndTag.cipher, 0, cipherLengthInByte);
        return cipherAndTag;
    }

    public static class CipherAndTag {
        public byte[] cipher;
        public int cipherOffset;
        public int cipherLen;
        public byte[] tag;

        public CipherAndTag() {
        }

        public CipherAndTag(byte[] cipher, int cipherOffset, int cipherLen, byte[] tag) {
            this.cipher = cipher;
            this.cipherOffset = cipherOffset;
            this.cipherLen = cipherLen;
            this.tag = tag;
        }
    }
}

