/*
 * Decompiled with CFR 0.152.
 */
package io.quarkiverse.langchain4j.runtime.devui;

import dev.langchain4j.agent.tool.ToolExecutionRequest;
import dev.langchain4j.agent.tool.ToolSpecification;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.ToolExecutionResultMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.memory.chat.ChatMemoryProvider;
import dev.langchain4j.model.StreamingResponseHandler;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.chat.StreamingChatLanguageModel;
import dev.langchain4j.model.output.FinishReason;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.model.output.TokenUsage;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.service.tool.ToolExecutor;
import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.service.tool.ToolProviderRequest;
import dev.langchain4j.service.tool.ToolProviderResult;
import io.quarkiverse.langchain4j.runtime.ToolsRecorder;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatMessagePojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ChatResultPojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionRequestPojo;
import io.quarkiverse.langchain4j.runtime.devui.json.ToolExecutionResultPojo;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutor;
import io.quarkiverse.langchain4j.runtime.tool.QuarkusToolExecutorFactory;
import io.quarkiverse.langchain4j.runtime.tool.ToolMethodCreateInfo;
import io.quarkus.arc.All;
import io.quarkus.arc.Arc;
import io.quarkus.logging.Log;
import io.smallrye.mutiny.Multi;
import io.smallrye.mutiny.infrastructure.Infrastructure;
import io.smallrye.mutiny.subscription.MultiEmitter;
import io.vertx.core.json.JsonObject;
import jakarta.enterprise.context.control.ActivateRequestContext;
import java.lang.annotation.Annotation;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

@ActivateRequestContext
public class ChatJsonRPCService {
    public static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20;
    private final ChatLanguageModel model;
    private final Optional<StreamingChatLanguageModel> streamingModel;
    private final ChatMemoryProvider memoryProvider;
    private RetrievalAugmentor retrievalAugmentor;
    private final List<ToolSpecification> toolSpecifications;
    private final Map<String, ToolExecutor> toolExecutors;
    private final ToolProvider toolProvider;
    private final AtomicReference<ChatMemory> currentMemory = new AtomicReference();
    private final AtomicLong currentMemoryId = new AtomicLong();

    public ChatJsonRPCService(@All List<ChatLanguageModel> models, @All List<StreamingChatLanguageModel> streamingModels, @All List<Supplier<RetrievalAugmentor>> retrievalAugmentorSuppliers, @All List<RetrievalAugmentor> retrievalAugmentors, ChatMemoryProvider memoryProvider, QuarkusToolExecutorFactory toolExecutorFactory, @All List<Supplier<ToolProvider>> toolProvidersSuppliers, @All List<ToolProvider> toolProviders) {
        this.model = models.get(0);
        this.toolProvider = this.getToolProvider(toolProvidersSuppliers, toolProviders);
        this.streamingModel = streamingModels.isEmpty() ? Optional.empty() : Optional.of(streamingModels.get(0));
        this.retrievalAugmentor = null;
        for (Supplier<RetrievalAugmentor> supplier : retrievalAugmentorSuppliers) {
            this.retrievalAugmentor = supplier.get();
            if (this.retrievalAugmentor == null) continue;
            break;
        }
        if (this.retrievalAugmentor == null) {
            Iterator<Supplier<RetrievalAugmentor>> iterator = retrievalAugmentors.iterator();
            while (iterator.hasNext()) {
                RetrievalAugmentor augmentorFromCdi;
                this.retrievalAugmentor = augmentorFromCdi = (RetrievalAugmentor)iterator.next();
                if (this.retrievalAugmentor == null) continue;
                break;
            }
        }
        this.memoryProvider = memoryProvider;
        Map<String, List<ToolMethodCreateInfo>> toolsMetadata = ToolsRecorder.getMetadata();
        if (toolsMetadata != null && this.toolProvider == null) {
            this.toolExecutors = new HashMap<String, ToolExecutor>();
            this.toolSpecifications = new ArrayList<ToolSpecification>();
            for (Map.Entry<String, List<ToolMethodCreateInfo>> entry : toolsMetadata.entrySet()) {
                for (ToolMethodCreateInfo methodCreateInfo : entry.getValue()) {
                    Object objectWithTool = null;
                    try {
                        objectWithTool = Arc.container().select(Thread.currentThread().getContextClassLoader().loadClass(entry.getKey()), new Annotation[0]).get();
                    }
                    catch (ClassNotFoundException e) {
                        throw new RuntimeException(e);
                    }
                    QuarkusToolExecutor.Context executorContext = new QuarkusToolExecutor.Context(objectWithTool, methodCreateInfo.invokerClassName(), methodCreateInfo.methodName(), methodCreateInfo.argumentMapperClassName(), methodCreateInfo.executionModel());
                    this.toolExecutors.put(methodCreateInfo.toolSpecification().name(), toolExecutorFactory.create(executorContext));
                    this.toolSpecifications.add(methodCreateInfo.toolSpecification());
                }
            }
        } else if (this.toolProvider != null) {
            this.toolExecutors = new HashMap<String, ToolExecutor>();
            this.toolSpecifications = new ArrayList<ToolSpecification>();
        } else {
            this.toolSpecifications = List.of();
            this.toolExecutors = Map.of();
        }
    }

