/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.hip;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.InterruptableAction;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.hip.ToolConfig;
import com.alibaba.cloud.ai.graph.state.RemoveByHash;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;

@HookPositions(value={HookPosition.AFTER_MODEL})
public class HumanInTheLoopHook
extends ModelHook
implements AsyncNodeActionWithConfig,
InterruptableAction {
    private static final Logger log = LoggerFactory.getLogger(HumanInTheLoopHook.class);
    private Map<String, ToolConfig> approvalOn;

    private HumanInTheLoopHook(Builder builder) {
        this.approvalOn = new HashMap<String, ToolConfig>(builder.approvalOn);
    }

    public static Builder builder() {
        return new Builder();
    }

    public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
        return this.afterModel(state, config);
    }

    @Override
    public CompletableFuture<Map<String, Object>> afterModel(OverAllState state, RunnableConfig config) {
        Optional feedback = config.metadata("HUMAN_FEEDBACK");
        InterruptionMetadata interruptionMetadata = feedback.orElse(null);
        if (interruptionMetadata == null) {
            log.info("No human feedback found in the runnable config metadata, no tool to execute or none needs feedback.");
            return CompletableFuture.completedFuture(Map.of());
        }
        List messages = state.value("messages").orElse(List.of());
        Message lastMessage = (Message)messages.get(messages.size() - 1);
        if (lastMessage instanceof AssistantMessage) {
            AssistantMessage assistantMessage = (AssistantMessage)lastMessage;
            if (!assistantMessage.hasToolCalls()) {
                log.info("Found human feedback but last AssistantMessage has no tool calls, nothing to process for human feedback.");
                return CompletableFuture.completedFuture(Map.of());
            }
            ArrayList<AssistantMessage.ToolCall> newToolCalls = new ArrayList<AssistantMessage.ToolCall>();
            ArrayList<ToolResponseMessage.ToolResponse> responses = new ArrayList<ToolResponseMessage.ToolResponse>();
            ToolResponseMessage rejectedMessage = new ToolResponseMessage(responses);
            for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                Optional<InterruptionMetadata.ToolFeedback> toolFeedbackOpt = interruptionMetadata.toolFeedbacks().stream().filter(tf -> tf.getName().equals(toolCall.name())).findFirst();
                if (toolFeedbackOpt.isPresent()) {
                    InterruptionMetadata.ToolFeedback toolFeedback = toolFeedbackOpt.get();
                    InterruptionMetadata.ToolFeedback.FeedbackResult result = toolFeedback.getResult();
                    if (result == InterruptionMetadata.ToolFeedback.FeedbackResult.APPROVED) {
                        newToolCalls.add(toolCall);
                        continue;
                    }
                    if (result == InterruptionMetadata.ToolFeedback.FeedbackResult.EDITED) {
                        AssistantMessage.ToolCall editedToolCall = new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(), toolCall.name(), toolFeedback.getArguments());
                        newToolCalls.add(editedToolCall);
                        continue;
                    }
                    if (result != InterruptionMetadata.ToolFeedback.FeedbackResult.REJECTED) continue;
                    ToolResponseMessage.ToolResponse response = new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), String.format("Tool call request for %s has been rejected by human. The reason for why this tool is rejected and the suggestion for next possible tool choose is listed as below:\n %s.", toolFeedback.getName(), toolFeedback.getDescription()));
                    responses.add(response);
                    continue;
                }
                newToolCalls.add(toolCall);
            }
            HashMap updates = new HashMap();
            ArrayList<Object> newMessages = new ArrayList<Object>();
            if (!rejectedMessage.getResponses().isEmpty()) {
                newMessages.add(rejectedMessage);
            }
            if (!newToolCalls.isEmpty()) {
                newMessages.add(new AssistantMessage(assistantMessage.getText(), assistantMessage.getMetadata(), newToolCalls, assistantMessage.getMedia()));
                newMessages.add(new RemoveByHash((Object)assistantMessage));
            }
            updates.put("messages", newMessages);
            return CompletableFuture.completedFuture(updates);
        }
        log.warn("Last message is not an AssistantMessage, cannot process human feedback.");
        return CompletableFuture.completedFuture(Map.of());
    }

    public Optional<InterruptionMetadata> interrupt(String nodeId, OverAllState state, RunnableConfig config) {
        AssistantMessage assistantMessage;
        Optional feedback = config.metadata("HUMAN_FEEDBACK");
        if (feedback.isPresent()) {
            if (!(feedback.get() instanceof InterruptionMetadata)) {
                throw new IllegalArgumentException("Human feedback metadata must be of type InterruptionMetadata.");
            }
            if (!this.validateFeedback((InterruptionMetadata)feedback.get())) {
                return Optional.of((InterruptionMetadata)feedback.get());
            }
            return Optional.empty();
        }
        List messages = state.value("messages").orElse(List.of());
        Message lastMessage = (Message)messages.get(messages.size() - 1);
        if (lastMessage instanceof AssistantMessage && (assistantMessage = (AssistantMessage)lastMessage).hasToolCalls()) {
            boolean needsInterruption = false;
            InterruptionMetadata.Builder builder = InterruptionMetadata.builder((String)this.getName(), (OverAllState)state);
            for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                if (!this.approvalOn.containsKey(toolCall.name())) continue;
                ToolConfig toolConfig = this.approvalOn.get(toolCall.name());
                String description = toolConfig.getDescription();
                String content = "The AI is requesting to use the tool: " + toolCall.name() + ".\n" + (String)(description != null ? "Description: " + description + "\n" : "") + "With the following arguments: " + toolCall.arguments() + "\nDo you approve?";
                builder.addToolFeedback(InterruptionMetadata.ToolFeedback.builder().id(toolCall.id()).name(toolCall.name()).description(content).arguments(toolCall.arguments()).build()).build();
                needsInterruption = true;
            }
            return needsInterruption ? Optional.of(builder.build()) : Optional.empty();
        }
        return Optional.empty();
    }

    private boolean validateFeedback(InterruptionMetadata feedback) {
        if (feedback == null || feedback.toolFeedbacks() == null || feedback.toolFeedbacks().isEmpty()) {
            return false;
        }
        List toolFeedbacks = feedback.toolFeedbacks();
        for (InterruptionMetadata.ToolFeedback toolFeedback : toolFeedbacks) {
            if (toolFeedback.getResult() != null) continue;
            return false;
        }
        if (toolFeedbacks.size() != this.approvalOn.size()) {
            return false;
        }
        for (InterruptionMetadata.ToolFeedback toolFeedback : toolFeedbacks) {
            if (this.approvalOn.containsKey(toolFeedback.getName())) continue;
            return false;
        }
        return true;
    }

    @Override
    public String getName() {
        return "HIP";
    }

    @Override
    public List<JumpTo> canJumpTo() {
        return List.of();
    }

    public static class Builder {
        private Map<String, ToolConfig> approvalOn = new HashMap<String, ToolConfig>();

        public Builder approvalOn(String toolName, ToolConfig toolConfig) {
            this.approvalOn.put(toolName, toolConfig);
            return this;
        }

        public Builder approvalOn(String toolName, String description) {
            ToolConfig config = new ToolConfig();
            config.setDescription(description);
            this.approvalOn.put(toolName, config);
            return this;
        }

        public Builder approvalOn(Map<String, ToolConfig> approvalOn) {
            this.approvalOn.putAll(approvalOn);
            return this;
        }

        public HumanInTheLoopHook build() {
            return new HumanInTheLoopHook(this);
        }
    }
}

