/*
 * Decompiled with CFR 0.152.
 */
package com.github.tartaricacid.touhoulittlemaid.ai.manager.entity;

import com.github.tartaricacid.touhoulittlemaid.TouhouLittleMaid;
import com.github.tartaricacid.touhoulittlemaid.ai.manager.entity.MaidAIChatManager;
import com.github.tartaricacid.touhoulittlemaid.ai.manager.response.ResponseChat;
import com.github.tartaricacid.touhoulittlemaid.ai.service.ErrorCode;
import com.github.tartaricacid.touhoulittlemaid.ai.service.ResponseCallback;
import com.github.tartaricacid.touhoulittlemaid.ai.service.ServiceType;
import com.github.tartaricacid.touhoulittlemaid.ai.service.function.FunctionCallRegister;
import com.github.tartaricacid.touhoulittlemaid.ai.service.function.IFunctionCall;
import com.github.tartaricacid.touhoulittlemaid.ai.service.function.response.ToolResponse;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.ChatType;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.LLMClient;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.LLMConfig;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.LLMMessage;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.openai.response.FunctionToolCall;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.openai.response.Message;
import com.github.tartaricacid.touhoulittlemaid.ai.service.llm.openai.response.ToolCall;
import com.github.tartaricacid.touhoulittlemaid.ai.service.tts.TTSSite;
import com.github.tartaricacid.touhoulittlemaid.config.subconfig.AIConfig;
import com.github.tartaricacid.touhoulittlemaid.entity.passive.EntityMaid;
import com.google.gson.JsonObject;
import com.google.gson.JsonSyntaxException;
import com.mojang.serialization.DynamicOps;
import com.mojang.serialization.JsonOps;
import java.net.http.HttpRequest;
import java.util.List;
import java.util.Optional;
import net.minecraft.ChatFormatting;
import net.minecraft.network.chat.Component;
import net.minecraft.network.chat.MutableComponent;
import net.minecraft.server.MinecraftServer;
import net.minecraft.server.level.ServerLevel;
import net.minecraft.server.level.ServerPlayer;
import net.minecraft.util.GsonHelper;
import net.minecraft.world.entity.LivingEntity;
import net.minecraft.world.level.Level;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.Logger;

