/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package net.shibboleth.shared.spring.expression;

import java.util.function.BiFunction;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.slf4j.Logger;

import org.springframework.expression.EvaluationContext;

import net.shibboleth.shared.annotation.ParameterName;
import net.shibboleth.shared.annotation.constraint.NotEmpty;
import net.shibboleth.shared.collection.Pair;
import net.shibboleth.shared.primitive.LoggerFactory;

/**
 * Predicate whose condition is defined by an Spring EL expression.
 * 
 * @param <T> first input type
 * @param <U> second input type
 * @param <V> return type
 * 
 * @since 6.1.0
 */
public class SpringExpressionBiFunction<T,U,V> extends AbstractSpringExpressionEvaluator 
            implements BiFunction<T,U,V> {
    
    /** Class logger. */
    @Nonnull private final Logger log = LoggerFactory.getLogger(SpringExpressionBiFunction.class);

    /** Input types. */
    @Nullable private Pair<Class<T>,Class<U>> inputTypes;

    /**
     * Constructor.
     *
     * @param expression the expression to evaluate
     */
    public SpringExpressionBiFunction(@Nonnull @NotEmpty @ParameterName(name="expression") final String expression) {
        super(expression);
    }

    /**
     * Get the input type to be enforced.
     *
     * @return input type
     */
    @Nullable public Pair<Class<T>,Class<U>> getInputTypes() {
        return inputTypes;
    }

    /**
     * Set the input types to be enforced.
     *
     * @param types the input types
     */
    public void setInputTypes(@Nullable final Pair<Class<T>,Class<U>> types) {
        if (types != null && types.getFirst() != null && types.getSecond() != null) {
            inputTypes = types;
        } else {
            inputTypes = null;
        }
    }

    /**
     * Set the output type to be enforced.
     *
     * @param type output type
     */
    @Override public void setOutputType(@Nullable final Class<?> type) {
        super.setOutputType(type);
    }

    /**
     * Set value to return if an error occurs.
     *
     * @param value value to return
     */
    @Override public void setReturnOnError(@Nullable final Object value) {
        super.setReturnOnError(value);
    }

    /** {@inheritDoc} */
    @SuppressWarnings("unchecked")
    @Nullable public V apply(@Nullable final T first, @Nullable final U second) {
        final Pair<Class<T>,Class<U>> types = getInputTypes();
        if (null != types) {
            final Class<T> intype1 = types.getFirst();
            final Class<U> intype2 = types.getSecond();

            if (null != first && null != intype1 && !intype1.isInstance(first)) {
                log.error("Input of type {} was not of type {}", first.getClass(), intype1);
                return (V) getReturnOnError();
            }
            if (null != second && null != intype2 && !intype2.isInstance(second)) {
                log.error("Input of type {} was not of type {}", second.getClass(), intype2);
                return (V) getReturnOnError();
            }
        }

        return (V) evaluate(first, second);
    }

    /** {@inheritDoc} */
    @Override
    protected void prepareContext(@Nonnull final EvaluationContext context, @Nullable final Object... input) {
        context.setVariable("input1", input != null ? input[0] : null);
        context.setVariable("input2", input != null ? input[1] : null);
    }
    
}