/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.client.transport;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.client.transport.AsyncHttpRequestCustomizer;
import io.modelcontextprotocol.client.transport.ResponseSubscribers;
import io.modelcontextprotocol.client.transport.SyncHttpRequestCustomizer;
import io.modelcontextprotocol.spec.DefaultMcpTransportSession;
import io.modelcontextprotocol.spec.DefaultMcpTransportStream;
import io.modelcontextprotocol.spec.McpClientTransport;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpTransportSession;
import io.modelcontextprotocol.spec.McpTransportSessionNotFoundException;
import io.modelcontextprotocol.spec.McpTransportStream;
import io.modelcontextprotocol.util.Assert;
import io.modelcontextprotocol.util.Utils;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.time.Duration;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import org.reactivestreams.Publisher;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.Disposable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
import reactor.util.function.Tuples;

public class HttpClientStreamableHttpTransport
implements McpClientTransport {
    private static final Logger logger = LoggerFactory.getLogger(HttpClientStreamableHttpTransport.class);
    private static final String MCP_PROTOCOL_VERSION = "2025-03-26";
    private static final String DEFAULT_ENDPOINT = "/mcp";
    private final HttpClient httpClient;
    private final HttpRequest.Builder requestBuilder;
    private static final String MESSAGE_EVENT_TYPE = "message";
    private static final String APPLICATION_JSON = "application/json";
    private static final String TEXT_EVENT_STREAM = "text/event-stream";
    public static int NOT_FOUND = 404;
    public static int METHOD_NOT_ALLOWED = 405;
    public static int BAD_REQUEST = 400;
    private final ObjectMapper objectMapper;
    private final URI baseUri;
    private final String endpoint;
    private final boolean openConnectionOnStartup;
    private final boolean resumableStreams;
    private final AsyncHttpRequestCustomizer httpRequestCustomizer;
    private final AtomicReference<DefaultMcpTransportSession> activeSession = new AtomicReference();
    private final AtomicReference<Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>>> handler = new AtomicReference();
    private final AtomicReference<Consumer<Throwable>> exceptionHandler = new AtomicReference();

    private HttpClientStreamableHttpTransport(ObjectMapper objectMapper, HttpClient httpClient, HttpRequest.Builder requestBuilder, String baseUri, String endpoint, boolean resumableStreams, boolean openConnectionOnStartup, AsyncHttpRequestCustomizer httpRequestCustomizer) {
        this.objectMapper = objectMapper;
        this.httpClient = httpClient;
        this.requestBuilder = requestBuilder;
        this.baseUri = URI.create(baseUri);
        this.endpoint = endpoint;
        this.resumableStreams = resumableStreams;
        this.openConnectionOnStartup = openConnectionOnStartup;
        this.activeSession.set(this.createTransportSession());
        this.httpRequestCustomizer = httpRequestCustomizer;
    }

    @Override
    public String protocolVersion() {
        return MCP_PROTOCOL_VERSION;
    }

    public static Builder builder(String baseUri) {
        return new Builder(baseUri);
    }

    @Override
    public Mono<Void> connect(Function<Mono<McpSchema.JSONRPCMessage>, Mono<McpSchema.JSONRPCMessage>> handler) {
        return Mono.deferContextual(ctx -> {
            this.handler.set(handler);
            if (this.openConnectionOnStartup) {
                logger.debug("Eagerly opening connection on startup");
                return this.reconnect(null).onErrorComplete(t -> {
                    logger.warn("Eager connect failed ", t);
                    return true;
                }).then();
            }
            return Mono.empty();
        });
    }

    private DefaultMcpTransportSession createTransportSession() {
        Function<String, Publisher<Void>> onClose = sessionId -> sessionId == null ? Mono.empty() : this.createDelete((String)sessionId);
        return new DefaultMcpTransportSession(onClose);
    }

    private Publisher<Void> createDelete(String sessionId) {
        URI uri = Utils.resolveUri(this.baseUri, this.endpoint);
        return Mono.defer(() -> {
            HttpRequest.Builder builder = this.requestBuilder.copy().uri(uri).header("Cache-Control", "no-cache").header("mcp-session-id", sessionId).header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION).DELETE();
            return Mono.from(this.httpRequestCustomizer.customize(builder, "DELETE", uri, null));
        }).flatMap(requestBuilder -> {
            HttpRequest request = requestBuilder.build();
            return Mono.fromFuture(() -> this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()));
        }).then();
    }

    @Override
    public void setExceptionHandler(Consumer<Throwable> handler) {
        logger.debug("Exception handler registered");
        this.exceptionHandler.set(handler);
    }

    private void handleException(Throwable t) {
        Consumer<Throwable> handler;
        logger.debug("Handling exception for session {}", (Object)HttpClientStreamableHttpTransport.sessionIdOrPlaceholder(this.activeSession.get()), (Object)t);
        if (t instanceof McpTransportSessionNotFoundException) {
            McpTransportSession invalidSession = this.activeSession.getAndSet(this.createTransportSession());
            logger.warn("Server does not recognize session {}. Invalidating.", invalidSession.sessionId());
            invalidSession.close();
        }
        if ((handler = this.exceptionHandler.get()) != null) {
            handler.accept(t);
        }
    }

    @Override
    public Mono<Void> closeGracefully() {
        return Mono.defer(() -> {
            logger.debug("Graceful close triggered");
            DefaultMcpTransportSession currentSession = this.activeSession.getAndSet(this.createTransportSession());
            if (currentSession != null) {
                return currentSession.closeGracefully();
            }
            return Mono.empty();
        });
    }

    private Mono<Disposable> reconnect(McpTransportStream<Disposable> stream) {
        return Mono.deferContextual(ctx -> {
            if (stream != null) {
                logger.debug("Reconnecting stream {} with lastId {}", (Object)stream.streamId(), stream.lastId());
            } else {
                logger.debug("Reconnecting with no prior stream");
            }
            AtomicReference<Disposable> disposableRef = new AtomicReference<Disposable>();
            McpTransportSession transportSession = this.activeSession.get();
            URI uri = Utils.resolveUri(this.baseUri, this.endpoint);
            Disposable connection = Mono.defer(() -> {
                HttpRequest.Builder requestBuilder = this.requestBuilder.copy();
                if (transportSession != null && transportSession.sessionId().isPresent()) {
                    requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get());
                }
                if (stream != null && stream.lastId().isPresent()) {
                    requestBuilder = requestBuilder.header("last-event-id", stream.lastId().get());
                }
                HttpRequest.Builder builder = requestBuilder.uri(uri).header("Accept", TEXT_EVENT_STREAM).header("Cache-Control", "no-cache").header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION).GET();
                return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, null));
            }).flatMapMany(requestBuilder -> Flux.create(sseSink -> this.httpClient.sendAsync(requestBuilder.build(), responseInfo -> ResponseSubscribers.sseToBodySubscriber(responseInfo, (FluxSink<ResponseSubscribers.ResponseEvent>)sseSink)).whenComplete((response, throwable) -> {
                if (throwable != null) {
                    sseSink.error(throwable);
                } else {
                    logger.debug("SSE connection established successfully");
                }
            })).map(responseEvent -> (ResponseSubscribers.SseResponseEvent)responseEvent).flatMap(responseEvent -> {
                int statusCode = responseEvent.responseInfo().statusCode();
                if (statusCode >= 200 && statusCode < 300) {
                    if (MESSAGE_EVENT_TYPE.equals(responseEvent.sseEvent().event())) {
                        try {
                            McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, responseEvent.sseEvent().data());
                            Tuple2 idWithMessages = Tuples.of(Optional.ofNullable(responseEvent.sseEvent().id()), List.of(message));
                            McpTransportStream sessionStream = stream != null ? stream : new DefaultMcpTransportStream(this.resumableStreams, this::reconnect);
                            logger.debug("Connected stream {}", (Object)sessionStream.streamId());
                            return Flux.from(sessionStream.consumeSseStream((Publisher<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>>)Flux.just((Object)idWithMessages)));
                        }
                        catch (IOException ioException) {
                            return Flux.error((Throwable)new McpError((Object)("Error parsing JSON-RPC message: " + responseEvent.sseEvent().data())));
                        }
                    }
                    logger.debug("Received SSE event with type: {}", (Object)responseEvent.sseEvent());
                    return Flux.empty();
                }
                if (statusCode == METHOD_NOT_ALLOWED) {
                    logger.debug("The server does not support SSE streams, using request-response mode.");
                    return Flux.empty();
                }
                if (statusCode == NOT_FOUND) {
                    String sessionIdRepresentation = HttpClientStreamableHttpTransport.sessionIdOrPlaceholder(transportSession);
                    McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException("Session not found for session ID: " + sessionIdRepresentation);
                    return Flux.error((Throwable)exception);
                }
                if (statusCode == BAD_REQUEST) {
                    String sessionIdRepresentation = HttpClientStreamableHttpTransport.sessionIdOrPlaceholder(transportSession);
                    McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException("Session not found for session ID: " + sessionIdRepresentation);
                    return Flux.error((Throwable)exception);
                }
                return Flux.error((Throwable)new McpError((Object)("Received unrecognized SSE event type: " + responseEvent.sseEvent().event())));
            }).flatMap(jsonrpcMessage -> (Publisher)this.handler.get().apply((Mono<McpSchema.JSONRPCMessage>)Mono.just((Object)jsonrpcMessage))).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete(t -> {
                this.handleException((Throwable)t);
                return true;
            }).doFinally(s -> {
                Disposable ref = disposableRef.getAndSet(null);
                if (ref != null) {
                    transportSession.removeConnection(ref);
                }
            })).contextWrite(ctx).subscribe();
            disposableRef.set(connection);
            transportSession.addConnection(connection);
            return Mono.just((Object)connection);
        });
    }

    private HttpResponse.BodyHandler<Void> toSendMessageBodySubscriber(FluxSink<ResponseSubscribers.ResponseEvent> sink) {
        HttpResponse.BodyHandler<Void> responseBodyHandler = responseInfo -> {
            String contentType = responseInfo.headers().firstValue("Content-Type").orElse("").toLowerCase();
            if (contentType.contains(TEXT_EVENT_STREAM)) {
                logger.debug("Received SSE stream response, using line subscriber");
                return ResponseSubscribers.sseToBodySubscriber(responseInfo, sink);
            }
            if (contentType.contains(APPLICATION_JSON)) {
                logger.debug("Received response, using string subscriber");
                return ResponseSubscribers.aggregateBodySubscriber(responseInfo, sink);
            }
            logger.debug("Received Bodyless response, using discarding subscriber");
            return ResponseSubscribers.bodilessBodySubscriber(responseInfo, sink);
        };
        return responseBodyHandler;
    }

    public String toString(McpSchema.JSONRPCMessage message) {
        try {
            return this.objectMapper.writeValueAsString((Object)message);
        }
        catch (IOException e) {
            throw new RuntimeException("Failed to serialize JSON-RPC message", e);
        }
    }

    @Override
    public Mono<Void> sendMessage(McpSchema.JSONRPCMessage sentMessage) {
        return Mono.create(deliveredSink -> {
            logger.debug("Sending message {}", (Object)sentMessage);
            AtomicReference<Disposable> disposableRef = new AtomicReference<Disposable>();
            McpTransportSession transportSession = this.activeSession.get();
            URI uri = Utils.resolveUri(this.baseUri, this.endpoint);
            String jsonBody = this.toString(sentMessage);
            Disposable connection = Mono.defer(() -> {
                HttpRequest.Builder requestBuilder = this.requestBuilder.copy();
                if (transportSession != null && transportSession.sessionId().isPresent()) {
                    requestBuilder = requestBuilder.header("mcp-session-id", transportSession.sessionId().get());
                }
                HttpRequest.Builder builder = requestBuilder.uri(uri).header("Accept", "application/json, text/event-stream").header("Content-Type", APPLICATION_JSON).header("Cache-Control", "no-cache").header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION).POST(HttpRequest.BodyPublishers.ofString(jsonBody));
                return Mono.from(this.httpRequestCustomizer.customize(builder, "GET", uri, jsonBody));
            }).flatMapMany(requestBuilder -> Flux.create(responseEventSink -> Mono.fromFuture((CompletableFuture)this.httpClient.sendAsync(requestBuilder.build(), this.toSendMessageBodySubscriber((FluxSink<ResponseSubscribers.ResponseEvent>)responseEventSink)).whenComplete((response, throwable) -> {
                if (throwable != null) {
                    responseEventSink.error(throwable);
                } else {
                    logger.debug("SSE connection established successfully");
                }
            })).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete().subscribe())).flatMap(responseEvent -> {
                if (transportSession.markInitialized(responseEvent.responseInfo().headers().firstValue("mcp-session-id").orElseGet(() -> null))) {
                    this.reconnect(null).contextWrite(deliveredSink.contextView()).subscribe();
                }
                String sessionRepresentation = HttpClientStreamableHttpTransport.sessionIdOrPlaceholder(transportSession);
                int statusCode = responseEvent.responseInfo().statusCode();
                if (statusCode >= 200 && statusCode < 300) {
                    String contentType = responseEvent.responseInfo().headers().firstValue("Content-Type").orElse("").toLowerCase();
                    if (contentType.isBlank()) {
                        logger.debug("No content type returned for POST in session {}", (Object)sessionRepresentation);
                        deliveredSink.success();
                        return Flux.empty();
                    }
                    if (contentType.contains(TEXT_EVENT_STREAM)) {
                        return Flux.just((Object)((ResponseSubscribers.SseResponseEvent)responseEvent).sseEvent()).flatMap(sseEvent -> {
                            try {
                                McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(this.objectMapper, sseEvent.data());
                                Tuple2 idWithMessages = Tuples.of(Optional.ofNullable(sseEvent.id()), List.of(message));
                                DefaultMcpTransportStream sessionStream = new DefaultMcpTransportStream(this.resumableStreams, this::reconnect);
                                logger.debug("Connected stream {}", (Object)sessionStream.streamId());
                                deliveredSink.success();
                                return Flux.from(sessionStream.consumeSseStream((Publisher<Tuple2<Optional<String>, Iterable<McpSchema.JSONRPCMessage>>>)Flux.just((Object)idWithMessages)));
                            }
                            catch (IOException ioException) {
                                return Flux.error((Throwable)new McpError((Object)("Error parsing JSON-RPC message: " + sseEvent.data())));
                            }
                        });
                    }
                    if (contentType.contains(APPLICATION_JSON)) {
                        deliveredSink.success();
                        String data = ((ResponseSubscribers.AggregateResponseEvent)responseEvent).data();
                        if (sentMessage instanceof McpSchema.JSONRPCNotification && Utils.hasText(data)) {
                            logger.warn("Notification: {} received non-compliant response: {}", (Object)sentMessage, (Object)data);
                            return Mono.empty();
                        }
                        try {
                            return Mono.just((Object)McpSchema.deserializeJsonRpcMessage(this.objectMapper, data));
                        }
                        catch (IOException e) {
                            return Mono.error((Throwable)e);
                        }
                    }
                    logger.warn("Unknown media type {} returned for POST in session {}", (Object)contentType, (Object)sessionRepresentation);
                    return Flux.error((Throwable)new RuntimeException("Unknown media type returned: " + contentType));
                }
                if (statusCode == NOT_FOUND) {
                    McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException("Session not found for session ID: " + sessionRepresentation);
                    return Flux.error((Throwable)exception);
                }
                if (statusCode == BAD_REQUEST) {
                    McpTransportSessionNotFoundException exception = new McpTransportSessionNotFoundException("Session not found for session ID: " + sessionRepresentation);
                    return Flux.error((Throwable)exception);
                }
                return Flux.error((Throwable)new RuntimeException("Failed to send message: " + String.valueOf(responseEvent)));
            }).flatMap(jsonRpcMessage -> (Publisher)this.handler.get().apply((Mono<McpSchema.JSONRPCMessage>)Mono.just((Object)jsonRpcMessage))).onErrorMap(CompletionException.class, t -> t.getCause()).onErrorComplete(t -> {
                this.handleException((Throwable)t);
                deliveredSink.error(t);
                return true;
            }).doFinally(s -> {
                logger.debug("SendMessage finally: {}", s);
                Disposable ref = disposableRef.getAndSet(null);
                if (ref != null) {
                    transportSession.removeConnection(ref);
                }
            }).contextWrite(deliveredSink.contextView()).subscribe();
            disposableRef.set(connection);
            transportSession.addConnection(connection);
        });
    }

    private static String sessionIdOrPlaceholder(McpTransportSession<?> transportSession) {
        return transportSession.sessionId().orElse("[missing_session_id]");
    }

    @Override
    public <T> T unmarshalFrom(Object data, TypeReference<T> typeRef) {
        return (T)this.objectMapper.convertValue(data, typeRef);
    }

    public static class Builder {
        private final String baseUri;
        private ObjectMapper objectMapper;
        private HttpClient.Builder clientBuilder = HttpClient.newBuilder().version(HttpClient.Version.HTTP_1_1).connectTimeout(Duration.ofSeconds(10L));
        private String endpoint = "/mcp";
        private boolean resumableStreams = true;
        private boolean openConnectionOnStartup = false;
        private HttpRequest.Builder requestBuilder = HttpRequest.newBuilder();
        private AsyncHttpRequestCustomizer httpRequestCustomizer = AsyncHttpRequestCustomizer.NOOP;

        private Builder(String baseUri) {
            Assert.hasText(baseUri, "baseUri must not be empty");
            this.baseUri = baseUri;
        }

        public Builder clientBuilder(HttpClient.Builder clientBuilder) {
            Assert.notNull(clientBuilder, "clientBuilder must not be null");
            this.clientBuilder = clientBuilder;
            return this;
        }

        public Builder customizeClient(Consumer<HttpClient.Builder> clientCustomizer) {
            Assert.notNull(clientCustomizer, "clientCustomizer must not be null");
            clientCustomizer.accept(this.clientBuilder);
            return this;
        }

        public Builder requestBuilder(HttpRequest.Builder requestBuilder) {
            Assert.notNull(requestBuilder, "requestBuilder must not be null");
            this.requestBuilder = requestBuilder;
            return this;
        }

        public Builder customizeRequest(Consumer<HttpRequest.Builder> requestCustomizer) {
            Assert.notNull(requestCustomizer, "requestCustomizer must not be null");
            requestCustomizer.accept(this.requestBuilder);
            return this;
        }

        public Builder objectMapper(ObjectMapper objectMapper) {
            Assert.notNull(objectMapper, "ObjectMapper must not be null");
            this.objectMapper = objectMapper;
            return this;
        }

        public Builder endpoint(String endpoint) {
            Assert.hasText(endpoint, "endpoint must be a non-empty String");
            this.endpoint = endpoint;
            return this;
        }

        public Builder resumableStreams(boolean resumableStreams) {
            this.resumableStreams = resumableStreams;
            return this;
        }

        public Builder openConnectionOnStartup(boolean openConnectionOnStartup) {
            this.openConnectionOnStartup = openConnectionOnStartup;
            return this;
        }

        public Builder httpRequestCustomizer(SyncHttpRequestCustomizer syncHttpRequestCustomizer) {
            this.httpRequestCustomizer = AsyncHttpRequestCustomizer.fromSync(syncHttpRequestCustomizer);
            return this;
        }

        public Builder asyncHttpRequestCustomizer(AsyncHttpRequestCustomizer asyncHttpRequestCustomizer) {
            this.httpRequestCustomizer = asyncHttpRequestCustomizer;
            return this;
        }

        public HttpClientStreamableHttpTransport build() {
            ObjectMapper objectMapper = this.objectMapper != null ? this.objectMapper : new ObjectMapper();
            return new HttpClientStreamableHttpTransport(objectMapper, this.clientBuilder.build(), this.requestBuilder, this.baseUri, this.endpoint, this.resumableStreams, this.openConnectionOnStartup, this.httpRequestCustomizer);
        }
    }
}

