/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.aiservice;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Exceptions;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.input.structured.StructuredPrompt;
import dev.langchain4j.model.input.structured.StructuredPromptProcessor;
import dev.langchain4j.model.moderation.Moderation;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.AiServiceContext;
import dev.langchain4j.service.AiServiceTokenStream;
import dev.langchain4j.service.AiServices;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.service.output.ServiceOutputParser;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.spi.ServiceHelper;
import io.quarkiverse.langchain4j.audit.Audit;
import io.quarkiverse.langchain4j.audit.AuditService;
import io.quarkiverse.langchain4j.runtime.QuarkusServiceOutputParser;
import io.quarkiverse.langchain4j.runtime.ResponseSchemaUtil;
import io.quarkiverse.langchain4j.runtime.aiservice.AiServiceMethodCreateInfo;
import io.quarkiverse.langchain4j.runtime.aiservice.ChatMemorySeeder;
import io.quarkiverse.langchain4j.runtime.aiservice.DefaultCommittableChatMemory;
import io.quarkiverse.langchain4j.runtime.aiservice.NoopChatMemory;
import io.quarkiverse.langchain4j.runtime.aiservice.QuarkusAiServiceContext;
import io.quarkiverse.langchain4j.spi.DefaultMemoryIdProvider;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
import java.lang.reflect.Array;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Flow;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import org.jboss.logging.Logger;

public class AiServiceMethodImplementationSupport {
    private static final Logger log = Logger.getLogger(AiServiceMethodImplementationSupport.class);
    private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;
    private static final List<DefaultMemoryIdProvider> DEFAULT_MEMORY_ID_PROVIDERS;
    private static final ServiceOutputParser SERVICE_OUTPUT_PARSER;

    public Object implement(Input input) {
        QuarkusAiServiceContext context = input.context;
        AiServiceMethodCreateInfo createInfo = input.createInfo;
        Object[] methodArgs = input.methodArgs;
        AuditService auditService = context.auditService;
        Audit audit = null;
        if (auditService != null) {
            audit = auditService.create(new Audit.CreateInfo(createInfo.getInterfaceName(), createInfo.getMethodName(), methodArgs, createInfo.getMemoryIdParamPosition()));
        }
        try {
            Object result = AiServiceMethodImplementationSupport.doImplement(createInfo, methodArgs, context, audit);
            if (audit != null) {
                audit.onCompletion(result);
                auditService.complete(audit);
            }
            return result;
        }
        catch (Exception e) {
            log.errorv((Throwable)e, "Execution of {0}#{1} failed", (Object)createInfo.getInterfaceName(), (Object)createInfo.getMethodName());
            if (audit != null) {
                audit.onFailure(e);
                auditService.complete(audit);
            }
            throw e;
        }
    }

