/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j.agent;

import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncCommandAction;
import org.bsc.langgraph4j.action.AsyncNodeActionWithConfig;
import org.bsc.langgraph4j.action.InterruptableAction;
import org.bsc.langgraph4j.action.InterruptionMetadata;
import org.bsc.langgraph4j.prebuilt.MessagesState;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.Channel;
import org.bsc.langgraph4j.utils.EdgeMappings;

public interface AgentEx {
    public static final String CONTINUE_LABEL = "continue";
    public static final String END_LABEL = "end";
    public static final String APPROVAL_RESULT_PROPERTY = "approval_result";

    public static <M, S extends MessagesState<M>, TOOL> Builder<M, S, TOOL> builder() {
        return new Builder();
    }

    public static class Builder<M, S extends MessagesState<M>, TOOL> {
        private StateSerializer<S> stateSerializer;
        private AsyncNodeActionWithConfig<S> callModelAction;
        private AsyncNodeActionWithConfig<S> dispatchToolsAction;
        private AsyncCommandAction<S> dispatchActionEdge;
        private Function<String, AsyncNodeActionWithConfig<S>> executeToolFactory;
        private AsyncCommandAction<S> shouldContinueEdge;
        private AsyncCommandAction<S> approvalActionEdge;
        private Map<String, Channel<?>> schema;
        private Function<TOOL, String> toolName;

        public Builder<M, S, TOOL> stateSerializer(StateSerializer<S> stateSerializer) {
            this.stateSerializer = stateSerializer;
            return this;
        }

        public Builder<M, S, TOOL> schema(Map<String, Channel<?>> schema) {
            this.schema = schema;
            return this;
        }

        public Builder<M, S, TOOL> callModelAction(AsyncNodeActionWithConfig<S> callModelAction) {
            this.callModelAction = callModelAction;
            return this;
        }

        public Builder<M, S, TOOL> executeToolFactory(Function<String, AsyncNodeActionWithConfig<S>> executeToolFactory) {
            this.executeToolFactory = executeToolFactory;
            return this;
        }

        public Builder<M, S, TOOL> dispatchToolsAction(AsyncNodeActionWithConfig<S> dispatchToolsAction) {
            this.dispatchToolsAction = dispatchToolsAction;
            return this;
        }

        public Builder<M, S, TOOL> shouldContinueEdge(AsyncCommandAction<S> shouldContinueEdge) {
            this.shouldContinueEdge = shouldContinueEdge;
            return this;
        }

        public Builder<M, S, TOOL> dispatchActionEdge(AsyncCommandAction<S> dispatchActionEdge) {
            this.dispatchActionEdge = dispatchActionEdge;
            return this;
        }

        public Builder<M, S, TOOL> approvalActionEdge(AsyncCommandAction<S> approvalActionEdge) {
            this.approvalActionEdge = approvalActionEdge;
            return this;
        }

        public Builder<M, S, TOOL> toolName(Function<TOOL, String> toolName) {
            this.toolName = toolName;
            return this;
        }

        public StateGraph<S> build(Collection<TOOL> tools, Map<String, ApprovalNodeAction<M, S>> approvals) throws GraphStateException {
            Objects.requireNonNull(this.toolName, "toolName is required!");
            for (String approval : approvals.keySet()) {
                tools.stream().filter(tool -> Objects.equals(this.toolName.apply(tool), approval)).findAny().orElseThrow(() -> new IllegalArgumentException(String.format("approval action %s not found!", approval)));
            }
            StateGraph<S> graph = new StateGraph<S>(Objects.requireNonNull(this.schema, "schema is required!"), Objects.requireNonNull(this.stateSerializer, "stateSerializer is required!")).addNode("model", Objects.requireNonNull(this.callModelAction, "callModelAction is required!")).addNode("action_dispatcher", Objects.requireNonNull(this.dispatchToolsAction, "dispatchToolsAction is required!")).addEdge(StateGraph.START, "model").addConditionalEdges("model", Objects.requireNonNull(this.shouldContinueEdge, "shouldContinueEdge is required!"), EdgeMappings.builder().to("action_dispatcher", AgentEx.CONTINUE_LABEL).toEND(AgentEx.END_LABEL).build());
            EdgeMappings.Builder actionMappingBuilder = EdgeMappings.builder().to("model").toEND();
            for (TOOL tool2 : tools) {
                String tool_name = this.toolName.apply(tool2);
                if (approvals.containsKey(tool_name)) {
                    String approval_nodeId = String.format("approval_%s", tool_name);
                    ApprovalNodeAction<M, S> approvalAction = approvals.get(tool_name);
                    graph.addNode(approval_nodeId, approvalAction);
                    graph.addConditionalEdges(approval_nodeId, Objects.requireNonNull(this.approvalActionEdge, "approvalActionEdge is required!"), EdgeMappings.builder().to("model", ApprovalState.REJECTED.name()).to(tool_name, ApprovalState.APPROVED.name()).build());
                    actionMappingBuilder.to(approval_nodeId);
                } else {
                    actionMappingBuilder.to(tool_name);
                }
                graph.addNode(tool_name, Objects.requireNonNull(this.executeToolFactory, "executeToolsAction is required!").apply(tool_name));
                graph.addEdge(tool_name, "action_dispatcher");
            }
            return graph.addConditionalEdges("action_dispatcher", Objects.requireNonNull(this.dispatchActionEdge, "dispatchActionEdge is required!"), actionMappingBuilder.build());
        }
    }

    public static final class ApprovalNodeAction<M, State extends MessagesState<M>>
    implements AsyncNodeActionWithConfig<State>,
    InterruptableAction<State> {
        private final BiFunction<String, State, InterruptionMetadata<State>> interruptionMetadataProvider;

        private ApprovalNodeAction(Builder<M, State> builder) {
            this.interruptionMetadataProvider = builder.interruptionMetadataProvider;
        }

        @Override
        public CompletableFuture<Map<String, Object>> apply(State state, RunnableConfig config) {
            return CompletableFuture.completedFuture(Map.of());
        }

        @Override
        public Optional<InterruptionMetadata<State>> interrupt(String nodeId, State state) {
            if (((AgentState)state).value(AgentEx.APPROVAL_RESULT_PROPERTY).isEmpty()) {
                InterruptionMetadata<State> metadata = this.interruptionMetadataProvider.apply(nodeId, state);
                return Optional.of(metadata);
            }
            return Optional.empty();
        }

        public static <M, State extends MessagesState<M>> Builder<M, State> builder() {
            return new Builder();
        }

        public static class Builder<M, State extends MessagesState<M>> {
            private BiFunction<String, State, InterruptionMetadata<State>> interruptionMetadataProvider;

            public Builder<M, State> interruptionMetadataProvider(BiFunction<String, State, InterruptionMetadata<State>> provider) {
                this.interruptionMetadataProvider = provider;
                return this;
            }

            public ApprovalNodeAction<M, State> build() {
                Objects.requireNonNull(this.interruptionMetadataProvider, "interruptionMetadataProvider cannot be null!");
                return new ApprovalNodeAction(this);
            }
        }
    }

    public static enum ApprovalState {
        APPROVED,
        REJECTED;

    }
}

