/*
 * Decompiled with CFR 0.152.
 */
package com.impossibl.postgres.protocol.sasl.scram.client;

import com.impossibl.postgres.jdbc.xa.Base64;
import com.impossibl.postgres.protocol.sasl.scram.ScramFunctions;
import com.impossibl.postgres.protocol.sasl.scram.ScramMechanism;
import com.impossibl.postgres.protocol.sasl.scram.exception.ScramException;
import com.impossibl.postgres.protocol.sasl.scram.exception.ScramInvalidServerSignatureException;
import com.impossibl.postgres.protocol.sasl.scram.exception.ScramParseException;
import com.impossibl.postgres.protocol.sasl.scram.exception.ScramServerErrorException;
import com.impossibl.postgres.protocol.sasl.scram.gssapi.Gs2CbindFlag;
import com.impossibl.postgres.protocol.sasl.scram.message.ClientFinalMessage;
import com.impossibl.postgres.protocol.sasl.scram.message.ClientFirstMessage;
import com.impossibl.postgres.protocol.sasl.scram.message.ServerFinalMessage;
import com.impossibl.postgres.protocol.sasl.scram.message.ServerFirstMessage;
import com.impossibl.postgres.protocol.sasl.scram.stringprep.StringPreparation;
import com.impossibl.postgres.protocol.sasl.scram.util.Preconditions;
import java.nio.charset.StandardCharsets;

public class ScramSession {
    private final ScramMechanism scramMechanism;
    private final String channelBindMethod;
    private final boolean serverSupportsChannelBinding;
    private final StringPreparation stringPreparation;
    private final String user;
    private final String nonce;
    private ClientFirstMessage clientFirstMessage;
    private ClientFinalProcessor clientFinalProcessor;
    private String serverFirstMessageString;

    ScramSession(ScramMechanism scramMechanism, String channelBindMethod, boolean serverSupportsChannelBinding, StringPreparation stringPreparation, String user, String nonce) {
        this.scramMechanism = Preconditions.checkNotNull(scramMechanism, "scramMechanism");
        this.channelBindMethod = channelBindMethod;
        this.serverSupportsChannelBinding = serverSupportsChannelBinding;
        this.stringPreparation = Preconditions.checkNotNull(stringPreparation, "stringPreparation");
        this.user = Preconditions.checkNotNull(user, "user");
        this.nonce = Preconditions.checkNotEmpty(nonce, "nonce");
    }

    public String getScramMechanismName() {
        return this.scramMechanism.getName();
    }

    public byte[] clientFirstMessage(String authzid) {
        Gs2CbindFlag gs2CbindFlag = this.channelBindMethod != null ? (this.scramMechanism.requiresChannelBinding() ? Gs2CbindFlag.ENABLED : (this.serverSupportsChannelBinding ? Gs2CbindFlag.DISABLED : Gs2CbindFlag.NO_SERVER_SUPPORT)) : Gs2CbindFlag.DISABLED;
        String channelBindMethod = gs2CbindFlag == Gs2CbindFlag.ENABLED ? this.channelBindMethod : null;
        this.clientFirstMessage = new ClientFirstMessage(gs2CbindFlag, authzid, channelBindMethod, this.user, this.nonce);
        return this.clientFirstMessage.toString().getBytes(StandardCharsets.UTF_8);
    }

    public boolean requiresChannelBindData() {
        return this.scramMechanism.requiresChannelBinding();
    }

    public String getChannelBindMethod() {
        return this.channelBindMethod;
    }

    public byte[] receiveServerFirstMessage(String serverFirstMessage, byte[] channelBindData, String password) throws ScramException {
        if (this.requiresChannelBindData() && channelBindData == null) {
            throw new ScramException("Missing required channel-bind data");
        }
        this.clientFinalProcessor = new ServerFirstProcessor(Preconditions.checkNotEmpty(serverFirstMessage, "serverFirstMessage")).clientFinalProcessor(password);
        return this.clientFinalProcessor.clientFinalMessage(channelBindData).getBytes(StandardCharsets.UTF_8);
    }

    public byte[] receiveServerFirstMessage(String serverFirstMessage, byte[] channelBindData, byte[] clientKey, byte[] storedKey) throws ScramException {
        if (this.requiresChannelBindData() && channelBindData == null) {
            throw new ScramException("Missing required channel-bind data");
        }
        this.clientFinalProcessor = new ServerFirstProcessor(Preconditions.checkNotEmpty(serverFirstMessage, "serverFirstMessage")).clientFinalProcessor(clientKey, storedKey);
        return this.clientFinalProcessor.clientFinalMessage(channelBindData).getBytes(StandardCharsets.UTF_8);
    }