public class LLMCallback
implements ResponseCallback<ResponseChat> {
    private static final int MAX_CALL_COUNT = 3;
    protected final EntityMaid maid;
    protected final MaidAIChatManager chatManager;
    protected int callCount = 0;
    protected long waitingChatBubbleId;
    protected String message;

    public LLMCallback(MaidAIChatManager chatManager, String message, long waitingChatBubbleId) {
        this.maid = chatManager.getMaid();
        this.chatManager = chatManager;
        this.message = message;
        this.waitingChatBubbleId = waitingChatBubbleId;
    }

    @Override
    public void onFailure(HttpRequest request, Throwable throwable, int errorCode) {
        Level level = this.maid.f_19853_;
        if (level instanceof ServerLevel) {
            ServerLevel serverLevel = (ServerLevel)level;
            MinecraftServer server = serverLevel.m_7654_();
            server.m_18707_(() -> {
                LivingEntity patt3043$temp = this.maid.m_269323_();
                if (patt3043$temp instanceof ServerPlayer) {
                    ServerPlayer player = (ServerPlayer)patt3043$temp;
                    String cause = throwable.getLocalizedMessage();
                    MutableComponent errorMessage = ErrorCode.getErrorMessage(ServiceType.LLM, errorCode, cause);
                    player.m_213846_((Component)errorMessage.m_130940_(ChatFormatting.RED));
                }
                this.maid.getChatBubbleManager().removeChatBubble(this.waitingChatBubbleId);
            });
        }
        if (errorCode == 4) {
            TouhouLittleMaid.LOGGER.error("LLM return field is empty, error is {}", (Object)throwable.getMessage());
        } else if (errorCode == 2) {
            TouhouLittleMaid.LOGGER.error("Error in parsing LLM return JSON string, error is {}", (Object)throwable.getMessage());
        } else {
            TouhouLittleMaid.LOGGER.error("LLM request failed: {}, error is {}", (Object)request, (Object)throwable.getMessage());
        }
    }

    @Override
    public void onSuccess(ResponseChat responseChat) {
        String chatText = responseChat.getChatText();
        String ttsText = responseChat.getTtsText();
        if (chatText.isBlank() || ttsText.isBlank()) {
            String message = "Error in Response Chat: %s".formatted(responseChat);
            this.onFailure(null, new Throwable(message), 4);
        } else {
            Level level;
            if (this.callCount == 0) {
                this.chatManager.addUserHistory(this.message);
            }
            this.chatManager.addAssistantHistory(responseChat.toString());
            TTSSite site = this.chatManager.getTTSSite();
            if (((Boolean)AIConfig.TTS_ENABLED.get()).booleanValue() && site != null && site.enabled()) {
                this.chatManager.tts(site, chatText, ttsText, this.waitingChatBubbleId);
            } else if (StringUtils.isNotBlank((CharSequence)chatText) && (level = this.maid.f_19853_) instanceof ServerLevel) {
                ServerLevel serverLevel = (ServerLevel)level;
                MinecraftServer server = serverLevel.m_7654_();
                server.m_18707_(() -> this.maid.getChatBubbleManager().addLLMChatText(chatText, this.waitingChatBubbleId));
            }
        }
    }

    public void onFunctionCall(Message choice, List<LLMMessage> messages, LLMConfig config, LLMClient client) {
        if (this.callCount == 0) {
            this.chatManager.addUserHistory(this.message);
        }
        this.chatManager.addAssistantHistory("", choice.getToolCalls());
        messages.add(LLMMessage.assistantChat(this.maid, choice.getContent(), choice.getToolCalls()));
        choice.getToolCalls().forEach(toolCall -> {
            try {
                this.onSingleCall(messages, config, client, (ToolCall)toolCall);
            }
            catch (JsonSyntaxException exception) {
                String message = "Exception %s, JSON is: %s".formatted(exception.getLocalizedMessage(), toolCall.getFunction().getArguments());
                this.onFailure(null, new Throwable(message), 2);
            }
        });
    }

    private void onSingleCall(List<LLMMessage> messages, LLMConfig config, LLMClient client, ToolCall toolCall) throws JsonSyntaxException {
        FunctionToolCall function = toolCall.getFunction();
        String name = function.getName();
        String arguments = function.getArguments();
        IFunctionCall<?> functionCall = FunctionCallRegister.getFunctionCall(name);
        if (functionCall == null) {
            return;
        }
        Object result = null;
        try {
            JsonObject parse = GsonHelper.m_13864_((String)arguments);
            Optional optional = functionCall.codec().parse((DynamicOps)JsonOps.INSTANCE, (Object)parse).resultOrPartial(arg_0 -> ((Logger)TouhouLittleMaid.LOGGER).error(arg_0));
            if (optional.isEmpty()) {
                return;
            }
            result = optional.get();
        }
        catch (Exception exception) {
            String message = "Exception %s, JSON is: %s".formatted(exception.getLocalizedMessage(), arguments);
            this.onFailure(null, new Throwable(message), 2);
            return;
        }
        TouhouLittleMaid.LOGGER.debug("Use function call: {}, arguments is {}", (Object)functionCall.getId(), (Object)arguments);
        EntityMaid maid = config.maid();
        Level level = maid.f_19853_;
        if (!(level instanceof ServerLevel)) {
            return;
        }
        ServerLevel serverLevel = (ServerLevel)level;
        Object finalResult = result;
        serverLevel.m_7654_().m_18707_(() -> {
            ToolResponse toolResponse = functionCall.onToolCall(finalResult, maid);
            ++this.callCount;
            String response = toolResponse.message();
            this.chatManager.addToolHistory(response, toolCall.getId());
            messages.add(LLMMessage.toolChat(maid, response, toolCall.getId()));
            if (this.callCount >= 3) {
                TouhouLittleMaid.LOGGER.error("Function call count exceed max count: {}", (Object)3);
            } else {
                LLMConfig keepConfig = new LLMConfig(config.model(), config.maid(), ChatType.MULTI_FUNCTION_CALL);
                client.chat(messages, keepConfig, this);
            }
        });
    }
}