    public String reset(String systemMessage) {
        if (this.currentMemory.get() != null) {
            this.currentMemory.get().clear();
        }
        long memoryId = ThreadLocalRandom.current().nextLong();
        this.currentMemoryId.set(memoryId);
        ChatMemory memory = this.memoryProvider.get((Object)memoryId);
        this.currentMemory.set(memory);
        if (systemMessage != null && !systemMessage.isEmpty()) {
            memory.add((ChatMessage)new SystemMessage(systemMessage));
        }
        return "OK";
    }

    public boolean isStreamingChatSupported() {
        return this.streamingModel.isPresent();
    }

    public Multi<JsonObject> streamingChat(String message, boolean ragEnabled) {
        ChatMemory m = this.currentMemory.get();
        if (m == null) {
            this.reset("");
            m = this.currentMemory.get();
        }
        final ChatMemory memory = m;
        List chatMemoryBackup = memory.messages();
        Multi stream = Multi.createFrom().emitter(em -> {
            try {
                UserMessage userMessage = UserMessage.from((String)message);
                if (this.retrievalAugmentor != null && ragEnabled) {
                    Metadata metadata = Metadata.from((UserMessage)userMessage, (Object)this.currentMemoryId.get(), (List)memory.messages());
                    AugmentationRequest augmentationRequest = new AugmentationRequest((ChatMessage)userMessage, metadata);
                    ChatMessage augmentedMessage = this.retrievalAugmentor.augment(augmentationRequest).chatMessage();
                    memory.add(augmentedMessage);
                    em.emit((Object)new JsonObject().put("augmentedMessage", (Object)augmentedMessage.text()));
                } else {
                    memory.add((ChatMessage)new UserMessage(message));
                }
                StreamingChatLanguageModel streamingModel = this.streamingModel.orElseThrow(IllegalStateException::new);
                boolean hasToolProvider = this.setToolsViaProviderIfAvailable(memory, userMessage);
                if (this.toolSpecifications.isEmpty()) {
                    streamingModel.generate(memory.messages(), (StreamingResponseHandler)new StreamingResponseHandler<AiMessage>(){
                        final /* synthetic */ ChatJsonRPCService this$0;
                        {
                            this.this$0 = this$0;
                        }

                        public void onComplete(Response<AiMessage> response) {
                            memory.add((ChatMessage)response.content());
                            String message = ((AiMessage)response.content()).text();
                            em.emit((Object)new JsonObject().put("message", (Object)message));
                            em.complete();
                        }

                        public void onNext(String token) {
                            em.emit((Object)new JsonObject().put("token", (Object)token));
                        }

                        public void onError(Throwable error) {
                            em.fail(error);
                        }
                    });
                } else {
                    this.executeWithToolsAndStreaming(memory, (MultiEmitter<? super JsonObject>)em, 20);
                }
            }
            catch (Throwable t) {
                memory.clear();
                chatMemoryBackup.forEach(arg_0 -> ((ChatMemory)memory).add(arg_0));
                Log.warn((Object)t);
                em.fail(t);
            }
        });
        return stream.runSubscriptionOn((Executor)Infrastructure.getDefaultWorkerPool());
    }

    public ChatResultPojo chat(String message, boolean ragEnabled) {
        ChatMemory memory = this.currentMemory.get();
        if (memory == null) {
            this.reset("");
            memory = this.currentMemory.get();
        }
        List chatMemoryBackup = memory.messages();
        try {
            UserMessage userMessage = UserMessage.from((String)message);
            if (this.retrievalAugmentor != null && ragEnabled) {
                Metadata metadata = Metadata.from((UserMessage)userMessage, (Object)this.currentMemoryId.get(), (List)memory.messages());
                AugmentationRequest augmentationRequest = new AugmentationRequest((ChatMessage)userMessage, metadata);
                ChatMessage augmentedMessage = this.retrievalAugmentor.augment(augmentationRequest).chatMessage();
                memory.add(augmentedMessage);
            } else {
                memory.add((ChatMessage)new UserMessage(message));
            }
            boolean hasToolProvider = this.setToolsViaProviderIfAvailable(memory, userMessage);
            if (this.toolSpecifications.isEmpty()) {
                Response modelResponse = this.model.generate(memory.messages());
                memory.add((ChatMessage)modelResponse.content());
            } else {
                this.executeWithTools(memory);
            }
            if (hasToolProvider) {
                this.toolSpecifications.clear();
                this.toolExecutors.clear();
            }
            List<ChatMessagePojo> response = ChatMessagePojo.listFromMemory(memory);
            return new ChatResultPojo(response, null);
        }
        catch (Throwable t) {
            memory.clear();
            chatMemoryBackup.forEach(arg_0 -> ((ChatMemory)memory).add(arg_0));
            Log.warn((Object)t);
            return new ChatResultPojo(null, t.getMessage());
        }
    }

