/*
 * Decompiled with CFR 0.152.
 */
package com.alibaba.cloud.ai.graph.agent.interceptor.contextediting;

import com.alibaba.cloud.ai.graph.agent.hook.TokenCounter;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelCallHandler;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelInterceptor;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelRequest;
import com.alibaba.cloud.ai.graph.agent.interceptor.ModelResponse;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;

public class ContextEditingInterceptor
extends ModelInterceptor {
    private static final Logger log = LoggerFactory.getLogger(ContextEditingInterceptor.class);
    private static final String DEFAULT_PLACEHOLDER = "[cleared]";
    private final int trigger;
    private final int clearAtLeast;
    private final int keep;
    private final boolean clearToolInputs;
    private final Set<String> excludeTools;
    private final String placeholder;
    private final TokenCounter tokenCounter;

    private ContextEditingInterceptor(Builder builder) {
        this.trigger = builder.trigger;
        this.clearAtLeast = builder.clearAtLeast;
        this.keep = builder.keep;
        this.clearToolInputs = builder.clearToolInputs;
        this.excludeTools = builder.excludeTools != null ? new HashSet<String>(builder.excludeTools) : new HashSet();
        this.placeholder = builder.placeholder;
        this.tokenCounter = builder.tokenCounter;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    public ModelResponse interceptModel(ModelRequest request, ModelCallHandler handler) {
        ArrayList<Message> messages = new ArrayList<Message>(request.getMessages());
        int tokens = this.tokenCounter.countTokens(messages);
        if (tokens <= this.trigger) {
            return handler.call(request);
        }
        log.info("Token count {} exceeds trigger {}, clearing tool results", (Object)tokens, (Object)this.trigger);
        List<ClearableToolMessage> candidates = this.findClearableCandidates(messages);
        if (candidates.isEmpty()) {
            log.debug("No tool messages to clear");
            return handler.call(request);
        }
        int clearedTokens = 0;
        HashSet<Integer> indicesToClear = new HashSet<Integer>();
        for (ClearableToolMessage candidate : candidates) {
            if (clearedTokens >= this.clearAtLeast) break;
            indicesToClear.add(candidate.index);
            clearedTokens += candidate.estimatedTokens;
        }
        ArrayList<Message> updatedMessages = new ArrayList<Message>();
        for (int i = 0; i < messages.size(); ++i) {
            Message msg = (Message)messages.get(i);
            if (indicesToClear.contains(i)) {
                if (msg instanceof ToolResponseMessage) {
                    ToolResponseMessage toolMsg = (ToolResponseMessage)msg;
                    ArrayList clearedResponses = new ArrayList();
                    for (ToolResponseMessage.ToolResponse resp : toolMsg.getResponses()) {
                        clearedResponses.add(new ToolResponseMessage.ToolResponse(resp.id(), resp.name(), this.placeholder));
                    }
                    updatedMessages.add((Message)ToolResponseMessage.builder().responses((List)clearedResponses).metadata(toolMsg.getMetadata()).build());
                    continue;
                }
                if (!(msg instanceof AssistantMessage)) continue;
                AssistantMessage assistantMsg = (AssistantMessage)msg;
                ArrayList<AssistantMessage.ToolCall> clearedToolCalls = new ArrayList<AssistantMessage.ToolCall>();
                for (AssistantMessage.ToolCall toolCall : assistantMsg.getToolCalls()) {
                    clearedToolCalls.add(new AssistantMessage.ToolCall(toolCall.id(), toolCall.type(), toolCall.name(), this.placeholder));
                }
                AssistantMessage clearedAssistantMsg = AssistantMessage.builder().content(assistantMsg.getText()).properties(assistantMsg.getMetadata()).toolCalls(clearedToolCalls).build();
                updatedMessages.add((Message)clearedAssistantMsg);
                continue;
            }
            updatedMessages.add(msg);
        }
        if (clearedTokens > 0) {
            log.info("Cleared approximately {} tokens from {} tool messages", (Object)clearedTokens, (Object)indicesToClear.size());
            ModelRequest updatedRequest = ModelRequest.builder(request).messages(updatedMessages).build();
            return handler.call(updatedRequest);
        }
        return handler.call(request);
    }

    private List<ClearableToolMessage> findClearableCandidates(List<Message> messages) {
        List<ClearableToolMessage> candidates = new ArrayList<ClearableToolMessage>();
        for (int i = 0; i < messages.size(); ++i) {
            AssistantMessage assistantMsg;
            boolean alreadyCleared;
            Message msg = messages.get(i);
            if (msg instanceof ToolResponseMessage) {
                Object resp2;
                ToolResponseMessage toolMsg = (ToolResponseMessage)msg;
                alreadyCleared = false;
                for (Object resp2 : toolMsg.getResponses()) {
                    if (!this.placeholder.equals(resp2.responseData())) continue;
                    alreadyCleared = true;
                    break;
                }
                if (alreadyCleared) continue;
                boolean excluded = false;
                resp2 = toolMsg.getResponses().iterator();
                while (resp2.hasNext()) {
                    ToolResponseMessage.ToolResponse resp3 = (ToolResponseMessage.ToolResponse)resp2.next();
                    if (!this.excludeTools.contains(resp3.name())) continue;
                    excluded = true;
                    break;
                }
                if (excluded) continue;
                int tokens = TokenCounter.approximateMsgCounter().countTokens(List.of(toolMsg));
                candidates.add(new ClearableToolMessage(i, tokens));
                continue;
            }
            if (!(msg instanceof AssistantMessage) || (assistantMsg = (AssistantMessage)msg).getToolCalls().isEmpty()) continue;
            alreadyCleared = false;
            for (AssistantMessage.ToolCall toolCall : assistantMsg.getToolCalls()) {
                if (!this.placeholder.equals(toolCall.arguments())) continue;
                alreadyCleared = true;
                break;
            }
            if (alreadyCleared) continue;
            boolean excluded = false;
            for (AssistantMessage.ToolCall toolCall : assistantMsg.getToolCalls()) {
                if (!this.excludeTools.contains(toolCall.name())) continue;
                excluded = true;
                break;
            }
            if (excluded) continue;
            int tokens = TokenCounter.approximateMsgCounter().countTokens(List.of(assistantMsg));
            candidates.add(new ClearableToolMessage(i, tokens));
        }
        if (candidates.size() > this.keep) {
            candidates = candidates.subList(0, candidates.size() - this.keep);
        } else {
            candidates.clear();
        }
        return candidates;
    }

    @Override
    public String getName() {
        return "ContextEditing";
    }

    public static class Builder {
        private int trigger = 100000;
        private int clearAtLeast = 0;
        private int keep = 3;
        private boolean clearToolInputs = false;
        private Set<String> excludeTools;
        private String placeholder = "[cleared]";
        private TokenCounter tokenCounter = TokenCounter.approximateMsgCounter();

        public Builder trigger(int trigger) {
            this.trigger = trigger;
            return this;
        }

        public Builder clearAtLeast(int clearAtLeast) {
            this.clearAtLeast = clearAtLeast;
            return this;
        }

        public Builder keep(int keep) {
            this.keep = keep;
            return this;
        }

        public Builder clearToolInputs(boolean clearToolInputs) {
            this.clearToolInputs = clearToolInputs;
            return this;
        }

        public Builder excludeTools(Set<String> excludeTools) {
            this.excludeTools = excludeTools;
            return this;
        }

        public Builder excludeTools(String ... toolNames) {
            this.excludeTools = new HashSet<String>(Arrays.asList(toolNames));
            return this;
        }

        public Builder placeholder(String placeholder) {
            this.placeholder = placeholder;
            return this;
        }

        public Builder tokenCounter(TokenCounter tokenCounter) {
            this.tokenCounter = tokenCounter;
            return this;
        }

        public ContextEditingInterceptor build() {
            return new ContextEditingInterceptor(this);
        }
    }

    private static class ClearableToolMessage {
        final int index;
        final int estimatedTokens;

        ClearableToolMessage(int index, int estimatedTokens) {
            this.index = index;
            this.estimatedTokens = estimatedTokens;
        }
    }
}

