/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import org.bsc.langgraph4j.CompileConfig;
import org.bsc.langgraph4j.CompiledGraph;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.action.AsyncCommandAction;
import org.bsc.langgraph4j.action.AsyncEdgeAction;
import org.bsc.langgraph4j.action.AsyncNodeAction;
import org.bsc.langgraph4j.action.AsyncNodeActionWithConfig;
import org.bsc.langgraph4j.internal.edge.Edge;
import org.bsc.langgraph4j.internal.edge.EdgeCondition;
import org.bsc.langgraph4j.internal.edge.EdgeValue;
import org.bsc.langgraph4j.internal.node.Node;
import org.bsc.langgraph4j.internal.node.SubCompiledGraphNode;
import org.bsc.langgraph4j.internal.node.SubStateGraphNode;
import org.bsc.langgraph4j.serializer.StateSerializer;
import org.bsc.langgraph4j.serializer.std.ObjectStreamStateSerializer;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.AgentStateFactory;
import org.bsc.langgraph4j.state.Channel;

public class StateGraph<State extends AgentState> {
    public static String END = "__END__";
    public static String START = "__START__";
    final Nodes<State> nodes = new Nodes();
    final Edges<State> edges = new Edges();
    private final Map<String, Channel<?>> channels;
    private final StateSerializer<State> stateSerializer;

    public StateGraph(Map<String, Channel<?>> channels, StateSerializer<State> stateSerializer) {
        this.channels = channels;
        this.stateSerializer = Objects.requireNonNull(stateSerializer, "stateSerializer cannot be null");
    }

    public StateGraph(StateSerializer<State> stateSerializer) {
        this(Map.of(), stateSerializer);
    }

    public StateGraph(AgentStateFactory<State> stateFactory) {
        this(Map.of(), stateFactory);
    }

    public StateGraph(Map<String, Channel<?>> channels, AgentStateFactory<State> stateFactory) {
        this(channels, new ObjectStreamStateSerializer<State>(stateFactory));
    }

    public StateSerializer<State> getStateSerializer() {
        return this.stateSerializer;
    }

    public final AgentStateFactory<State> getStateFactory() {
        return this.stateSerializer.stateFactory();
    }

    public Map<String, Channel<?>> getChannels() {
        return Collections.unmodifiableMap(this.channels);
    }

    public StateGraph<State> addNode(String id, AsyncNodeAction<State> action) throws GraphStateException {
        return this.addNode(id, AsyncNodeActionWithConfig.of(action));
    }

