package org.pac4j.saml.credentials;

import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.Conditions;
import org.opensaml.saml.saml2.core.NameIDType;
import org.pac4j.core.credentials.Credentials;
import org.pac4j.core.profile.converter.AttributeConverter;

import java.io.Serial;
import java.io.Serializable;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

/**
 * Credentials containing the nameId of the SAML subject and all of its attributes.
 *
 * @author Michael Remond
 * @since 1.5.0
 */
@Slf4j
@Getter
@Setter
@EqualsAndHashCode(callSuper = false)
@ToString
public class SAML2AuthenticationCredentials extends Credentials {

    @Serial
    private static final long serialVersionUID = 5040516205957826527L;

    private SAMLNameID nameId;

    private String sessionIndex;

    private List<SAMLAttribute> attributes;

    private SAMLConditions conditions;

    private String issuerId;

    private List<String> authnContexts;
    private List<String> authnContextAuthorities;

    private String inResponseTo;

    public SAML2AuthenticationCredentials() {}

    /**
     * <p>Constructor for SAML2AuthenticationCredentials.</p>
     *
     * @param nameId a {@link SAML2AuthenticationCredentials.SAMLNameID} object
     * @param issuerId a {@link String} object
     * @param samlAttributes a {@link List} object
     * @param conditions a {@link Conditions} object
     * @param sessionIndex a {@link String} object
     * @param authnContexts a {@link List} object
     * @param authnContextAuthorities a {@link List} object
     * @param inResponseTo a {@link String} object
     */
    public SAML2AuthenticationCredentials(final SAMLNameID nameId, final String issuerId,
                                          final List<SAMLAttribute> samlAttributes, final Conditions conditions,
                                          final String sessionIndex, final List<String> authnContexts,
                                          final List<String> authnContextAuthorities,
                                          final String inResponseTo) {
        this.nameId = nameId;
        this.issuerId = issuerId;
        this.sessionIndex = sessionIndex;
        this.attributes = samlAttributes;
        this.inResponseTo = inResponseTo;

        if (conditions != null) {
            this.conditions = new SAMLConditions();

            if (conditions.getNotBefore() != null) {
                this.conditions.setNotBefore(ZonedDateTime.ofInstant(conditions.getNotBefore(), ZoneOffset.UTC));
            }

            if (conditions.getNotOnOrAfter() != null) {
                this.conditions.setNotOnOrAfter(ZonedDateTime.ofInstant(conditions.getNotOnOrAfter(), ZoneOffset.UTC));
            }
        } else {
            this.conditions = null;
        }
        this.authnContextAuthorities = authnContextAuthorities;
        this.authnContexts = authnContexts;

        LOGGER.info("Constructed SAML2 credentials: {}", this);
    }

    @Getter
    @Setter
    @ToString
    public static class SAMLNameID implements Serializable {
        @Serial
        private static final long serialVersionUID = -7913473743778305079L;
        private String format;
        private String nameQualifier;
        private String spNameQualifier;
        private String spProviderId;
        private String value;

        public static SAMLNameID from(final NameIDType nameId) {
            val result = new SAMLNameID();
            result.setNameQualifier(nameId.getNameQualifier());
            result.setFormat(nameId.getFormat());
            result.setSpNameQualifier(nameId.getSPNameQualifier());
            result.setSpProviderId(nameId.getSPProvidedID());
            result.setValue(nameId.getValue());
            return result;
        }

        public static SAMLNameID from(final SAMLAttribute attribute) {
            val result = new SAMLNameID();
            result.setValue(attribute.getAttributeValues().get(0));
            result.setFormat(attribute.getNameFormat());
            result.setNameQualifier(attribute.getName());
            result.setSpNameQualifier(attribute.getFriendlyName());
            return result;
        }
    }

    @Getter
    @Setter
    @ToString
    public static class SAMLAttribute implements Serializable {
        @Serial
        private static final long serialVersionUID = 2532838901563948260L;
        private String friendlyName;
        private String name;
        private String nameFormat;
        private List<String> attributeValues = new ArrayList<>();

        public static List<SAMLAttribute> from(final AttributeConverter samlAttributeConverter, final Iterable<Attribute> samlAttributes) {

            List<SAMLAttribute> attributes = new ArrayList<>();

            samlAttributes.forEach(attribute -> {
                val result = samlAttributeConverter.convert(attribute);
                if (result instanceof Collection) {
                    attributes.addAll((Collection<? extends SAMLAttribute>) result);
                } else {
                    attributes.add((SAMLAttribute) result);
                }
            });

            return attributes;
        }
    }

    @Getter
    @Setter
    @ToString
    public static class SAMLConditions implements Serializable {
        @Serial
        private static final long serialVersionUID = -8966585574672014553L;
        private ZonedDateTime notBefore;
        private ZonedDateTime notOnOrAfter;
    }
}
