/*
 * Decompiled with CFR 0.152.
 */
package org.kapott.cryptalgs;

import java.io.ByteArrayOutputStream;
import java.math.BigInteger;
import java.security.InvalidAlgorithmParameterException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.NoSuchProviderException;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.SignatureException;
import java.security.SignatureSpi;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.AlgorithmParameterSpec;
import java.util.Arrays;
import org.kapott.cryptalgs.RSAPrivateCrtKey2;
import org.kapott.cryptalgs.SignatureParamSpec;

public class PKCS1_PSS
extends SignatureSpi {
    private RSAPublicKey pubKey;
    private PrivateKey privKey;
    private SignatureParamSpec param;
    private ByteArrayOutputStream plainmsg;

    @Override
    @Deprecated
    protected void engineSetParameter(String param1, Object value) {
    }

    @Override
    protected void engineSetParameter(AlgorithmParameterSpec param1) throws InvalidAlgorithmParameterException {
        if (!(param1 instanceof SignatureParamSpec)) {
            throw new InvalidAlgorithmParameterException();
        }
        this.param = (SignatureParamSpec)param1;
    }

    @Override
    @Deprecated
    protected Object engineGetParameter(String parameter) {
        return null;
    }

    public static MessageDigest getMessageDigest(SignatureParamSpec spec) {
        MessageDigest result;
        try {
            String provider = spec.getProvider();
            result = provider != null ? MessageDigest.getInstance(spec.getHashAlg(), provider) : MessageDigest.getInstance(spec.getHashAlg());
        }
        catch (NoSuchAlgorithmException e) {
            throw new RuntimeException(e);
        }
        catch (NoSuchProviderException e) {
            throw new RuntimeException(e);
        }
        return result;
    }

    @Override
    protected void engineInitSign(PrivateKey privateKey) {
        this.privKey = privateKey;
        this.plainmsg = new ByteArrayOutputStream();
    }

    @Override
    protected void engineInitVerify(PublicKey publicKey) {
        this.pubKey = (RSAPublicKey)publicKey;
        this.plainmsg = new ByteArrayOutputStream();
    }

    @Override
    protected void engineUpdate(byte b) {
        this.plainmsg.write(b);
    }

    @Override
    protected void engineUpdate(byte[] b, int offset, int length) {
        for (int i = 0; i < length; ++i) {
            this.engineUpdate(b[offset + i]);
        }
    }

    @Override
    protected int engineSign(byte[] output, int offset, int len) throws SignatureException {
        byte[] sig = this.engineSign();
        if (offset + len > output.length) {
            throw new SignatureException("output result too large for buffer");
        }
        System.arraycopy(sig, 0, output, offset, sig.length);
        return sig.length;
    }

    @Override
    protected byte[] engineSign() {
        return this.pss_sign(this.privKey, this.plainmsg.toByteArray());
    }

    @Override
    protected boolean engineVerify(byte[] sig) {
        return this.pss_verify(this.pubKey, this.plainmsg.toByteArray(), sig);
    }

    private static byte[] i2os(BigInteger x, int outLen) {
        byte[] bytes = x.toByteArray();
        if (bytes.length > outLen) {
            for (int i = 0; i < bytes.length - outLen; ++i) {
                if (bytes[i] == 0) continue;
                throw new RuntimeException("value too large");
            }
            byte[] out = new byte[outLen];
            System.arraycopy(bytes, bytes.length - outLen, out, 0, outLen);
            bytes = out;
        } else if (bytes.length < outLen) {
            byte[] out = new byte[outLen];
            System.arraycopy(bytes, 0, out, outLen - bytes.length, bytes.length);
            bytes = out;
        }
        return bytes;
    }

    private static BigInteger os2i(byte[] bytes) {
        return new BigInteger(1, bytes);
    }

    private static BigInteger sp1(PrivateKey key, BigInteger m) {
        BigInteger result;
        if (key instanceof RSAPrivateKey) {
            BigInteger d = ((RSAPrivateKey)key).getPrivateExponent();
            BigInteger n = ((RSAPrivateKey)key).getModulus();
            result = m.modPow(d, n);
        } else {
            RSAPrivateCrtKey2 key2 = (RSAPrivateCrtKey2)key;
            BigInteger p = key2.getP();
            BigInteger q = key2.getQ();
            BigInteger dP = key2.getdP();
            BigInteger dQ = key2.getdQ();
            BigInteger qInv = key2.getQInv();
            BigInteger s1 = m.modPow(dP, p);
            BigInteger s2 = m.modPow(dQ, q);
            BigInteger h = s1.subtract(s2).multiply(qInv).mod(p);
            result = s2.add(q.multiply(h));
        }
        return result;
    }

    private static BigInteger vp1(RSAPublicKey key, BigInteger s) {
        BigInteger e = key.getPublicExponent();
        BigInteger n = key.getModulus();
        BigInteger m = s.modPow(e, n);
        return m;
    }

    private static byte[] concat(byte[] x1, byte[] x2) {
        byte[] result = new byte[x1.length + x2.length];
        System.arraycopy(x1, 0, result, 0, x1.length);
        System.arraycopy(x2, 0, result, x1.length, x2.length);
        return result;
    }

    private static byte[] hash(SignatureParamSpec spec, byte[] data) {
        MessageDigest dig = PKCS1_PSS.getMessageDigest(spec);
        dig.reset();
        return dig.digest(data);
    }

    private static byte[] mgf1(SignatureParamSpec spec, byte[] mgfSeed, int maskLen) {
        MessageDigest dig = PKCS1_PSS.getMessageDigest(spec);
        int hLen = dig.getDigestLength();
        byte[] T = new byte[]{};
        int i = 0;
        while ((double)i < Math.ceil((double)maskLen / (double)hLen)) {
            byte[] c = PKCS1_PSS.i2os(new BigInteger(Integer.toString(i)), 4);
            T = PKCS1_PSS.concat(T, PKCS1_PSS.hash(spec, PKCS1_PSS.concat(mgfSeed, c)));
            ++i;
        }
        byte[] result = new byte[maskLen];
        System.arraycopy(T, 0, result, 0, maskLen);
        return result;
    }

    private static byte[] random_os(int len) {
        byte[] result = new byte[len];
        for (int i = 0; i < len; ++i) {
            result[i] = (byte)(256.0 * Math.random());
        }
        return result;
    }

    private static byte[] xor_os(byte[] a1, byte[] a2) {
        if (a1.length != a2.length) {
            throw new RuntimeException("a1.len != a2.len");
        }
        byte[] result = new byte[a1.length];
        for (int i = 0; i < result.length; ++i) {
            result[i] = (byte)(a1[i] ^ a2[i]);
        }
        return result;
    }

    public static byte[] emsa_pss_encode(SignatureParamSpec spec, byte[] msg, int emBits) {
        int hLen;
        int emLen = emBits >> 3;
        if ((emBits & 7) != 0) {
            ++emLen;
        }
        byte[] mHash = PKCS1_PSS.hash(spec, msg);
        MessageDigest dig = PKCS1_PSS.getMessageDigest(spec);
        int sLen = hLen = dig.getDigestLength();
        byte[] salt = PKCS1_PSS.random_os(sLen);
        byte[] zeroes = new byte[8];
        byte[] m2 = PKCS1_PSS.concat(PKCS1_PSS.concat(zeroes, mHash), salt);
        byte[] H = PKCS1_PSS.hash(spec, m2);
        byte[] PS = new byte[emLen - sLen - hLen - 2];
        byte[] DB = PKCS1_PSS.concat(PKCS1_PSS.concat(PS, new byte[]{1}), salt);
        byte[] dbMask = PKCS1_PSS.mgf1(spec, H, emLen - hLen - 1);
        byte[] maskedDB = PKCS1_PSS.xor_os(DB, dbMask);
        int tooMuchBits = (emLen << 3) - emBits;
        byte mask = (byte)(255 >>> tooMuchBits);
        maskedDB[0] = (byte)(maskedDB[0] & mask);
        byte[] EM = PKCS1_PSS.concat(PKCS1_PSS.concat(maskedDB, H), new byte[]{-68});
        return EM;
    }

    public static boolean emsa_pss_verify(SignatureParamSpec spec, byte[] msg, byte[] EM, int emBits) {
        int hLen;
        int emLen = emBits >> 3;
        if ((emBits & 7) != 0) {
            ++emLen;
        }
        byte[] mHash = PKCS1_PSS.hash(spec, msg);
        MessageDigest dig = PKCS1_PSS.getMessageDigest(spec);
        int sLen = hLen = dig.getDigestLength();
        if (EM[EM.length - 1] != -68) {
            return false;
        }
        byte[] maskedDB = new byte[emLen - hLen - 1];
        byte[] H = new byte[hLen];
        System.arraycopy(EM, 0, maskedDB, 0, emLen - hLen - 1);
        System.arraycopy(EM, emLen - hLen - 1, H, 0, hLen);
        byte[] dbMask = PKCS1_PSS.mgf1(spec, H, emLen - hLen - 1);
        byte[] DB = PKCS1_PSS.xor_os(maskedDB, dbMask);
        int tooMuchBits = (emLen << 3) - emBits;
        byte mask = (byte)(255 >>> tooMuchBits);
        DB[0] = (byte)(DB[0] & mask);
        byte[] salt = new byte[sLen];
        System.arraycopy(DB, DB.length - sLen, salt, 0, sLen);
        byte[] zeroes = new byte[8];
        byte[] m2 = PKCS1_PSS.concat(PKCS1_PSS.concat(zeroes, mHash), salt);
        byte[] H2 = PKCS1_PSS.hash(spec, m2);
        return Arrays.equals(H, H2);
    }

    public static int calculateEMBitLen(BigInteger modulus) {
        return modulus.bitLength() - 1;
    }

    private byte[] pss_sign(PrivateKey key, byte[] msg) {
        BigInteger bModulus = key instanceof RSAPrivateKey ? ((RSAPrivateKey)key).getModulus() : ((RSAPrivateCrtKey2)key).getP().multiply(((RSAPrivateCrtKey2)key).getQ());
        int modBits = bModulus.bitLength();
        int k = modBits >> 3;
        if ((modBits & 7) != 0) {
            ++k;
        }
        byte[] EM = PKCS1_PSS.emsa_pss_encode(this.param, msg, modBits - 1);
        BigInteger m = PKCS1_PSS.os2i(EM);
        BigInteger s = PKCS1_PSS.sp1(key, m);
        byte[] S = PKCS1_PSS.i2os(s, k);
        return S;
    }

    private boolean pss_verify(RSAPublicKey key, byte[] msg, byte[] S) {
        BigInteger s = PKCS1_PSS.os2i(S);
        BigInteger m = PKCS1_PSS.vp1(key, s);
        BigInteger n = key.getModulus();
        int emBits = n.bitLength() - 1;
        int emLen = emBits >> 3;
        if ((emBits & 7) != 0) {
            ++emLen;
        }
        byte[] EM = PKCS1_PSS.i2os(m, emLen);
        return PKCS1_PSS.emsa_pss_verify(this.param, msg, EM, emBits);
    }
}

