/*
 * Decompiled with CFR 0.152.
 */
package org.bsc.langgraph4j;

import java.io.IOException;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.bsc.async.AsyncGenerator;
import org.bsc.langgraph4j.CompileConfig;
import org.bsc.langgraph4j.GraphArgs;
import org.bsc.langgraph4j.GraphInput;
import org.bsc.langgraph4j.GraphRepresentation;
import org.bsc.langgraph4j.GraphResume;
import org.bsc.langgraph4j.GraphRunnerException;
import org.bsc.langgraph4j.GraphStateException;
import org.bsc.langgraph4j.NodeOutput;
import org.bsc.langgraph4j.ProcessedNodesEdgesAndConfig;
import org.bsc.langgraph4j.RunnableConfig;
import org.bsc.langgraph4j.StateGraph;
import org.bsc.langgraph4j.action.AsyncNodeActionWithConfig;
import org.bsc.langgraph4j.action.Command;
import org.bsc.langgraph4j.action.InterruptableAction;
import org.bsc.langgraph4j.action.InterruptionMetadata;
import org.bsc.langgraph4j.checkpoint.BaseCheckpointSaver;
import org.bsc.langgraph4j.checkpoint.Checkpoint;
import org.bsc.langgraph4j.internal.edge.Edge;
import org.bsc.langgraph4j.internal.edge.EdgeValue;
import org.bsc.langgraph4j.internal.node.Node;
import org.bsc.langgraph4j.internal.node.ParallelNode;
import org.bsc.langgraph4j.state.AgentState;
import org.bsc.langgraph4j.state.StateSnapshot;
import org.bsc.langgraph4j.utils.TryFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CompiledGraph<State extends AgentState> {
    private static final Logger log = LoggerFactory.getLogger(CompiledGraph.class);
    public final StateGraph<State> stateGraph;
    final Map<String, AsyncNodeActionWithConfig<State>> nodes = new LinkedHashMap<String, AsyncNodeActionWithConfig<State>>();
    final Map<String, EdgeValue<State>> edges = new LinkedHashMap<String, EdgeValue<State>>();
    private final ProcessedNodesEdgesAndConfig<State> processedData;
    private int maxIterations = 25;
    public final CompileConfig compileConfig;

    protected CompiledGraph(StateGraph<State> stateGraph, CompileConfig compileConfig) throws GraphStateException {
        this.stateGraph = stateGraph;
        this.processedData = ProcessedNodesEdgesAndConfig.process(stateGraph, compileConfig);
        for (String string : this.processedData.interruptsBefore()) {
            if (this.processedData.nodes().anyMatchById(string)) continue;
            throw StateGraph.Errors.interruptionNodeNotExist.exception(string);
        }
        for (String string : this.processedData.interruptsBefore()) {
            if (this.processedData.nodes().anyMatchById(string)) continue;
            throw StateGraph.Errors.interruptionNodeNotExist.exception(string);
        }
        this.compileConfig = CompileConfig.builder(compileConfig).interruptsBefore(this.processedData.interruptsBefore()).interruptsAfter(this.processedData.interruptsAfter()).build();
        for (Node node : this.processedData.nodes().elements) {
            Node.ActionFactory factory = node.actionFactory();
            Objects.requireNonNull(factory, String.format("action factory for node id '%s' is null!", node.id()));
            this.nodes.put(node.id(), factory.apply(compileConfig));
        }
        for (Edge edge : this.processedData.edges().elements) {
            List targets = edge.targets();
            if (targets.size() == 1) {
                this.edges.put(edge.sourceId(), targets.get(0));
                continue;
            }
            Supplier<Stream> parallelNodeStream = () -> targets.stream().filter(target -> this.nodes.containsKey(target.id()));
            List<Edge> parallelNodeEdges = parallelNodeStream.get().map(target -> new Edge(target.id())).filter(ee -> this.processedData.edges().elements.contains(ee)).map(ee -> this.processedData.edges().elements.indexOf(ee)).map(index -> this.processedData.edges().elements.get((int)index)).toList();
            Set parallelNodeTargets = parallelNodeEdges.stream().map(ee -> ee.target().id()).collect(Collectors.toSet());
            if (parallelNodeTargets.size() > 1) {
                List<Edge> conditionalEdges = parallelNodeEdges.stream().filter(ee -> ee.target().value() != null).toList();
                if (!conditionalEdges.isEmpty()) {
                    throw StateGraph.Errors.unsupportedConditionalEdgeOnParallelNode.exception(edge.sourceId(), conditionalEdges.stream().map(Edge::sourceId).toList());
                }
                throw StateGraph.Errors.illegalMultipleTargetsOnParallelNode.exception(edge.sourceId(), parallelNodeTargets);
            }
            List actions = parallelNodeStream.get().map(target -> this.nodes.get(target.id())).toList();
            ParallelNode parallelNode = new ParallelNode(edge.sourceId(), actions, stateGraph.getChannels());
            this.nodes.put(parallelNode.id(), parallelNode.actionFactory().apply(compileConfig));
            this.edges.put(edge.sourceId(), new EdgeValue(parallelNode.id()));
            this.edges.put(parallelNode.id(), new EdgeValue((String)parallelNodeTargets.iterator().next()));
        }
    }

    public Collection<StateSnapshot<State>> getStateHistory(RunnableConfig config) {
        BaseCheckpointSaver saver = this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("Missing CheckpointSaver!"));
        return saver.list(config).stream().map(checkpoint -> StateSnapshot.of(checkpoint, config, this.stateGraph.getStateFactory())).collect(Collectors.toList());
    }

    public StateSnapshot<State> getState(RunnableConfig config) {
        return this.stateOf(config).orElseThrow(() -> new IllegalStateException("Missing Checkpoint!"));
    }

    public Optional<StateSnapshot<State>> stateOf(RunnableConfig config) {
        BaseCheckpointSaver saver = this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("Missing CheckpointSaver!"));
        return saver.get(config).map(checkpoint -> StateSnapshot.of(checkpoint, config, this.stateGraph.getStateFactory()));
    }

    public Optional<StateSnapshot<State>> lastStateOf(RunnableConfig config) {
        return this.getStateHistory(config).stream().findFirst();
    }

    public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values, String asNode) throws Exception {
        BaseCheckpointSaver saver = this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("Missing CheckpointSaver!"));
        Checkpoint branchCheckpoint = saver.get(config).map(Checkpoint::copyOf).map(cp -> cp.updateState(values, this.stateGraph.getChannels())).orElseThrow(() -> new IllegalStateException("Missing Checkpoint!"));
        String nextNodeId = null;
        if (asNode != null) {
            Command nextNodeCommand = this.nextNodeId(asNode, branchCheckpoint.getState(), config);
            nextNodeId = nextNodeCommand.gotoNode();
            branchCheckpoint = branchCheckpoint.updateState(nextNodeCommand.update(), this.stateGraph.getChannels());
        }
        RunnableConfig newConfig = saver.put(config, branchCheckpoint);
        return RunnableConfig.builder(newConfig).checkPointId(branchCheckpoint.getId()).nextNode(nextNodeId).build();
    }

    public RunnableConfig updateState(RunnableConfig config, Map<String, Object> values) throws Exception {
        return this.updateState(config, values, null);
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations <= 0) {
            throw new IllegalArgumentException("maxIterations must be > 0!");
        }
        this.maxIterations = maxIterations;
    }

    private Command nextNodeId(EdgeValue<State> route, Map<String, Object> state, String nodeId, RunnableConfig config) throws Exception {
        if (route == null) {
            throw RunnableErrors.missingEdge.exception(nodeId);
        }
        if (route.id() != null) {
            return new Command(route.id(), state);
        }
        if (route.value() != null) {
            AgentState derefState = (AgentState)this.stateGraph.getStateFactory().apply(state);
            Command command = (Command)((CompletableFuture)route.value().action().apply(derefState, config)).get();
            String newRoute = command.gotoNode();
            String result = route.value().mappings().get(newRoute);
            if (result == null) {
                throw RunnableErrors.missingNodeInEdgeMapping.exception(nodeId, newRoute);
            }
            Map<String, Object> currentState = AgentState.updateState(state, command.update(), this.stateGraph.getChannels());
            return new Command(result, currentState);
        }
        throw RunnableErrors.executionError.exception(String.format("invalid edge value for nodeId: [%s] !", nodeId));
    }

    private Command nextNodeId(String nodeId, Map<String, Object> state, RunnableConfig config) throws Exception {
        return this.nextNodeId(this.edges.get(nodeId), state, nodeId, config);
    }

    private Command getEntryPoint(Map<String, Object> state, RunnableConfig config) throws Exception {
        EdgeValue<State> entryPoint = this.edges.get(StateGraph.START);
        return this.nextNodeId(entryPoint, state, "entryPoint", config);
    }

    private boolean shouldInterruptBefore(String nodeId, String previousNodeId) {
        Objects.requireNonNull(nodeId, "nodeId cannot be null");
        if (previousNodeId == null) {
            return false;
        }
        return this.compileConfig.interruptsBefore().contains(nodeId);
    }

    private boolean shouldInterruptAfter(String nodeId, String previousNodeId) {
        if (nodeId == null || Objects.equals(nodeId, previousNodeId)) {
            return false;
        }
        return this.compileConfig.interruptsAfter().contains(nodeId);
    }

    private Optional<Checkpoint> addCheckpoint(RunnableConfig config, String nodeId, Map<String, Object> state, String nextNodeId) throws Exception {
        if (this.compileConfig.checkpointSaver().isPresent()) {
            Checkpoint cp = Checkpoint.builder().nodeId(nodeId).state((AgentState)this.cloneState(state)).nextNodeId(nextNodeId).build();
            this.compileConfig.checkpointSaver().get().put(config, cp);
            return Optional.of(cp);
        }
        return Optional.empty();
    }

    Map<String, Object> getInitialStateFromSchema() {
        return this.stateGraph.getStateFactory().initialDataFromSchema(this.stateGraph.getChannels());
    }

    Map<String, Object> getInitialState(Map<String, Object> inputs, RunnableConfig config) {
        return this.compileConfig.checkpointSaver().flatMap(saver -> saver.get(config)).map(cp -> AgentState.updateState(cp.getState(), inputs, this.stateGraph.getChannels())).orElseGet(() -> AgentState.updateState(this.getInitialStateFromSchema(), inputs, this.stateGraph.getChannels()));
    }

    State cloneState(Map<String, Object> data) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException {
        return this.stateGraph.getStateSerializer().cloneObject(data);
    }

    public AsyncGenerator<NodeOutput<State>> stream(GraphInput input, RunnableConfig config) {
        Objects.requireNonNull(config, "config cannot be null");
        Objects.requireNonNull(input, "input cannot be null");
        AsyncNodeGenerator generator = new AsyncNodeGenerator(input, config);
        return new AsyncGenerator.WithEmbed(generator);
    }

    public AsyncGenerator<NodeOutput<State>> stream(Map<String, Object> inputs, RunnableConfig config) {
        return this.stream((GraphInput)((Object)(inputs == null ? new GraphResume() : new GraphArgs(inputs))), config);
    }

    public AsyncGenerator<NodeOutput<State>> stream(Map<String, Object> inputs) {
        return this.stream(GraphInput.args(inputs), RunnableConfig.builder().build());
    }

    public Optional<State> invoke(GraphInput input, RunnableConfig config) {
        return this.stream(input, config).stream().reduce((a, b) -> b).map(NodeOutput::state);
    }

    public Optional<State> invoke(Map<String, Object> inputs, RunnableConfig config) {
        return this.stream(GraphInput.args(inputs), config).stream().reduce((a, b) -> b).map(NodeOutput::state);
    }

    public Optional<State> invoke(Map<String, Object> inputs) {
        return this.invoke(GraphInput.args(inputs), RunnableConfig.builder().build());
    }

    public AsyncGenerator<NodeOutput<State>> streamSnapshots(GraphInput input, RunnableConfig config) {
        Objects.requireNonNull(config, "config cannot be null");
        AsyncNodeGenerator generator = new AsyncNodeGenerator(input, config.withStreamMode(StreamMode.SNAPSHOTS));
        return new AsyncGenerator.WithEmbed(generator);
    }

    public AsyncGenerator<NodeOutput<State>> streamSnapshots(Map<String, Object> inputs, RunnableConfig config) {
        return this.streamSnapshots((GraphInput)((Object)(inputs == null ? new GraphResume() : new GraphArgs(inputs))), config);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title, boolean printConditionalEdges) {
        String content = type.generator.generate(this.processedData.nodes(), this.processedData.edges(), title, printConditionalEdges);
        return new GraphRepresentation(type, content);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type, String title) {
        String content = type.generator.generate(this.processedData.nodes(), this.processedData.edges(), title, true);
        return new GraphRepresentation(type, content);
    }

    public GraphRepresentation getGraph(GraphRepresentation.Type type) {
        return this.getGraph(type, "Graph Diagram", true);
    }

    static enum RunnableErrors {
        missingNodeInEdgeMapping("cannot find edge mapping for id: '%s' in conditional edge with sourceId: '%s' "),
        missingNode("node with id: '%s' doesn't exist!"),
        missingEdge("edge with sourceId: '%s' doesn't exist!"),
        executionError("%s");

        private final String errorMessage;

        private RunnableErrors(String errorMessage) {
            this.errorMessage = errorMessage;
        }

        GraphRunnerException exception(String ... args) {
            return new GraphRunnerException(String.format(this.errorMessage, args));
        }
    }

    public class AsyncNodeGenerator<Output extends NodeOutput<State>>
    implements AsyncGenerator<Output> {
        Map<String, Object> currentState;
        String currentNodeId;
        String nextNodeId;
        int iteration = 0;
        RunnableConfig config;
        boolean resumedFromEmbed = false;

        protected AsyncNodeGenerator(GraphInput input, RunnableConfig config) {
            boolean isResumeRequest = input instanceof GraphResume;
            if (isResumeRequest) {
                log.trace("RESUME REQUEST");
                BaseCheckpointSaver saver = CompiledGraph.this.compileConfig.checkpointSaver().orElseThrow(() -> new IllegalStateException("inputs cannot be null (ie. resume request) if no checkpoint saver is configured"));
                Checkpoint startCheckpoint = saver.get(config).orElseThrow(() -> new IllegalStateException("Resume request without a saved checkpoint!"));
                this.currentState = startCheckpoint.getState();
                this.config = config.withCheckPointId(null);
                this.nextNodeId = startCheckpoint.getNextNodeId();
                this.currentNodeId = null;
                log.trace("RESUME FROM {}", (Object)startCheckpoint.getNodeId());
            } else {
                log.trace("START");
                Map<String, Object> initState = CompiledGraph.this.getInitialState(((GraphArgs)input).value(), config);
                AgentState initializedState = (AgentState)CompiledGraph.this.stateGraph.getStateFactory().apply(initState);
                this.currentState = initializedState.data();
                this.nextNodeId = null;
                this.currentNodeId = StateGraph.START;
                this.config = config;
            }
        }

        protected Output buildNodeOutput(String nodeId) throws Exception {
            return (Output)NodeOutput.of(nodeId, CompiledGraph.this.cloneState(this.currentState));
        }

        protected Output buildStateSnapshot(Checkpoint checkpoint) throws Exception {
            return (Output)StateSnapshot.of(checkpoint, this.config, CompiledGraph.this.stateGraph.getStateFactory());
        }

        private Optional<AsyncGenerator.Data<Output>> getEmbedGenerator(Map<String, Object> partialState) {
            return partialState.entrySet().stream().filter(e -> e.getValue() instanceof AsyncGenerator).findFirst().map(generatorEntry -> {
                AsyncGenerator generator = (AsyncGenerator)generatorEntry.getValue();
                return AsyncGenerator.Data.composeWith((AsyncGenerator)generator.map(n -> {
                    n.setSubGraph(true);
                    return n;
                }), data -> {
                    if (data != null) {
                        if (data instanceof Map) {
                            Map<String, Object> partialStateWithoutGenerator = partialState.entrySet().stream().filter(e -> !Objects.equals(e.getKey(), generatorEntry.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
                            Map<String, Object> intermediateState = AgentState.updateState(this.currentState, partialStateWithoutGenerator, CompiledGraph.this.stateGraph.getChannels());
                            this.currentState = AgentState.updateState(intermediateState, (Map<String, Object>)((Map)data), CompiledGraph.this.stateGraph.getChannels());
                        } else {
                            throw new IllegalArgumentException("Embedded generator must return a Map");
                        }
                    }
                    Command nextNodeCommand = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState, this.config);
                    this.nextNodeId = nextNodeCommand.gotoNode();
                    this.currentState = nextNodeCommand.update();
                    this.resumedFromEmbed = true;
                });
            });
        }

        private CompletableFuture<AsyncGenerator.Data<Output>> evaluateAction(AsyncNodeActionWithConfig<State> action, State withState) {
            return action.apply(withState, this.config).thenApply(TryFunction.Try(updateState -> {
                Optional<AsyncGenerator.Data<Output>> embed = this.getEmbedGenerator((Map<String, Object>)updateState);
                if (embed.isPresent()) {
                    return embed.get();
                }
                this.currentState = AgentState.updateState(this.currentState, (Map<String, Object>)updateState, CompiledGraph.this.stateGraph.getChannels());
                Command nextNodeCommand = CompiledGraph.this.nextNodeId(this.currentNodeId, this.currentState, this.config);
                this.nextNodeId = nextNodeCommand.gotoNode();
                this.currentState = nextNodeCommand.update();
                return AsyncGenerator.Data.of(this.getNodeOutput());
            }));
        }

        private CompletableFuture<Output> getNodeOutput() throws Exception {
            Optional<Checkpoint> cp = CompiledGraph.this.addCheckpoint(this.config, this.currentNodeId, this.currentState, this.nextNodeId);
            return CompletableFuture.completedFuture(cp.isPresent() && this.config.streamMode() == StreamMode.SNAPSHOTS ? this.buildStateSnapshot(cp.get()) : this.buildNodeOutput(this.currentNodeId));
        }

        private Optional<BaseCheckpointSaver.Tag> releaseThread() throws Exception {
            if (CompiledGraph.this.compileConfig.releaseThread() && CompiledGraph.this.compileConfig.checkpointSaver().isPresent()) {
                return Optional.of(CompiledGraph.this.compileConfig.checkpointSaver().get().release(this.config));
            }
            return Optional.empty();
        }

        public AsyncGenerator.Data<Output> next() {
            try {
                InterruptableAction interruption;
                Optional interruptMetadata;
                if (++this.iteration > CompiledGraph.this.maxIterations) {
                    return AsyncGenerator.Data.error((Throwable)new IllegalStateException(String.format("Maximum number of iterations (%d) reached!", CompiledGraph.this.maxIterations)));
                }
                if (this.nextNodeId == null && this.currentNodeId == null) {
                    return this.releaseThread().map(AsyncGenerator.Data::done).orElseGet(() -> AsyncGenerator.Data.done(this.currentState));
                }
                if (this.resumedFromEmbed) {
                    CompletableFuture<Output> future = this.getNodeOutput();
                    this.resumedFromEmbed = false;
                    return AsyncGenerator.Data.of(future);
                }
                if (StateGraph.START.equals(this.currentNodeId)) {
                    Command nextNodeCommand = CompiledGraph.this.getEntryPoint(this.currentState, this.config);
                    this.nextNodeId = nextNodeCommand.gotoNode();
                    this.currentState = nextNodeCommand.update();
                    Optional<Checkpoint> cp = CompiledGraph.this.addCheckpoint(this.config, StateGraph.START, this.currentState, this.nextNodeId);
                    Output output = cp.isPresent() && this.config.streamMode() == StreamMode.SNAPSHOTS ? this.buildStateSnapshot(cp.get()) : this.buildNodeOutput(this.currentNodeId);
                    this.currentNodeId = this.nextNodeId;
                    return AsyncGenerator.Data.of(output);
                }
                if (StateGraph.END.equals(this.nextNodeId)) {
                    this.nextNodeId = null;
                    this.currentNodeId = null;
                    return AsyncGenerator.Data.of(this.buildNodeOutput(StateGraph.END));
                }
                Object clonedState = CompiledGraph.this.cloneState(this.currentState);
                if (CompiledGraph.this.shouldInterruptAfter(this.currentNodeId, this.nextNodeId)) {
                    return AsyncGenerator.Data.done(InterruptionMetadata.builder(this.currentNodeId, clonedState).build());
                }
                if (CompiledGraph.this.shouldInterruptBefore(this.nextNodeId, this.currentNodeId)) {
                    return AsyncGenerator.Data.done(InterruptionMetadata.builder(this.currentNodeId, clonedState).build());
                }
                this.currentNodeId = this.nextNodeId;
                AsyncNodeActionWithConfig action = CompiledGraph.this.nodes.get(this.currentNodeId);
                if (action == null) {
                    throw RunnableErrors.missingNode.exception(this.currentNodeId);
                }
                if (action instanceof InterruptableAction && (interruptMetadata = (interruption = (InterruptableAction)((Object)action)).interrupt(this.currentNodeId, clonedState)).isPresent()) {
                    return AsyncGenerator.Data.done(interruptMetadata.get());
                }
                return this.evaluateAction(action, clonedState).get();
            }
            catch (Exception e) {
                log.error(e.getMessage(), (Throwable)e);
                return AsyncGenerator.Data.error((Throwable)e);
            }
        }
    }

    public static enum StreamMode {
        VALUES,
        SNAPSHOTS;

    }
}

