Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mcp/client/_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async def connect(
# Unwrap FastMCP to get underlying Server
actual_server: Server[Any]
if isinstance(self._server, FastMCP):
actual_server = self._server._mcp_server # type: ignore[reportPrivateUsage]
actual_server = self._server.mcp_server
else:
actual_server = self._server

Expand Down
1 change: 0 additions & 1 deletion src/mcp/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,5 +296,4 @@ async def list_tools(self, *, cursor: str | None = None, meta: RequestParamsMeta

async def send_roots_list_changed(self) -> None:
"""Send a notification that the roots list has changed."""
# TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support.
await self.session.send_roots_list_changed() # pragma: no cover
12 changes: 10 additions & 2 deletions src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,14 @@ def session_manager(self) -> StreamableHTTPSessionManager:
"""
return self._mcp_server.session_manager # pragma: no cover

@property
def mcp_server(self):
"""Get the underlying MCP server instance.

This is exposed to enable advanced use cases like in-memory testing.
"""
return self._mcp_server

@overload
def run(self, transport: Literal["stdio"] = ...) -> None: ...

Expand Down Expand Up @@ -255,8 +263,8 @@ def run(
transport: Transport protocol to use ("stdio", "sse", or "streamable-http")
**kwargs: Transport-specific options (see overloads for details)
"""
TRANSPORTS = Literal["stdio", "sse", "streamable-http"]
if transport not in TRANSPORTS.__args__: # type: ignore # pragma: no cover
SUPPORTED_TRANSPORTS = {"stdio", "sse", "streamable-http"}
if transport not in SUPPORTED_TRANSPORTS: # pragma: no cover
raise ValueError(f"Unknown transport: {transport}")

match transport:
Expand Down
50 changes: 32 additions & 18 deletions src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,34 +133,45 @@ def check_client_capability(self, capability: types.ClientCapabilities) -> bool:

client_caps = self._client_params.capabilities

if capability.roots is not None: # pragma: lax no cover
if client_caps.roots is None:
return False
if capability.roots.list_changed and not client_caps.roots.list_changed:
return False
# Check roots capability
if capability.roots and not client_caps.roots: # pragma: lax no cover
return False
if (
capability.roots
and capability.roots.list_changed
and client_caps.roots
and not client_caps.roots.list_changed
): # pragma: lax no cover
return False

if capability.sampling is not None: # pragma: lax no cover
if client_caps.sampling is None:
return False
if capability.sampling.context is not None and client_caps.sampling.context is None:
# Check sampling capability
if capability.sampling and not client_caps.sampling: # pragma: lax no cover
return False
if capability.sampling and client_caps.sampling: # pragma: lax no cover
if capability.sampling.context and not client_caps.sampling.context: # pragma: lax no cover
return False
if capability.sampling.tools is not None and client_caps.sampling.tools is None:
if capability.sampling.tools and not client_caps.sampling.tools: # pragma: lax no cover
return False

if capability.elicitation is not None and client_caps.elicitation is None: # pragma: lax no cover
# Check elicitation capability
if capability.elicitation and not client_caps.elicitation: # pragma: lax no cover
return False

if capability.experimental is not None: # pragma: lax no cover
if client_caps.experimental is None:
# Check experimental capability
if capability.experimental: # pragma: lax no cover
if not client_caps.experimental: # pragma: lax no cover
return False
for exp_key, exp_value in capability.experimental.items():
if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
for exp_key, exp_value in capability.experimental.items(): # pragma: lax no cover
if (
exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value
): # pragma: lax no cover
return False

if capability.tasks is not None: # pragma: lax no cover
if client_caps.tasks is None:
# Check tasks capability
if capability.tasks: # pragma: lax no cover
if not client_caps.tasks: # pragma: lax no cover
return False
if not check_tasks_capability(capability.tasks, client_caps.tasks):
if not check_tasks_capability(capability.tasks, client_caps.tasks): # pragma: lax no cover
return False

return True
Expand Down Expand Up @@ -207,6 +218,9 @@ async def _received_notification(self, notification: types.ClientNotification) -
match notification:
case types.InitializedNotification():
self._initialization_state = InitializationState.Initialized
case types.RootsListChangedNotification():
# When roots list changes, server should request updated list
await self.list_roots() # pragma: no cover
case _:
if self._initialization_state != InitializationState.Initialized: # pragma: no cover
raise RuntimeError("Received notification before initialization was complete")
Expand Down
Loading