/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.aiservice;

import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.AugmentationResult;
import io.quarkiverse.langchain4j.guardrails.Guardrail;
import io.quarkiverse.langchain4j.guardrails.GuardrailParams;
import io.quarkiverse.langchain4j.guardrails.GuardrailResult;
import io.quarkiverse.langchain4j.guardrails.InputGuardrail;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.InputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrail;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailParams;
import io.quarkiverse.langchain4j.guardrails.OutputGuardrailResult;
import io.quarkiverse.langchain4j.guardrails.OutputTokenAccumulator;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodImplementationSupport;
import io.quarkiverse.langchain4j.runtime.aiservice.GuardrailException;
import io.smallrye.mutiny.Multi;
import jakarta.enterprise.inject.spi.CDI;
import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

public class GuardrailsSupport {
    public static UserMessage invokeInputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, UserMessage userMessage, ChatMemory chatMemory, AugmentationResult augmentationResult, Map<String, Object> templateVariables) {
        InputGuardrailResult result;
        try {
            String userMessageTemplate = methodCreateInfo.getUserMessageTemplate();
            result = GuardrailsSupport.invokeInputGuardRails(methodCreateInfo, new InputGuardrailParams(userMessage, chatMemory, augmentationResult, userMessageTemplate, Collections.unmodifiableMap(templateVariables)));
        }
        catch (Exception e) {
            throw new GuardrailException(e.getMessage(), e);
        }
        if (!result.isSuccess()) {
            throw new GuardrailException(result.toString(), result.getFirstFailureException());
        }
        if (result.hasRewrittenResult()) {
            userMessage = InputGuardrailParams.rewriteUserMessage(userMessage, result.successfulText());
        }
        return userMessage;
    }

    public static Response<AiMessage> invokeOutputGuardrails(AiServiceMethodCreateInfo methodCreateInfo, ChatMemory chatMemory, ChatLanguageModel chatModel, Response<AiMessage> response, List<ToolSpecification> toolSpecifications, OutputGuardrailParams output) {
        int attempt = 0;
        int max = methodCreateInfo.getGuardrailsMaxRetry();
        if (max <= 0) {
            max = 1;
        }
        OutputGuardrailResult result = null;
        while (attempt < max) {
            try {
                result = GuardrailsSupport.invokeOutputGuardRails(methodCreateInfo, output);
            }
            catch (Exception e) {
                throw new GuardrailException(e.getMessage(), e);
            }
            if (result.isSuccess()) break;
            if (!result.isRetry()) {
                throw new GuardrailException(result.toString(), result.getFirstFailureException());
            }
            if (result.getReprompt() != null) {
                chatMemory.add((ChatMessage)UserMessage.userMessage((String)result.getReprompt()));
            }
            response = AiServiceMethodImplementationSupport.executeRequest(methodCreateInfo, (List<ChatMessage>)chatMemory.messages(), chatModel, toolSpecifications);
            chatMemory.add((ChatMessage)response.content());
            ++attempt;
            output = new OutputGuardrailParams((AiMessage)response.content(), output.memory(), output.augmentationResult(), output.userMessageTemplate(), output.variables());
        }
        if (attempt == max) {
            String failureMessages = Optional.ofNullable(result.failures()).orElseGet(List::of).stream().map(OutputGuardrailResult.Failure::message).collect(Collectors.joining(System.lineSeparator()));
            throw new GuardrailException("Output validation failed. The guardrails have reached the maximum number of retries. Guardrail messages:" + System.lineSeparator() + failureMessages);
        }
        if (result.hasRewrittenResult()) {
            response = GuardrailsSupport.rewriteResponse(response, result);
        }
        return response;
    }

    public static Response<AiMessage> rewriteResponse(Response<AiMessage> response, OutputGuardrailResult result) {
        List tools = ((AiMessage)response.content()).toolExecutionRequests();
        AiMessage content = tools != null && !tools.isEmpty() ? new AiMessage(result.successfulText(), tools) : new AiMessage(result.successfulText());
        Map metadata = response.metadata();
        if (result.successfulResult() != null) {
            metadata.put(OutputGuardrailResult.class.getName(), result.successfulResult());
        }
        return new Response((Object)content, response.tokenUsage(), response.finishReason(), metadata);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static OutputGuardrailResult invokeOutputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, OutputGuardrailParams params) {
        if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
            return OutputGuardrailResult.success();
        }
        Class<AiServiceMethodImplementationSupport> clazz = AiServiceMethodImplementationSupport.class;
        synchronized (AiServiceMethodImplementationSupport.class) {
            List<Class<? extends Guardrail>> classes = methodCreateInfo.getOutputGuardrailsClasses();
            if (classes.isEmpty()) {
                for (String className : methodCreateInfo.getOutputGuardrailsClassNames()) {
                    try {
                        classes.add(Class.forName(className, true, Thread.currentThread().getContextClassLoader()));
                    }
                    catch (Exception e) {
                        throw new RuntimeException("Could not find " + OutputGuardrail.class.getSimpleName() + " implementation class: " + className, e);
                    }
                }
            }
            // ** MonitorExit[var3_2] (shouldn't be in output)
            return GuardrailsSupport.guardrailResult(params, classes, OutputGuardrailResult.success(), OutputGuardrailResult::failure);
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static InputGuardrailResult invokeInputGuardRails(AiServiceMethodCreateInfo methodCreateInfo, InputGuardrailParams params) {
        if (methodCreateInfo.getInputGuardrailsClassNames().isEmpty()) {
            return InputGuardrailResult.success();
        }
        Class<AiServiceMethodImplementationSupport> clazz = AiServiceMethodImplementationSupport.class;
        synchronized (AiServiceMethodImplementationSupport.class) {
            List<Class<? extends Guardrail>> classes = methodCreateInfo.getInputGuardrailsClasses();
            if (classes.isEmpty()) {
                for (String className : methodCreateInfo.getInputGuardrailsClassNames()) {
                    try {
                        classes.add(Class.forName(className, true, Thread.currentThread().getContextClassLoader()));
                    }
                    catch (Exception e) {
                        throw new RuntimeException("Could not find " + InputGuardrail.class.getSimpleName() + " implementation class: " + className, e);
                    }
                }
            }
            // ** MonitorExit[var3_2] (shouldn't be in output)
            return GuardrailsSupport.guardrailResult(params, classes, InputGuardrailResult.success(), InputGuardrailResult::failure);
        }
    }

    private static <GR extends GuardrailResult> GR guardrailResult(GuardrailParams params, List<Class<? extends Guardrail>> classes, GR accumulatedResults, Function<List<? extends GuardrailResult.Failure>, GR> producer) {
        for (Class<? extends Guardrail> bean : classes) {
            Object result = ((Guardrail)CDI.current().select(bean, new Annotation[0]).get()).validate(params).validatedBy(bean);
            if (result.isFatal()) {
                return (GR)(accumulatedResults.hasRewrittenResult() ? result.blockRetry() : result);
            }
            if (result.hasRewrittenResult()) {
                params = params.withText(result.successfulText());
            }
            accumulatedResults = GuardrailsSupport.compose(accumulatedResults, result, producer);
        }
        return accumulatedResults;
    }

    private static <GR extends GuardrailResult> GR compose(GR oldResult, GR newResult, Function<List<? extends GuardrailResult.Failure>, GR> producer) {
        if (oldResult.isSuccess()) {
            return newResult;
        }
        if (newResult.isSuccess()) {
            return oldResult;
        }
        ArrayList<GuardrailResult.Failure> failures = new ArrayList<GuardrailResult.Failure>();
        failures.addAll(oldResult.failures());
        failures.addAll(newResult.failures());
        return (GR)((GuardrailResult)producer.apply(failures));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Multi<String> accumulate(Multi<String> upstream, AiServiceMethodCreateInfo methodCreateInfo) {
        if (methodCreateInfo.getOutputGuardrailsClassNames().isEmpty()) {
            return upstream;
        }
        Class<AiServiceMethodImplementationSupport> clazz = AiServiceMethodImplementationSupport.class;
        synchronized (AiServiceMethodImplementationSupport.class) {
            OutputTokenAccumulator accumulator = methodCreateInfo.getOutputTokenAccumulator();
            if (accumulator == null) {
                String cn = methodCreateInfo.getOutputTokenAccumulatorClassName();
                if (cn == null) {
                    // ** MonitorExit[var3_2] (shouldn't be in output)
                    return upstream.collect().in(StringBuilder::new, StringBuilder::append).map(StringBuilder::toString).toMulti();
                }
                try {
                    Class<OutputTokenAccumulator> clazz2 = Class.forName(cn, true, Thread.currentThread().getContextClassLoader()).asSubclass(OutputTokenAccumulator.class);
                    accumulator = (OutputTokenAccumulator)CDI.current().select(clazz2, new Annotation[0]).get();
                    methodCreateInfo.setOutputTokenAccumulator(accumulator);
                }
                catch (Exception e) {
                    throw new RuntimeException("Could not find " + OutputTokenAccumulator.class.getSimpleName() + " implementation class: " + cn, e);
                }
            }
            // ** MonitorExit[var3_2] (shouldn't be in output)
            OutputTokenAccumulator actual = accumulator;
            return upstream.plug(s -> actual.accumulate(upstream));
        }
    }

    public static OutputGuardrailResult invokeOutputGuardrailsForStream(AiServiceMethodCreateInfo methodCreateInfo, OutputGuardrailParams outputGuardrailParams) {
        return GuardrailsSupport.invokeOutputGuardRails(methodCreateInfo, outputGuardrailParams);
    }

    static class GuardrailRetryException
    extends RuntimeException {
        GuardrailRetryException() {
        }
    }
}