    public StateGraph<State> addNode(String id, AsyncNodeActionWithConfig<State> action) throws GraphStateException {
        if (Objects.equals(id, END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        Node node = new Node(id, config -> action);
        if (this.nodes.elements.contains(node)) {
            throw Errors.duplicateNodeError.exception(id);
        }
        this.nodes.elements.add(node);
        return this;
    }

    public StateGraph<State> addNode(String id, AsyncCommandAction<State> action, Map<String, String> mappings) throws GraphStateException {
        return this.addNode(id, (State state, RunnableConfig config) -> CompletableFuture.completedFuture(Map.of())).addConditionalEdges(id, action, mappings);
    }

    public StateGraph<State> addNode(String id, CompiledGraph<State> subGraph) throws GraphStateException {
        if (Objects.equals(id, END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        SubCompiledGraphNode<State> node = new SubCompiledGraphNode<State>(id, subGraph);
        if (this.nodes.elements.contains(node)) {
            throw Errors.duplicateNodeError.exception(id);
        }
        this.nodes.elements.add(node);
        return this;
    }

    @Deprecated(forRemoval=true)
    public StateGraph<State> addSubgraph(String id, CompiledGraph<State> subGraph) throws GraphStateException {
        return this.addNode(id, subGraph);
    }

    public StateGraph<State> addNode(String id, StateGraph<State> subGraph) throws GraphStateException {
        if (Objects.equals(id, END)) {
            throw Errors.invalidNodeIdentifier.exception(END);
        }
        subGraph.validateGraph();
        SubStateGraphNode<State> node = new SubStateGraphNode<State>(id, subGraph);
        if (this.nodes.elements.contains(node)) {
            throw Errors.duplicateNodeError.exception(id);
        }
        this.nodes.elements.add(node);
        return this;
    }

    @Deprecated(forRemoval=true)
    public StateGraph<State> addSubgraph(String id, StateGraph<State> subGraph) throws GraphStateException {
        return this.addNode(id, subGraph);
    }

    public StateGraph<State> addEdge(String sourceId, String targetId) throws GraphStateException {
        if (Objects.equals(sourceId, END)) {
            throw Errors.invalidEdgeIdentifier.exception(END);
        }
        Edge newEdge = new Edge(sourceId, new EdgeValue(targetId));
        int index = this.edges.elements.indexOf(newEdge);
        if (index >= 0) {
            ArrayList newTargets = new ArrayList(this.edges.elements.get(index).targets());
            newTargets.add(newEdge.target());
            this.edges.elements.set(index, new Edge(sourceId, newTargets));
        } else {
            this.edges.elements.add(newEdge);
        }
        return this;
    }

    public StateGraph<State> addConditionalEdges(String sourceId, AsyncCommandAction<State> condition, Map<String, String> mappings) throws GraphStateException {
        if (Objects.equals(sourceId, END)) {
            throw Errors.invalidEdgeIdentifier.exception(END);
        }
        if (mappings == null || mappings.isEmpty()) {
            throw Errors.edgeMappingIsEmpty.exception(sourceId);
        }
        Edge<State> newEdge = new Edge<State>(sourceId, new EdgeValue<State>(new EdgeCondition<State>(condition, mappings)));
        if (this.edges.elements.contains(newEdge)) {
            throw Errors.duplicateConditionalEdgeError.exception(sourceId);
        }
        this.edges.elements.add(newEdge);
        return this;
    }

    public StateGraph<State> addConditionalEdges(String sourceId, AsyncEdgeAction<State> condition, Map<String, String> mappings) throws GraphStateException {
        return this.addConditionalEdges(sourceId, AsyncCommandAction.of(condition), mappings);
    }

    void validateGraph() throws GraphStateException {
        Edge<State> edgeStart = this.edges.edgeBySourceId(START).orElseThrow(() -> Errors.missingEntryPoint.exception(new Object[0]));
        edgeStart.validate(this.nodes);
        for (Edge<State> edge : this.edges.elements) {
            edge.validate(this.nodes);
        }
    }

    public CompiledGraph<State> compile(CompileConfig config) throws GraphStateException {
        Objects.requireNonNull(config, "config cannot be null");
        this.validateGraph();
        return new CompiledGraph(this, config);
    }

    public CompiledGraph<State> compile() throws GraphStateException {
        return this.compile(CompileConfig.builder().build());
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
        String content = type.generator.generate(this.nodes, this.edges, title, printConditionalEdges);
        return new GraphRepresentation(type, content);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
        String content = type.generator.generate(this.nodes, this.edges, title, true);
        return new GraphRepresentation(type, content);
    }

    public static class Nodes<State extends AgentState> {
        public final Set<Node<State>> elements;

        public Nodes(Collection<Node<State>> elements) {
            this.elements = new LinkedHashSet<Node<State>>(elements);
        }

        public Nodes() {
            this.elements = new LinkedHashSet<Node<State>>();
        }

        public boolean anyMatchById(String id) {
            return this.elements.stream().anyMatch(n -> Objects.equals(n.id(), id));
        }

        public List<SubStateGraphNode<State>> onlySubStateGraphNodes() {
            return this.elements.stream().filter(n -> n instanceof SubStateGraphNode).map(n -> (SubStateGraphNode)n).toList();
        }

        public List<Node<State>> exceptSubStateGraphNodes() {
            return this.elements.stream().filter(n -> !(n instanceof SubStateGraphNode)).toList();
        }
    }

    public static class Edges<State extends AgentState> {
        public final List<Edge<State>> elements;

        public Edges(Collection<Edge<State>> elements) {
            this.elements = new LinkedList<Edge<State>>(elements);
        }

        public Edges() {
            this.elements = new LinkedList<Edge<State>>();
        }

        public Optional<Edge<State>> edgeBySourceId(String sourceId) {
            return this.elements.stream().filter(e -> Objects.equals(e.sourceId(), sourceId)).findFirst();
        }

        public List<Edge<State>> edgesByTargetId(String targetId) {
            return this.elements.stream().filter(e -> e.anyMatchByTargetId(targetId)).toList();
        }
    }

    public static enum Errors {
        invalidNodeIdentifier("END is not a valid node id!"),
        invalidEdgeIdentifier("END is not a valid edge sourceId!"),
        duplicateNodeError("node with id: %s already exist!"),
        duplicateEdgeError("edge with id: %s already exist!"),
        duplicateConditionalEdgeError("conditional edge from '%s' already exist!"),
        edgeMappingIsEmpty("edge mapping is empty!"),
        missingEntryPoint("missing Entry Point"),
        entryPointNotExist("entryPoint: %s doesn't exist!"),
        finishPointNotExist("finishPoint: %s doesn't exist!"),
        missingNodeReferencedByEdge("edge sourceId '%s' refers to undefined node!"),
        missingNodeInEdgeMapping("edge mapping for sourceId: %s contains a not existent nodeId %s!"),
        invalidEdgeTarget("edge sourceId: %s has an initialized target value!"),
        duplicateEdgeTargetError("edge [%s] has duplicate targets %s!"),
        unsupportedConditionalEdgeOnParallelNode("parallel node doesn't support conditional branch, but on [%s] a conditional branch on %s have been found!"),
        illegalMultipleTargetsOnParallelNode("parallel node [%s] must have only one target, but %s have been found!"),
        interruptionNodeNotExist("node '%s' configured as interruption doesn't exist!");

        private final String errorMessage;

        private Errors(String errorMessage) {
            this.errorMessage = errorMessage;
        }

        public GraphStateException exception(Object ... args) {
            return new GraphStateException(String.format(this.errorMessage, args));
        }
    }
}