    private Response<AiMessage> executeWithTools(ChatMemory memory) {
        int MAX_SEQUENTIAL_TOOL_EXECUTIONS;
        Response response = this.model.generate(memory.messages(), this.toolSpecifications);
        int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS = 20;
        while (true) {
            if (executionsLeft-- == 0) {
                throw new RuntimeException("Something is wrong, exceeded " + MAX_SEQUENTIAL_TOOL_EXECUTIONS + " sequential tool executions");
            }
            AiMessage aiMessage = (AiMessage)response.content();
            memory.add((ChatMessage)aiMessage);
            if (!aiMessage.hasToolExecutionRequests()) break;
            for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                ToolExecutor toolExecutor = this.toolExecutors.get(toolExecutionRequest.name());
                String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, (Object)this.currentMemoryId.get());
                ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecutionRequest, (String)toolExecutionResult);
                memory.add((ChatMessage)toolExecutionResultMessage);
            }
            response = this.model.generate(memory.messages(), this.toolSpecifications);
        }
        return Response.from((Object)((AiMessage)response.content()), (TokenUsage)new TokenUsage(), (FinishReason)response.finishReason());
    }

    private void executeWithToolsAndStreaming(final ChatMemory memory, final MultiEmitter<? super JsonObject> em, int toolExecutionsLeft) {
        if (--toolExecutionsLeft == 0) {
            throw new RuntimeException("Something is wrong, exceeded 20 sequential tool executions");
        }
        final int finalToolExecutionsLeft = toolExecutionsLeft;
        this.streamingModel.get().generate(memory.messages(), this.toolSpecifications, (StreamingResponseHandler)new StreamingResponseHandler<AiMessage>(){
            final /* synthetic */ ChatJsonRPCService this$0;
            {
                this.this$0 = this$0;
            }

            public void onComplete(Response<AiMessage> response) {
                Infrastructure.getDefaultExecutor().execute(() -> {
                    AiMessage aiMessage = (AiMessage)response.content();
                    memory.add((ChatMessage)aiMessage);
                    if (!aiMessage.hasToolExecutionRequests()) {
                        em.emit((Object)new JsonObject().put("message", (Object)aiMessage.text()));
                        em.complete();
                    } else {
                        for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
                            ToolExecutor toolExecutor = this.this$0.toolExecutors.get(toolExecutionRequest.name());
                            ToolExecutionRequestPojo toolExecutionRequestPojo = new ToolExecutionRequestPojo(toolExecutionRequest.id(), toolExecutionRequest.name(), toolExecutionRequest.arguments());
                            em.emit((Object)new JsonObject().put("toolExecutionRequest", (Object)toolExecutionRequestPojo));
                            String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, (Object)this.this$0.currentMemoryId.get());
                            ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from((ToolExecutionRequest)toolExecutionRequest, (String)toolExecutionResult);
                            memory.add((ChatMessage)toolExecutionResultMessage);
                            ToolExecutionResultPojo toolExecutionResultPojo = new ToolExecutionResultPojo(toolExecutionResultMessage.id(), toolExecutionResultMessage.toolName(), toolExecutionResultMessage.text());
                            em.emit((Object)new JsonObject().put("toolExecutionResult", (Object)toolExecutionResultPojo));
                        }
                        this.this$0.executeWithToolsAndStreaming(memory, (MultiEmitter<? super JsonObject>)em, finalToolExecutionsLeft);
                    }
                    if (this.this$0.toolProvider != null) {
                        this.this$0.toolSpecifications.clear();
                        this.this$0.toolExecutors.clear();
                    }
                });
            }

            public void onNext(String token) {
                em.emit((Object)new JsonObject().put("token", (Object)token));
            }

            public void onError(Throwable error) {
                throw new RuntimeException(error);
            }
        });
    }

    private ToolProvider getToolProvider(List<Supplier<ToolProvider>> toolProviderSuppliers, List<ToolProvider> toolProviders) {
        for (Supplier<ToolProvider> supplier : toolProviderSuppliers) {
            if (supplier.get() == null) continue;
            return supplier.get();
        }
        for (ToolProvider toolProvider : toolProviders) {
            if (toolProvider == null) continue;
            return toolProvider;
        }
        return null;
    }

    private boolean setToolsViaProviderIfAvailable(ChatMemory memory, UserMessage userMessage) {
        boolean hasToolProvider;
        boolean bl = hasToolProvider = this.toolProvider != null;
        if (hasToolProvider) {
            ToolProviderRequest toolRequest = new ToolProviderRequest((Object)memory, userMessage);
            ToolProviderResult toolsResult = this.toolProvider.provideTools(toolRequest);
            for (ToolSpecification specification : toolsResult.tools().keySet()) {
                this.toolSpecifications.add(specification);
                this.toolExecutors.put(specification.name(), (ToolExecutor)toolsResult.tools().get(specification));
            }
        }
        return hasToolProvider;
    }
}

