/*
 * Decompiled with CFR 0.152.
 */
package fun.fengwk.chatjava.core.client.tool;

import fun.fengwk.chatjava.core.client.ChatClient;
import fun.fengwk.chatjava.core.client.ChatClientOptions;
import fun.fengwk.chatjava.core.client.ChatCompletionsResponse;
import fun.fengwk.chatjava.core.client.StreamChatListener;
import fun.fengwk.chatjava.core.client.request.ChatMessage;
import fun.fengwk.chatjava.core.client.request.ChatRequest;
import fun.fengwk.chatjava.core.client.response.ChatToolCall;
import fun.fengwk.chatjava.core.client.response.ChatToolCallFunction;
import fun.fengwk.chatjava.core.client.tool.ToolFunctionHandler;
import fun.fengwk.chatjava.core.client.tool.ToolFunctionHandlerRegistryView;
import fun.fengwk.chatjava.core.client.tool.ToolStreamChatListener;
import fun.fengwk.chatjava.core.client.util.ChatUtils;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ToolChatClient
implements ChatClient {
    private static final Logger log = LoggerFactory.getLogger(ToolChatClient.class);
    private final ChatClient delegate;
    private final ToolFunctionHandlerRegistryView registryView;

    public ToolChatClient(ChatClient delegate, ToolFunctionHandlerRegistryView registryView) {
        this.delegate = Objects.requireNonNull(delegate);
        this.registryView = registryView;
    }

    @Override
    public ChatClientOptions getClientOptions() {
        return this.delegate.getClientOptions();
    }

    @Override
    public void setClientOptions(ChatClientOptions clientOptions) {
        this.delegate.setClientOptions(clientOptions);
    }

    @Override
    public ChatCompletionsResponse chatCompletions(ChatRequest chatRequest, ChatClientOptions clientOptions) {
        chatRequest = chatRequest.copy();
        this.setTools(chatRequest);
        ChatCompletionsResponse response = this.delegate.chatCompletions(chatRequest, clientOptions);
        int callTimes = 1;
        while (ChatUtils.isToolCalls(response)) {
            if (callTimes > clientOptions.getMaxFunctionCallTimes()) {
                response = this.handleExceedsCallTimes(response, clientOptions.getMaxFunctionCallTimes());
                break;
            }
            ChatMessage message = response.getChatResponse().getChoices().get(0).getMessage();
            chatRequest.getMessages().add(message);
            List<ChatToolCall> toolCalls = message.getTool_calls();
            for (ChatToolCall toolCall : toolCalls) {
                if (!ChatUtils.isFunctionCall(toolCall)) continue;
                ChatToolCallFunction function = toolCall.getFunction();
                ToolFunctionHandler handler = this.registryView.getHandlerRequired(function.getName());
                String result = handler.call(function.getArguments());
                chatRequest.getMessages().add(ChatMessage.newToolMessage(toolCall.getId(), result));
            }
            response = this.delegate.chatCompletions(chatRequest, clientOptions);
        }
        return response;
    }

    @Override
    public CompletableFuture<ChatCompletionsResponse> streamChatCompletions(ChatRequest chatRequest, StreamChatListener chatListener, ChatClientOptions clientOptions) {
        chatRequest = chatRequest.copy();
        this.setTools(chatRequest);
        return this.doStreamChatCompletions(chatRequest, chatListener, clientOptions, 1);
    }

    private CompletableFuture<ChatCompletionsResponse> doStreamChatCompletions(ChatRequest chatRequest, StreamChatListener chatListener, ChatClientOptions clientOptions, int callTimes) {
        CompletableFuture<ChatCompletionsResponse> future = this.delegate.streamChatCompletions(chatRequest, new ToolStreamChatListener(chatListener), clientOptions);
        return future.thenCompose(resp -> {
            if (!ChatUtils.isToolCalls(resp)) {
                return CompletableFuture.completedStage(resp);
            }
            if (!resp.isSuccess()) {
                chatListener.onError(resp.getError());
                return CompletableFuture.completedStage(resp);
            }
            if (callTimes > clientOptions.getMaxFunctionCallTimes()) {
                resp = this.handleExceedsCallTimes((ChatCompletionsResponse)resp, clientOptions.getMaxFunctionCallTimes());
                chatListener.onError(resp.getError());
                return CompletableFuture.completedStage(resp);
            }
            ChatMessage delta = resp.getChatResponse().getChoices().get(0).getDelta();
            chatRequest.getMessages().add(delta);
            List<ChatToolCall> toolCalls = delta.getTool_calls();
            for (ChatToolCall toolCall : toolCalls) {
                if (!ChatUtils.isFunctionCall(toolCall)) continue;
                ChatToolCallFunction function = toolCall.getFunction();
                ToolFunctionHandler handler = this.registryView.getHandlerRequired(function.getName());
                String result = handler.call(function.getArguments());
                chatRequest.getMessages().add(ChatMessage.newToolMessage(toolCall.getId(), result));
            }
            return this.doStreamChatCompletions(chatRequest, chatListener, clientOptions, callTimes + 1);
        });
    }

    private void setTools(ChatRequest chatRequest) {
        chatRequest.setTools(this.registryView.getTools());
    }

    private ChatCompletionsResponse handleExceedsCallTimes(ChatCompletionsResponse resp, int maxFunctionCallTimes) {
        IllegalStateException ex = new IllegalStateException(String.format("The number of calls exceeds %d times", maxFunctionCallTimes));
        log.warn("{}, resp: {}", (Object)ex.getMessage(), (Object)resp);
        return new ChatCompletionsResponse(false, resp.getChatResponse(), ex.getMessage(), ex);
    }
}

