/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.hook.toolcalllimit;

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.RunnableConfig;
import com.alibaba.cloud.ai.graph.agent.hook.HookPosition;
import com.alibaba.cloud.ai.graph.agent.hook.HookPositions;
import com.alibaba.cloud.ai.graph.agent.hook.JumpTo;
import com.alibaba.cloud.ai.graph.agent.hook.ModelHook;
import com.alibaba.cloud.ai.graph.agent.hook.toolcalllimit.ToolCallLimitExceededException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;

@HookPositions(value={HookPosition.BEFORE_MODEL, HookPosition.AFTER_MODEL})
public class ToolCallLimitHook
extends ModelHook {
    private static final String THREAD_COUNT_KEY_PREFIX = "__tool_call_limit_thread_count__";
    private static final String RUN_COUNT_KEY_PREFIX = "__tool_call_limit_run_count__";
    private final String toolName;
    private final Integer threadLimit;
    private final Integer runLimit;
    private final ExitBehavior exitBehavior;

    private ToolCallLimitHook(Builder builder) {
        this.toolName = builder.toolName;
        this.threadLimit = builder.threadLimit;
        this.runLimit = builder.runLimit;
        this.exitBehavior = builder.exitBehavior;
    }

    public static Builder builder() {
        return new Builder();
    }

    private String getThreadCountKey() {
        String trackKey = this.toolName != null ? this.toolName : "__all__";
        return "__tool_call_limit_thread_count___" + trackKey;
    }

    private String getRunCountKey() {
        String trackKey = this.toolName != null ? this.toolName : "__all__";
        return "__tool_call_limit_run_count___" + trackKey;
    }

    @Override
    public CompletableFuture<Map<String, Object>> beforeModel(OverAllState state, RunnableConfig config) {
        boolean runLimitExceeded;
        int threadCount = config.context().containsKey(this.getThreadCountKey()) ? (Integer)config.context().get(this.getThreadCountKey()) : 0;
        int runCount = config.context().containsKey(this.getRunCountKey()) ? (Integer)config.context().get(this.getRunCountKey()) : 0;
        boolean threadLimitExceeded = this.threadLimit != null && threadCount >= this.threadLimit;
        boolean bl = runLimitExceeded = this.runLimit != null && runCount >= this.runLimit;
        if (threadLimitExceeded || runLimitExceeded) {
            if (this.exitBehavior == ExitBehavior.ERROR) {
                throw new ToolCallLimitExceededException(threadCount, runCount, this.threadLimit, this.runLimit, this.toolName);
            }
            if (this.exitBehavior == ExitBehavior.END) {
                String message = this.buildLimitExceededMessage(threadCount, runCount, this.threadLimit, this.runLimit, this.toolName);
                ArrayList<AssistantMessage> messages = new ArrayList<AssistantMessage>();
                messages.add(new AssistantMessage(message));
                HashMap<String, Object> updates = new HashMap<String, Object>();
                updates.put("messages", messages);
                updates.put("jump_to", (Object)JumpTo.end);
                return CompletableFuture.completedFuture(updates);
            }
        }
        return CompletableFuture.completedFuture(Map.of());
    }

    @Override
    public CompletableFuture<Map<String, Object>> afterModel(OverAllState state, RunnableConfig config) {
        AssistantMessage aiMessage;
        List messages = state.value("messages").orElse(List.of());
        if (messages.isEmpty()) {
            return CompletableFuture.completedFuture(Map.of());
        }
        Message lastMessage = (Message)messages.get(messages.size() - 1);
        int newCalls = 0;
        if (lastMessage instanceof AssistantMessage && (aiMessage = (AssistantMessage)lastMessage).getToolCalls() != null) {
            if (this.toolName == null) {
                newCalls = aiMessage.getToolCalls().size();
            } else {
                for (AssistantMessage.ToolCall toolCall : aiMessage.getToolCalls()) {
                    if (!this.toolName.equals(toolCall.name())) continue;
                    ++newCalls;
                }
            }
        }
        if (newCalls > 0) {
            int threadCount = config.context().containsKey(this.getThreadCountKey()) ? (Integer)config.context().get(this.getThreadCountKey()) : 0;
            int runCount = config.context().containsKey(this.getRunCountKey()) ? (Integer)config.context().get(this.getRunCountKey()) : 0;
            config.context().put(this.getThreadCountKey(), threadCount + newCalls);
            config.context().put(this.getRunCountKey(), runCount + newCalls);
        }
        return CompletableFuture.completedFuture(Map.of());
    }

    private String buildLimitExceededMessage(int threadCount, int runCount, Integer threadLimit, Integer runLimit, String toolName) {
        Object toolDesc = toolName != null ? "'" + toolName + "' tool call" : "Tool call";
        ArrayList<String> exceededLimits = new ArrayList<String>();
        if (threadLimit != null && threadCount >= threadLimit) {
            exceededLimits.add(String.format("thread limit (%d/%d)", threadCount, threadLimit));
        }
        if (runLimit != null && runCount >= runLimit) {
            exceededLimits.add(String.format("run limit (%d/%d)", runCount, runLimit));
        }
        return (String)toolDesc + " limits exceeded: " + String.join((CharSequence)", ", exceededLimits);
    }

    @Override
    public String getName() {
        return this.toolName != null ? "ToolCallLimit[" + this.toolName + "]" : "ToolCallLimit[all]";
    }

    @Override
    public List<JumpTo> canJumpTo() {
        if (this.exitBehavior == ExitBehavior.END) {
            return List.of(JumpTo.end);
        }
        return List.of();
    }

    public static class Builder {
        private String toolName;
        private Integer threadLimit;
        private Integer runLimit;
        private ExitBehavior exitBehavior = ExitBehavior.END;

        public Builder toolName(String toolName) {
            this.toolName = toolName;
            return this;
        }

        public Builder threadLimit(Integer threadLimit) {
            this.threadLimit = threadLimit;
            return this;
        }

        public Builder runLimit(Integer runLimit) {
            this.runLimit = runLimit;
            return this;
        }

        public Builder exitBehavior(ExitBehavior exitBehavior) {
            this.exitBehavior = exitBehavior;
            return this;
        }

        public ToolCallLimitHook build() {
            if (this.threadLimit == null && this.runLimit == null) {
                throw new IllegalArgumentException("At least one limit must be specified (threadLimit or runLimit)");
            }
            return new ToolCallLimitHook(this);
        }
    }

    public static enum ExitBehavior {
        END,
        ERROR;

    }
}