    public void receiveServerFinalMessage(String serverFinalMessage) throws ScramException {
        if (this.clientFinalProcessor == null) {
            throw new IllegalStateException("No ClientFinalProcessor selected. Ensure receiveServerFirstMessage has been called.");
        }
        this.clientFinalProcessor.receiveServerFinalMessage(serverFinalMessage);
    }

    private class ServerFirstProcessor {
        private final ServerFirstMessage serverFirstMessage;

        private ServerFirstProcessor(String receivedServerFirstMessage) throws ScramParseException {
            ScramSession.this.serverFirstMessageString = receivedServerFirstMessage;
            this.serverFirstMessage = ServerFirstMessage.parseFrom(receivedServerFirstMessage, ScramSession.this.nonce);
        }

        public String getSalt() {
            return this.serverFirstMessage.getSalt();
        }

        public int getIteration() {
            return this.serverFirstMessage.getIteration();
        }

        public ClientFinalProcessor clientFinalProcessor(String password) throws IllegalArgumentException {
            return new ClientFinalProcessor(this.serverFirstMessage.getNonce(), Preconditions.checkNotEmpty(password, "password"), this.getSalt(), this.getIteration());
        }

        public ClientFinalProcessor clientFinalProcessor(byte[] clientKey, byte[] storedKey) throws IllegalArgumentException {
            return new ClientFinalProcessor(this.serverFirstMessage.getNonce(), Preconditions.checkNotNull(clientKey, "clientKey"), Preconditions.checkNotNull(storedKey, "storedKey"));
        }
    }

    private class ClientFinalProcessor {
        private final String nonce;
        private final byte[] clientKey;
        private final byte[] storedKey;
        private final byte[] serverKey;
        private String authMessage;

        private ClientFinalProcessor(String nonce, byte[] clientKey, byte[] storedKey, byte[] serverKey) {
            assert (null != clientKey) : "clientKey";
            assert (null != storedKey) : "storedKey";
            assert (null != serverKey) : "serverKey";
            this.nonce = nonce;
            this.clientKey = clientKey;
            this.storedKey = storedKey;
            this.serverKey = serverKey;
        }

        private ClientFinalProcessor(String nonce, byte[] clientKey, byte[] serverKey) {
            this(nonce, clientKey, ScramFunctions.storedKey(scramSession.scramMechanism, clientKey), serverKey);
        }

        private ClientFinalProcessor(String nonce, byte[] saltedPassword) {
            this(nonce, ScramFunctions.clientKey(scramSession.scramMechanism, saltedPassword), ScramFunctions.serverKey(scramSession.scramMechanism, saltedPassword));
        }

        private ClientFinalProcessor(String nonce, String password, String salt, int iteration) {
            this(nonce, ScramFunctions.saltedPassword(scramSession.scramMechanism, scramSession.stringPreparation, password, Base64.decode(salt), iteration));
        }

        private synchronized void generateAndCacheAuthMessage(byte[] cbindData) {
            if (null != this.authMessage) {
                return;
            }
            this.authMessage = ScramSession.this.clientFirstMessage.writeToWithoutGs2Header(new StringBuffer()).append(",").append(ScramSession.this.serverFirstMessageString).append(",").append(ClientFinalMessage.writeToWithoutProof(ScramSession.this.clientFirstMessage.getGs2Header(), cbindData, this.nonce)).toString();
        }

        public String clientFinalMessage(byte[] cbindData) throws IllegalArgumentException {
            if (null == this.authMessage) {
                this.generateAndCacheAuthMessage(cbindData);
            }
            ClientFinalMessage clientFinalMessage = new ClientFinalMessage(ScramSession.this.clientFirstMessage.getGs2Header(), cbindData, this.nonce, ScramFunctions.clientProof(this.clientKey, ScramFunctions.clientSignature(ScramSession.this.scramMechanism, this.storedKey, this.authMessage)));
            return clientFinalMessage.toString();
        }

        public void receiveServerFinalMessage(String serverFinalMessage) throws ScramParseException, ScramServerErrorException, ScramInvalidServerSignatureException, IllegalArgumentException {
            Preconditions.checkNotEmpty(serverFinalMessage, "serverFinalMessage");
            ServerFinalMessage message = ServerFinalMessage.parseFrom(serverFinalMessage);
            if (message.isError()) {
                throw new ScramServerErrorException(message.getError());
            }
            if (!ScramFunctions.verifyServerSignature(ScramSession.this.scramMechanism, this.serverKey, this.authMessage, message.getVerifier())) {
                throw new ScramInvalidServerSignatureException("Invalid server SCRAM signature");
            }
        }
    }
}

