/*
 * Decompiled with CFR 0.152.
 */
package io.modelcontextprotocol.spec;

import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.server.McpAsyncServerExchange;
import io.modelcontextprotocol.server.McpInitRequestHandler;
import io.modelcontextprotocol.server.McpNotificationHandler;
import io.modelcontextprotocol.server.McpRequestHandler;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpLoggableSession;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.util.Assert;
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoSink;
import reactor.core.publisher.Sinks;

public class McpServerSession
implements McpLoggableSession {
    private static final Logger logger = LoggerFactory.getLogger(McpServerSession.class);
    private final ConcurrentHashMap<Object, MonoSink<McpSchema.JSONRPCResponse>> pendingResponses = new ConcurrentHashMap();
    private final String id;
    private final Duration requestTimeout;
    private final AtomicLong requestCounter = new AtomicLong(0L);
    private final McpInitRequestHandler initRequestHandler;
    private final Map<String, McpRequestHandler<?>> requestHandlers;
    private final Map<String, McpNotificationHandler> notificationHandlers;
    private final McpServerTransport transport;
    private final Sinks.One<McpAsyncServerExchange> exchangeSink = Sinks.one();
    private final AtomicReference<McpSchema.ClientCapabilities> clientCapabilities = new AtomicReference();
    private final AtomicReference<McpSchema.Implementation> clientInfo = new AtomicReference();
    private static final int STATE_UNINITIALIZED = 0;
    private static final int STATE_INITIALIZING = 1;
    private static final int STATE_INITIALIZED = 2;
    private final AtomicInteger state = new AtomicInteger(0);
    private volatile McpSchema.LoggingLevel minLoggingLevel = McpSchema.LoggingLevel.INFO;

    public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, McpInitRequestHandler initHandler, Map<String, McpRequestHandler<?>> requestHandlers, Map<String, McpNotificationHandler> notificationHandlers) {
        this.id = id;
        this.requestTimeout = requestTimeout;
        this.transport = transport;
        this.initRequestHandler = initHandler;
        this.requestHandlers = requestHandlers;
        this.notificationHandlers = notificationHandlers;
    }

    @Deprecated
    public McpServerSession(String id, Duration requestTimeout, McpServerTransport transport, McpInitRequestHandler initHandler, InitNotificationHandler initNotificationHandler, Map<String, McpRequestHandler<?>> requestHandlers, Map<String, McpNotificationHandler> notificationHandlers) {
        this.id = id;
        this.requestTimeout = requestTimeout;
        this.transport = transport;
        this.initRequestHandler = initHandler;
        this.requestHandlers = requestHandlers;
        this.notificationHandlers = notificationHandlers;
    }

    public String getId() {
        return this.id;
    }

    public void init(McpSchema.ClientCapabilities clientCapabilities, McpSchema.Implementation clientInfo) {
        this.clientCapabilities.lazySet(clientCapabilities);
        this.clientInfo.lazySet(clientInfo);
    }

    private String generateRequestId() {
        return this.id + "-" + this.requestCounter.getAndIncrement();
    }

    @Override
    public void setMinLoggingLevel(McpSchema.LoggingLevel minLoggingLevel) {
        Assert.notNull((Object)minLoggingLevel, "minLoggingLevel must not be null");
        this.minLoggingLevel = minLoggingLevel;
    }

    @Override
    public boolean isNotificationForLevelAllowed(McpSchema.LoggingLevel loggingLevel) {
        return loggingLevel.level() >= this.minLoggingLevel.level();
    }

    @Override
    public <T> Mono<T> sendRequest(String method, Object requestParams, TypeRef<T> typeRef) {
        String requestId = this.generateRequestId();
        return Mono.create(sink -> {
            this.pendingResponses.put(requestId, (MonoSink<McpSchema.JSONRPCResponse>)sink);
            McpSchema.JSONRPCRequest jsonrpcRequest = new McpSchema.JSONRPCRequest("2.0", method, requestId, requestParams);
            this.transport.sendMessage(jsonrpcRequest).subscribe(v -> {}, error -> {
                this.pendingResponses.remove(requestId);
                sink.error(error);
            });
        }).timeout(this.requestTimeout).handle((jsonRpcResponse, sink) -> {
            if (jsonRpcResponse.error() != null) {
                sink.error((Throwable)new McpError(jsonRpcResponse.error()));
            } else if (typeRef.getType().equals(Void.class)) {
                sink.complete();
            } else {
                sink.next(this.transport.unmarshalFrom(jsonRpcResponse.result(), typeRef));
            }
        });
    }

    @Override
    public Mono<Void> sendNotification(String method, Object params) {
        McpSchema.JSONRPCNotification jsonrpcNotification = new McpSchema.JSONRPCNotification("2.0", method, params);
        return this.transport.sendMessage(jsonrpcNotification);
    }

    public Mono<Void> handle(McpSchema.JSONRPCMessage message) {
        return Mono.deferContextual(ctx -> {
            McpTransportContext transportContext = (McpTransportContext)ctx.getOrDefault((Object)"MCP_TRANSPORT_CONTEXT", (Object)McpTransportContext.EMPTY);
            if (message instanceof McpSchema.JSONRPCResponse) {
                McpSchema.JSONRPCResponse response = (McpSchema.JSONRPCResponse)message;
                logger.debug("Received response: {}", (Object)response);
                if (response.id() != null) {
                    MonoSink<McpSchema.JSONRPCResponse> sink = this.pendingResponses.remove(response.id());
                    if (sink == null) {
                        logger.warn("Unexpected response for unknown id {}", response.id());
                    } else {
                        sink.success((Object)response);
                    }
                } else {
                    logger.error("Discarded MCP request response without session id. This is an indication of a bug in the request sender code that can lead to memory leaks as pending requests will never be completed.");
                }
                return Mono.empty();
            }
            if (message instanceof McpSchema.JSONRPCRequest) {
                McpSchema.JSONRPCRequest request = (McpSchema.JSONRPCRequest)message;
                logger.debug("Received request: {}", (Object)request);
                return this.handleIncomingRequest(request, transportContext).onErrorResume(error -> {
                    McpError mcpError;
                    McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = error instanceof McpError && (mcpError = (McpError)error).getJsonRpcError() != null ? mcpError.getJsonRpcError() : new McpSchema.JSONRPCResponse.JSONRPCError(-32603, error.getMessage(), McpError.aggregateExceptionMessages(error));
                    McpSchema.JSONRPCResponse errorResponse = new McpSchema.JSONRPCResponse("2.0", request.id(), null, jsonRpcError);
                    return this.transport.sendMessage(errorResponse).then(Mono.empty());
                }).flatMap(this.transport::sendMessage);
            }
            if (message instanceof McpSchema.JSONRPCNotification) {
                McpSchema.JSONRPCNotification notification = (McpSchema.JSONRPCNotification)message;
                logger.debug("Received notification: {}", (Object)notification);
                return this.handleIncomingNotification(notification, transportContext).doOnError(error -> logger.error("Error handling notification: {}", (Object)error.getMessage()));
            }
            logger.warn("Received unknown message type: {}", (Object)message);
            return Mono.empty();
        });
    }

    private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCRequest request, McpTransportContext transportContext) {
        return Mono.defer(() -> {
            Mono<McpSchema.InitializeResult> resultMono;
            if ("initialize".equals(request.method())) {
                McpSchema.InitializeRequest initializeRequest = this.transport.unmarshalFrom(request.params(), new TypeRef<McpSchema.InitializeRequest>(){});
                this.state.lazySet(1);
                this.init(initializeRequest.capabilities(), initializeRequest.clientInfo());
                resultMono = this.initRequestHandler.handle(initializeRequest);
            } else {
                McpRequestHandler<?> handler = this.requestHandlers.get(request.method());
                if (handler == null) {
                    MethodNotFoundError error2 = this.getMethodNotFoundError(request.method());
                    return Mono.just((Object)new McpSchema.JSONRPCResponse("2.0", request.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError(-32601, error2.message(), error2.data())));
                }
                resultMono = this.exchangeSink.asMono().flatMap(exchange -> handler.handle(this.copyExchange((McpAsyncServerExchange)exchange, transportContext), request.params()));
            }
            return resultMono.map(result -> new McpSchema.JSONRPCResponse("2.0", request.id(), result, null)).onErrorResume(error -> {
                McpError mcpError;
                McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError = error instanceof McpError && (mcpError = (McpError)error).getJsonRpcError() != null ? mcpError.getJsonRpcError() : new McpSchema.JSONRPCResponse.JSONRPCError(-32603, error.getMessage(), McpError.aggregateExceptionMessages(error));
                return Mono.just((Object)new McpSchema.JSONRPCResponse("2.0", request.id(), null, jsonRpcError));
            });
        });
    }

    private Mono<Void> handleIncomingNotification(McpSchema.JSONRPCNotification notification, McpTransportContext transportContext) {
        return Mono.defer(() -> {
            McpNotificationHandler handler;
            if ("notifications/initialized".equals(notification.method())) {
                this.state.lazySet(2);
                this.exchangeSink.tryEmitValue((Object)new McpAsyncServerExchange(this.id, this, this.clientCapabilities.get(), this.clientInfo.get(), transportContext));
            }
            if ((handler = this.notificationHandlers.get(notification.method())) == null) {
                logger.warn("No handler registered for notification method: {}", (Object)notification);
                return Mono.empty();
            }
            return this.exchangeSink.asMono().flatMap(exchange -> handler.handle(this.copyExchange((McpAsyncServerExchange)exchange, transportContext), notification.params()));
        });
    }

    private McpAsyncServerExchange copyExchange(McpAsyncServerExchange exchange, McpTransportContext transportContext) {
        return new McpAsyncServerExchange(exchange.sessionId(), this, exchange.getClientCapabilities(), exchange.getClientInfo(), transportContext);
    }

    private MethodNotFoundError getMethodNotFoundError(String method) {
        return new MethodNotFoundError(method, "Method not found: " + method, null);
    }

    @Override
    public Mono<Void> closeGracefully() {
        return this.transport.closeGracefully();
    }

    @Override
    public void close() {
        this.transport.close();
    }

    record MethodNotFoundError(String method, String message, Object data) {
    }

    @FunctionalInterface
    public static interface Factory {
        public McpServerSession create(McpServerTransport var1);
    }

    @Deprecated
    public static interface RequestHandler<T> {
        public Mono<T> handle(McpAsyncServerExchange var1, Object var2);
    }

    @Deprecated
    public static interface NotificationHandler {
        public Mono<Void> handle(McpAsyncServerExchange var1, Object var2);
    }

    public static interface InitNotificationHandler {
        public Mono<Void> handle();
    }

    @Deprecated
    public static interface InitRequestHandler {
        public Mono<McpSchema.InitializeResult> handle(McpSchema.InitializeRequest var1);
    }
}