    private static Object doImplement(final AiServiceMethodCreateInfo methodCreateInfo, Object[] methodArgs, final QuarkusAiServiceContext context, Audit audit) {
        List<ChatMessage> messagesToSend;
        Object chatMemory;
        final Optional<SystemMessage> systemMessage = AiServiceMethodImplementationSupport.prepareSystemMessage(methodCreateInfo, methodArgs);
        UserMessage userMessage = AiServiceMethodImplementationSupport.prepareUserMessage(context, methodCreateInfo, methodArgs);
        if (audit != null) {
            audit.initialMessages(systemMessage, userMessage);
        }
        final Object memoryId = AiServiceMethodImplementationSupport.memoryId(methodCreateInfo, methodArgs, context.chatMemoryProvider != null);
        final boolean needsMemorySeed = AiServiceMethodImplementationSupport.needsMemorySeed(context, memoryId);
        Type returnType = methodCreateInfo.getReturnType();
        if (context.retrievalAugmentor != null) {
            chatMemory = context.hasChatMemory() ? context.chatMemory(memoryId).messages() : null;
            Metadata metadata = Metadata.from((UserMessage)userMessage, (Object)memoryId, (List)chatMemory);
            final AugmentationRequest augmentationRequest = new AugmentationRequest((ChatMessage)userMessage, metadata);
            if (!AiServiceMethodImplementationSupport.isMulti(returnType)) {
                AugmentationResult augmentationResult = context.retrievalAugmentor.augment(augmentationRequest);
                userMessage = (UserMessage)augmentationResult.chatMessage();
            } else {
                CompletableFuture<AugmentationResult> augmentationResultCF = CompletableFuture.supplyAsync(new Supplier<AugmentationResult>(){

                    @Override
                    public AugmentationResult get() {
                        return context.retrievalAugmentor.augment(augmentationRequest);
                    }
                }, Infrastructure.getDefaultWorkerPool());
                return Multi.createFrom().completionStage(augmentationResultCF).flatMap((Function)new Function<AugmentationResult, Flow.Publisher<? extends Object>>(){

                    @Override
                    public Flow.Publisher<?> apply(AugmentationResult ar) {
                        ChatMessage augmentedUserMessage = ar.chatMessage();
                        List<ChatMessage> messagesToSend = this.messagesToSend(augmentedUserMessage, needsMemorySeed);
                        return Multi.createFrom().emitter((Consumer)new MultiEmitterConsumer(messagesToSend, context, memoryId));
                    }

                    private List<ChatMessage> messagesToSend(ChatMessage augmentedUserMessage, boolean needsMemorySeed2) {
                        List<ChatMessage> messagesToSend;
                        if (context.hasChatMemory()) {
                            ChatMemory chatMemory = context.chatMemory(memoryId);
                            messagesToSend = AiServiceMethodImplementationSupport.createMessagesToSendForExistingMemory(systemMessage, augmentedUserMessage, chatMemory, needsMemorySeed2, context, methodCreateInfo);
                        } else {
                            messagesToSend = AiServiceMethodImplementationSupport.createMessagesToSendForNoMemory(systemMessage, augmentedUserMessage, needsMemorySeed2, context, methodCreateInfo);
                        }
                        return messagesToSend;
                    }
                });
            }
        }
        if (context.hasChatMemory()) {
            chatMemory = new DefaultCommittableChatMemory(context.chatMemory(memoryId));
            messagesToSend = AiServiceMethodImplementationSupport.createMessagesToSendForExistingMemory(systemMessage, (ChatMessage)userMessage, (ChatMemory)chatMemory, needsMemorySeed, context, methodCreateInfo);
        } else {
            chatMemory = new NoopChatMemory();
            messagesToSend = AiServiceMethodImplementationSupport.createMessagesToSendForNoMemory(systemMessage, (ChatMessage)userMessage, needsMemorySeed, context, methodCreateInfo);
        }
        if (AiServiceMethodImplementationSupport.isTokenStream(returnType)) {
            chatMemory.commit();
            return new AiServiceTokenStream(messagesToSend, (AiServiceContext)context, memoryId);
        }
        if (AiServiceMethodImplementationSupport.isMulti(returnType)) {
            chatMemory.commit();
            return Multi.createFrom().emitter((Consumer)new MultiEmitterConsumer(messagesToSend, context, memoryId));
        }
        Future<Moderation> moderationFuture = AiServiceMethodImplementationSupport.triggerModerationIfNeeded(context, methodCreateInfo, messagesToSend);
        log.debug((Object)"Attempting to obtain AI response");
        List<ToolSpecification> toolSpecifications = context.toolSpecifications;
        Map<String, ToolExecutor> toolExecutors = context.toolExecutors;
        if (methodCreateInfo.getToolClassNames() != null && !methodCreateInfo.getToolClassNames().isEmpty()) {
            toolSpecifications = methodCreateInfo.getToolSpecifications();
            toolExecutors = methodCreateInfo.getToolExecutors();
        }
        Response response = toolSpecifications == null ? context.chatModel.generate(messagesToSend) : context.chatModel.generate(messagesToSend, toolSpecifications);
        log.debug((Object)"AI response obtained");
        if (audit != null) {
            audit.addLLMToApplicationMessage((Response<AiMessage>)response);
        }
        TokenUsage tokenUsageAccumulator = response.tokenUsage();
        AiServices.verifyModerationIfNeeded(moderationFuture);
        int executionsLeft = 10;
        while (true) {
            if (executionsLeft-- == 0) {
                throw Exceptions.runtime((String)"Something is wrong, exceeded %s sequential tool executions", (Object[])new Object[]{10});
            }
            AiMessage aiMessage = (AiMessage)response.content();
            chatMemory.add((ChatMessage)aiMessage);
            if (!aiMessage.hasToolExecutionRequests()) break;
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                log.debugv("Attempting to execute tool {0}", (Object)toolExecutionRequest);
                ToolExecutor toolExecutor = toolExecutors.get(toolExecutionRequest.name());
                if (toolExecutor == null) {
                    throw Exceptions.runtime((String)"Tool executor %s not found", (Object[])new Object[]{toolExecutionRequest.name()});
                }
                String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
                log.debugv("Result of {0} is '{1}'", (Object)toolExecutionRequest, (Object)toolExecutionResult);
                ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecutionRequest, (String)toolExecutionResult);
                if (audit != null) {
                    audit.addApplicationToLLMMessage(toolExecutionResultMessage);
                }
                chatMemory.add((ChatMessage)toolExecutionResultMessage);
            }
            log.debug((Object)"Attempting to obtain AI response");
            response = context.chatModel.generate(chatMemory.messages(), toolSpecifications);
            log.debug((Object)"AI response obtained");
            if (audit != null) {
                audit.addLLMToApplicationMessage((Response<AiMessage>)response);
            }
            tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
        }
        chatMemory.commit();
        response = Response.from((Object)((AiMessage)response.content()), (TokenUsage)tokenUsageAccumulator, (FinishReason)response.finishReason());
        return SERVICE_OUTPUT_PARSER.parse(response, returnType);
    }

    private static boolean needsMemorySeed(QuarkusAiServiceContext context, Object memoryId) {
        if (context.chatMemorySeeder == null) {
            return false;
        }
        if (!context.hasChatMemory()) {
            return false;
        }
        ChatMemory chatMemory = context.chatMemory(memoryId);
        return chatMemory.messages().isEmpty();
    }

    private static List<ChatMessage> createMessagesToSendForExistingMemory(Optional<SystemMessage> systemMessage, ChatMessage userMessage, ChatMemory chatMemory, boolean needsMemorySeed, QuarkusAiServiceContext context, AiServiceMethodCreateInfo methodCreateInfo) {
        if (systemMessage.isPresent()) {
            chatMemory.add((ChatMessage)systemMessage.get());
        }
        if (needsMemorySeed) {
            List<ChatMessage> seedChatMessages = context.chatMemorySeeder.seed(new ChatMemorySeeder.Context(methodCreateInfo.getMethodName()));
            for (ChatMessage seedChatMessage : seedChatMessages) {
                chatMemory.add(seedChatMessage);
            }
        }
        chatMemory.add(userMessage);
        return chatMemory.messages();
    }

    private static List<ChatMessage> createMessagesToSendForNoMemory(Optional<SystemMessage> systemMessage, ChatMessage userMessage, boolean needsMemorySeed, QuarkusAiServiceContext context, AiServiceMethodCreateInfo methodCreateInfo) {
        ArrayList<ChatMessage> result = new ArrayList<ChatMessage>();
        if (systemMessage.isPresent()) {
            result.add((ChatMessage)systemMessage.get());
        }
        if (needsMemorySeed) {
            result.addAll(context.chatMemorySeeder.seed(new ChatMemorySeeder.Context(methodCreateInfo.getMethodName())));
        }
        result.add(userMessage);
        return result;
    }

    private static boolean isTokenStream(Type returnType) {
        return AiServiceMethodImplementationSupport.isTypeOf(returnType, TokenStream.class);
    }

    private static boolean isMulti(Type returnType) {
        return AiServiceMethodImplementationSupport.isTypeOf(returnType, Multi.class);
    }

    private static boolean isTypeOf(Type type, Class<?> clazz) {
        if (type instanceof Class) {
            return type.equals(clazz);
        }
        if (type instanceof ParameterizedType) {
            ParameterizedType pt = (ParameterizedType)type;
            return AiServiceMethodImplementationSupport.isTypeOf(pt.getRawType(), clazz);
        }
        throw new IllegalStateException("Unsupported return type " + type);
    }

    private static Future<Moderation> triggerModerationIfNeeded(final AiServiceContext context, AiServiceMethodCreateInfo createInfo, final List<ChatMessage> messages) {
        Future<Moderation> moderationFuture = null;
        if (createInfo.isRequiresModeration()) {
            log.debug((Object)"Moderation is required and it will be executed in the background");
            ExecutorService defaultExecutor = (ExecutorService)Infrastructure.getDefaultExecutor();
            moderationFuture = defaultExecutor.submit(new Callable<Moderation>(){

                @Override
                public Moderation call() {
                    List messagesToModerate = AiServices.removeToolMessages((List)messages);
                    log.debug((Object)"Attempting to moderate messages");
                    Moderation result = (Moderation)context.moderationModel.moderate(messagesToModerate).content();
                    log.debug((Object)"Moderation completed");
                    return result;
                }
            });
        }
        return moderationFuture;
    }

    private static Optional<SystemMessage> prepareSystemMessage(AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
        if (createInfo.getSystemMessageInfo().isEmpty()) {
            return Optional.empty();
        }
        AiServiceMethodCreateInfo.TemplateInfo systemMessageInfo = createInfo.getSystemMessageInfo().get();
        HashMap<String, Object> templateParams = new HashMap<String, Object>();
        Map<String, Integer> nameToParamPosition = systemMessageInfo.nameToParamPosition();
        for (Map.Entry<String, Integer> entry : nameToParamPosition.entrySet()) {
            templateParams.put(entry.getKey(), methodArgs[entry.getValue()]);
        }
        templateParams.put(ResponseSchemaUtil.templateParam(), createInfo.getResponseSchemaInfo().outputFormatInstructions());
        Prompt prompt = PromptTemplate.from((String)systemMessageInfo.text().get()).apply(templateParams);
        return Optional.of(prompt.toSystemMessage());
    }

    private static UserMessage prepareUserMessage(AiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
        AiServiceMethodCreateInfo.UserMessageInfo userMessageInfo = createInfo.getUserMessageInfo();
        String userName = null;
        if (userMessageInfo.userNameParamPosition().isPresent()) {
            userName = methodArgs[userMessageInfo.userNameParamPosition().get()].toString();
        }
        if (userMessageInfo.template().isPresent()) {
            boolean hasResponseSchema;
            AiServiceMethodCreateInfo.TemplateInfo templateInfo = userMessageInfo.template().get();
            HashMap<String, Object> templateParams = new HashMap<String, Object>();
            Map<String, Integer> nameToParamPosition = templateInfo.nameToParamPosition();
            for (Map.Entry<String, Integer> entry : nameToParamPosition.entrySet()) {
                Object value = AiServiceMethodImplementationSupport.transformTemplateParamValue(methodArgs[entry.getValue()]);
                templateParams.put(entry.getKey(), value);
            }
            String templateText = templateInfo.text().isPresent() ? templateInfo.text().get() : (String)methodArgs[templateInfo.methodParamPosition().get()];
            boolean bl = hasResponseSchema = createInfo.getResponseSchemaInfo().isInUserMessage().orElse(false) != false || ResponseSchemaUtil.hasResponseSchema(templateText);
            if (hasResponseSchema && !createInfo.getResponseSchemaInfo().enabled()) {
                throw new RuntimeException("The %s placeholder cannot be used if the property quarkus.langchain4j.response-schema is set to false. Found in: %s".formatted(ResponseSchemaUtil.placeholder(), createInfo.getInterfaceName()));
            }
            if (!createInfo.getResponseSchemaInfo().isInSystemMessage() && !hasResponseSchema) {
                templateText = templateText.concat(ResponseSchemaUtil.placeholder());
            }
            templateParams.put(ResponseSchemaUtil.templateParam(), createInfo.getResponseSchemaInfo().outputFormatInstructions());
            Prompt prompt = PromptTemplate.from((String)templateText).apply(templateParams);
            return AiServiceMethodImplementationSupport.createUserMessage(userName, prompt.text());
        }
        if (userMessageInfo.paramPosition().isPresent()) {
            Integer paramIndex = userMessageInfo.paramPosition().get();
            Object argValue = methodArgs[paramIndex];
            if (argValue == null) {
                throw new IllegalArgumentException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName() + "' because parameter with index " + paramIndex + " is null");
            }
            String text = AiServiceMethodImplementationSupport.toString(argValue);
            return AiServiceMethodImplementationSupport.createUserMessage(userName, text.concat(createInfo.getResponseSchemaInfo().outputFormatInstructions()));
        }
        throw new IllegalStateException("Unable to construct UserMessage for class '" + context.aiServiceClass.getName() + "'. Please contact the maintainers");
    }

    private static UserMessage createUserMessage(String name, String text) {
        if (name == null) {
            return UserMessage.userMessage((String)text);
        }
        return UserMessage.userMessage((String)name, (String)text);
    }

    private static Object transformTemplateParamValue(Object value) {
        if (value.getClass().isArray()) {
            return Arrays.toString((Object[])value);
        }
        return value;
    }

    private static Object memoryId(AiServiceMethodCreateInfo createInfo, Object[] methodArgs, boolean hasChatMemoryProvider) {
        if (createInfo.getMemoryIdParamPosition().isPresent()) {
            return methodArgs[createInfo.getMemoryIdParamPosition().get()];
        }
        if (hasChatMemoryProvider) {
            for (DefaultMemoryIdProvider provider : DEFAULT_MEMORY_ID_PROVIDERS) {
                Object memoryId = provider.getMemoryId();
                if (memoryId == null) continue;
                String perServiceSuffix = "#" + createInfo.getInterfaceName() + "." + createInfo.getMethodName();
                return memoryId + perServiceSuffix;
            }
        }
        return "default";
    }

    private static String toString(Object arg) {
        if (arg.getClass().isArray()) {
            return AiServiceMethodImplementationSupport.arrayToString(arg);
        }
        if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
            return StructuredPromptProcessor.toPrompt((Object)arg).text();
        }
        return arg.toString();
    }

    private static String arrayToString(Object arg) {
        StringBuilder sb = new StringBuilder("[");
        int length = Array.getLength(arg);
        for (int i = 0; i < length; ++i) {
            sb.append(AiServiceMethodImplementationSupport.toString(Array.get(arg, i)));
            if (i >= length - 1) continue;
            sb.append(", ");
        }
        sb.append("]");
        return sb.toString();
    }

    static {
        SERVICE_OUTPUT_PARSER = new QuarkusServiceOutputParser();
        Collection defaultMemoryIdProviders = ServiceHelper.loadFactories(DefaultMemoryIdProvider.class);
        if (defaultMemoryIdProviders.isEmpty()) {
            DEFAULT_MEMORY_ID_PROVIDERS = Collections.emptyList();
        } else {
            DEFAULT_MEMORY_ID_PROVIDERS = new ArrayList<DefaultMemoryIdProvider>(defaultMemoryIdProviders);
            DEFAULT_MEMORY_ID_PROVIDERS.sort(new Comparator<DefaultMemoryIdProvider>(){

                @Override
                public int compare(DefaultMemoryIdProvider o1, DefaultMemoryIdProvider o2) {
                    return Integer.compare(o1.priority(), o2.priority());
                }
            });
        }
    }

    public static class Input {
        final QuarkusAiServiceContext context;
        final AiServiceMethodCreateInfo createInfo;
        final Object[] methodArgs;

        public Input(QuarkusAiServiceContext context, AiServiceMethodCreateInfo createInfo, Object[] methodArgs) {
            this.context = context;
            this.createInfo = createInfo;
            this.methodArgs = methodArgs;
        }
    }

    private static class MultiEmitterConsumer
    implements Consumer<MultiEmitter<? super String>> {
        private final List<ChatMessage> messagesToSend;
        private final QuarkusAiServiceContext context;
        private final Object memoryId;

        public MultiEmitterConsumer(List<ChatMessage> messagesToSend, QuarkusAiServiceContext context, Object memoryId) {
            this.messagesToSend = messagesToSend;
            this.context = context;
            this.memoryId = memoryId;
        }

        @Override
        public void accept(final MultiEmitter<? super String> em) {
            new AiServiceTokenStream(this.messagesToSend, (AiServiceContext)this.context, this.memoryId).onNext(arg_0 -> em.emit(arg_0)).onComplete((Consumer)new Consumer<Response<AiMessage>>(){

                @Override
                public void accept(Response<AiMessage> message) {
                    em.complete();
                }
            }).onError(arg_0 -> em.fail(arg_0)).start();
        }
    }

    public static interface Wrapper {
        public Object wrap(Input var1, Function<Input, Object> var2);
    }
}

