/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.hip;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.action.AsyncNodeActionWithConfig;
import com.alibaba.cloud.ai.graph.action.InterruptableAction;
import com.alibaba.cloud.ai.graph.action.InterruptionMetadata;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.hip.ToolConfig;
import com.alibaba.cloud.ai.graph.state.RemoveByHash;
import com.alibaba.cloud.ai.graph.utils.TypeRef;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;

@HookPositions(value={HookPosition.AFTER_MODEL})
public class HumanInTheLoopHook
extends ModelHook
implements AsyncNodeActionWithConfig,
InterruptableAction {
    private static final Logger log = LoggerFactory.getLogger(HumanInTheLoopHook.class);
    public static final String HITL_NODE_NAME = "HITL";
    private Map<String, ToolConfig> approvalOn;

    private HumanInTheLoopHook(Builder builder) {
        this.approvalOn = new HashMap<String, ToolConfig>(builder.approvalOn);
    }

    public static Builder builder() {
        return new Builder();
    }

    public CompletableFuture<Map<String, Object>> apply(OverAllState state, RunnableConfig config) {
        return this.afterModel(state, config);
    }

    @Override
    public CompletableFuture<Map<String, Object>> afterModel(OverAllState state, RunnableConfig config) {
        Optional feedback = config.getMetadataAndRemove("HUMAN_FEEDBACK", (TypeRef)new TypeRef<InterruptionMetadata>(){});
        InterruptionMetadata interruptionMetadata = feedback.orElse(null);
        if (interruptionMetadata == null) {
            log.debug("No human feedback found in the runnable config metadata, no tool to execute or none needs feedback.");
            return CompletableFuture.completedFuture(Map.of());
        }
        AssistantMessage assistantMessage = HumanInTheLoopHook.getLastAssistantMessage(state);
        if (assistantMessage != null) {
            if (!assistantMessage.hasToolCalls()) {
                log.info("Found human feedback but last AssistantMessage has no tool calls, nothing to process for human feedback.");
                return CompletableFuture.completedFuture(Map.of());
            }
            ArrayList<AssistantMessage.ToolCall> newToolCalls = new ArrayList<AssistantMessage.ToolCall>();
            ArrayList<ToolResponseMessage.ToolResponse> responses = new ArrayList<ToolResponseMessage.ToolResponse>();
            ToolResponseMessage rejectedMessage = ToolResponseMessage.builder().responses(responses).build();
            for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
                Optional<InterruptionMetadata.ToolFeedback> toolFeedbackOpt = interruptionMetadata.toolFeedbacks().stream().filter(tf -> tf.getName().equals(toolCall.name())).findFirst();
                if (toolFeedbackOpt.isPresent()) {
                    InterruptionMetadata.ToolFeedback toolFeedback = toolFeedbackOpt.get();
                    InterruptionMetadata.ToolFeedback.FeedbackResult result = toolFeedback.getResult();
                    if (result == InterruptionMetadata.ToolFeedback.FeedbackResult.APPROVED) {
                        newToolCalls.add(toolCall);
                        continue;
                    }
                    if (result == InterruptionMetadata.ToolFeedback.FeedbackResult.EDITED) {
                        AssistantMessage.ToolCall editedToolCall = new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(), toolCall.name(), toolFeedback.getArguments());
                        newToolCalls.add(editedToolCall);
                        continue;
                    }
                    if (result != InterruptionMetadata.ToolFeedback.FeedbackResult.REJECTED) continue;
                    newToolCalls.add(toolCall);
                    ToolResponseMessage.ToolResponse response = new ToolResponseMessage.ToolResponse(toolCall.id(), toolCall.name(), String.format("Tool call request for %s has been rejected by human. The reason for why this tool is rejected and the suggestion for next possible tool choose is listed as below:\n %s.", toolFeedback.getName(), toolFeedback.getDescription()));
                    responses.add(response);
                    continue;
                }
                newToolCalls.add(toolCall);
            }
            HashMap updates = new HashMap();
            ArrayList<Object> newMessages = new ArrayList<Object>();
            if (!newToolCalls.isEmpty()) {
                newMessages.add(AssistantMessage.builder().content(assistantMessage.getText()).properties(assistantMessage.getMetadata()).toolCalls(newToolCalls).media(assistantMessage.getMedia()).build());
                newMessages.add(new RemoveByHash((Object)assistantMessage));
            }
            if (!rejectedMessage.getResponses().isEmpty()) {
                newMessages.add(rejectedMessage);
            }
            updates.put("messages", newMessages);
            return CompletableFuture.completedFuture(updates);
        }
        log.warn("Last message is not an AssistantMessage, cannot process human feedback.");
        return CompletableFuture.completedFuture(Map.of());
    }

    public Optional<InterruptionMetadata> interrupt(String nodeId, OverAllState state, RunnableConfig config) {
        AssistantMessage lastMessage = HumanInTheLoopHook.getLastAssistantMessage(state);
        if (lastMessage == null || !lastMessage.hasToolCalls()) {
            return Optional.empty();
        }
        Optional feedback = config.metadata("HUMAN_FEEDBACK");
        if (feedback.isPresent()) {
            if (!(feedback.get() instanceof InterruptionMetadata)) {
                throw new IllegalArgumentException("Human feedback metadata must be of type InterruptionMetadata.");
            }
            if (!this.validateFeedback((InterruptionMetadata)feedback.get(), lastMessage.getToolCalls())) {
                return this.buildInterruptionMetadata(state, lastMessage);
            }
            return Optional.empty();
        }
        return this.buildInterruptionMetadata(state, lastMessage);
    }

    private static AssistantMessage getLastAssistantMessage(OverAllState state) {
        List messages = state.value("messages").orElse(List.of());
        AssistantMessage lastMessage = null;
        for (int i = messages.size() - 1; i >= 0; --i) {
            Message msg = (Message)messages.get(i);
            if (!(msg instanceof AssistantMessage)) continue;
            AssistantMessage assistantMessage = (AssistantMessage)msg;
            if (i + 1 < messages.size() && messages.get(i + 1) instanceof ToolResponseMessage) break;
            lastMessage = assistantMessage;
            break;
        }
        return lastMessage;
    }

    private Optional<InterruptionMetadata> buildInterruptionMetadata(OverAllState state, AssistantMessage lastMessage) {
        boolean needsInterruption = false;
        InterruptionMetadata.Builder builder = InterruptionMetadata.builder((String)this.getName(), (OverAllState)state);
        for (AssistantMessage.ToolCall toolCall : lastMessage.getToolCalls()) {
            if (this.approvalOn.containsKey(toolCall.name())) {
                ToolConfig toolConfig = this.approvalOn.get(toolCall.name());
                String description = toolConfig.getDescription();
                String content = "The AI is requesting to use the tool: " + toolCall.name() + ".\n" + (String)(description != null ? "Description: " + description + "\n" : "") + "With the following arguments: " + toolCall.arguments() + "\nDo you approve?";
                builder.addToolFeedback(InterruptionMetadata.ToolFeedback.builder().id(toolCall.id()).name(toolCall.name()).description(content).arguments(toolCall.arguments()).build()).build();
                needsInterruption = true;
                continue;
            }
            builder.addToolsAutomaticallyApproved(toolCall);
        }
        return needsInterruption ? Optional.of(builder.build()) : Optional.empty();
    }

    private boolean validateFeedback(InterruptionMetadata feedback, List<AssistantMessage.ToolCall> toolCalls) {
        if (feedback == null || feedback.toolFeedbacks() == null || feedback.toolFeedbacks().isEmpty()) {
            return false;
        }
        List toolFeedbacks = feedback.toolFeedbacks();
        List<AssistantMessage.ToolCall> toolCallsNeedingApproval = toolCalls.stream().filter(tc -> this.approvalOn.containsKey(tc.name())).toList();
        if (toolCallsNeedingApproval.isEmpty()) {
            return true;
        }
        for (AssistantMessage.ToolCall call2 : toolCallsNeedingApproval) {
            InterruptionMetadata.ToolFeedback matchedFeedback = toolFeedbacks.stream().filter(tf -> tf.getName().equals(call2.name()) && call2.id().equals(tf.getId())).findFirst().orElse(null);
            if (matchedFeedback == null) {
                log.warn("Missing feedback for tool {} (id={}); waiting for human input.", (Object)call2.name(), (Object)call2.id());
                return false;
            }
            if (matchedFeedback.getResult() != null) continue;
            log.warn("Feedback result for tool {} (id={}) is null; waiting for human input.", (Object)call2.name(), (Object)call2.id());
            return false;
        }
        for (InterruptionMetadata.ToolFeedback tf2 : toolFeedbacks) {
            boolean matched = toolCallsNeedingApproval.stream().anyMatch(call -> call.name().equals(tf2.getName()) && call.id().equals(tf2.getId()));
            if (matched) continue;
            log.warn("Ignoring unexpected tool feedback: name={}, id={}", (Object)tf2.getName(), (Object)tf2.getId());
        }
        return true;
    }

    @Override
    public String getName() {
        return HITL_NODE_NAME;
    }

    @Override
    public List<JumpTo> canJumpTo() {
        return List.of();
    }

    public static class Builder {
        private Map<String, ToolConfig> approvalOn = new HashMap<String, ToolConfig>();

        public Builder approvalOn(String toolName, ToolConfig toolConfig) {
            this.approvalOn.put(toolName, toolConfig);
            return this;
        }

        public Builder approvalOn(String toolName, String description) {
            ToolConfig config = new ToolConfig();
            config.setDescription(description);
            this.approvalOn.put(toolName, config);
            return this;
        }

        public Builder approvalOn(Map<String, ToolConfig> approvalOn) {
            this.approvalOn.putAll(approvalOn);
            return this;
        }

        public HumanInTheLoopHook build() {
            return new HumanInTheLoopHook(this);
        }
    }
}

