Skip to content

Commit ecca00d

Browse files
committed
feat: 支持 SSE 连接
1 parent f29ef14 commit ecca00d

5 files changed

Lines changed: 39 additions & 35 deletions

File tree

docs/tutorials/agent/agent.md

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,17 +309,15 @@ import asyncio
309309
qwen = Qwen()
310310

311311
async def main():
312-
async with MCPServer(
313-
type='stdio',
314-
command='uv',
315-
args=['--directory', 'examples/agent', 'run', 'mcp_server.py'],
316-
) as mcp_server:
312+
async with MCPServer( {
313+
'command': 'uv',
314+
'args': ['--directory', 'examples/agent', 'run', 'mcp_server.py'],
315+
} ) as mcp_server:
317316

318317
agent = Agent(
319318
llm=qwen,
320319
mcp_server=[mcp_server]
321320
)
322-
await agent.initialize()
323321
controller = Controller()
324322
_, resp = await controller.run(agent, "帮我查询南京今天的天气")
325323
print(resp)
@@ -328,12 +326,10 @@ if __name__ == '__main__':
328326
asyncio.run(main())
329327
```
330328

331-
这里有三需要注意:
329+
这里有两点需要注意:
332330

333331
- MCP Server 必须使用 `async with` 语句块来启动。
334332

335-
- 如果给 `Agent` 传递了 `mcp_server` 参数, 必须调用 `await agent.initialize()` 方法等待初始化完成。
336-
337333
- `Controller` 必须使用异步的 `run` 方法来启动。不能使用同步的 `run_sync` 方法, 因为本质上 `run_sync` 只是 `asyncio.run(controller.run(...))` 的包装。
338334

339335
### MCP Server 与外部工具的比较

examples/agent/use_mcp.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,24 @@
44
# File Name: examples/agent/use_mcp.py
55
# Description: 使用 MCP 工具
66

7-
from course_graph.agent import Agent, Controller, MCPServer, TraceEvent
7+
from course_graph.agent import Agent, Controller, MCPServer
88
from course_graph.llm import Qwen
99
import asyncio
1010
qwen = Qwen()
1111

1212

1313
async def main():
1414
async with MCPServer(
15-
type='stdio',
16-
command='uv',
17-
args=['--directory', 'examples/agent', 'run', 'mcp_server.py'],
15+
{
16+
'command': 'uv',
17+
'args': ['--directory', 'examples/agent', 'run', 'mcp_server.py'],
18+
}
1819
) as mcp_server:
1920

2021
agent = Agent(
2122
llm=qwen,
2223
mcp_server=[mcp_server]
2324
)
24-
await agent.initialize()
2525
controller = Controller()
2626
_, resp = await controller.run(agent, "帮我查询今天南京的天气")
2727
print(resp)

src/course_graph/agent/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@
88
from .agent import Agent
99
from .controller import Controller
1010
from .types import Result, ContextVariables, TraceEvent
11-
from .mcp import MCPServer
11+
from .mcp import MCPServer, STDIO, SSE

src/course_graph/agent/agent.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,9 @@ def __init__(
6969
self.tool_choice = tool_choice
7070

7171
self.messages: list[ChatCompletionMessageParam] = []
72-
self.mcp_server = mcp_server
7372

74-
async def initialize(self):
75-
""" 等待初始化智能体
76-
"""
77-
for server in self.mcp_server:
78-
tools = await server.list_tools()
73+
for server in mcp_server:
74+
tools = server.tools
7975
for tool in tools:
8076
self.tools.append({
8177
'type': 'function',

src/course_graph/agent/mcp.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,48 @@
44
# File Name: course_graph/agent/mcp.py
55
# Description: MCP Server 和 Client 实现相关
66

7-
from typing import Literal
7+
from typing import TypedDict, Required, NotRequired
88
from mcp import ClientSession, StdioServerParameters
99
from mcp.client.stdio import stdio_client
10+
from mcp.client.sse import sse_client
1011
from contextlib import AsyncExitStack
1112
from mcp.types import Tool
1213

1314

15+
class STDIO(TypedDict):
16+
command: Required[str]
17+
args: Required[list[str]]
18+
envs: NotRequired[dict[str, str]]
19+
20+
21+
class SSE(TypedDict):
22+
url: Required[str]
23+
headers: NotRequired[dict[str, str]]
24+
25+
1426
class MCPServer:
15-
def __init__(self, type: Literal['stdio'], command: str, args: list[str], envs: dict[str, str] = None):
27+
def __init__(self, server: STDIO | SSE):
1628
""" MCP 服务器
1729
1830
Args:
19-
type (Literal['stdio']): 服务器类型, 目前只支持 'stdio'
20-
command (str): 命令
21-
args (list[str]): 参数
22-
envs (dict[str, str], optional): 环境变量. Defaults to None.
31+
server (STDIO | SSE): MCP 服务器配置
2332
"""
24-
self.params = StdioServerParameters(command=command, args=args, envs=envs)
33+
self.server = server
2534
self.stack = AsyncExitStack()
26-
self.session = None
35+
self.session: ClientSession = None
36+
self.tools: list[Tool] = None
2737

2838
async def __aenter__(self):
29-
stdio_transport = await self.stack.enter_async_context(stdio_client(self.params))
30-
self.stdio, self.write = stdio_transport
31-
self.session = await self.stack.enter_async_context(ClientSession(self.stdio, self.write))
39+
if 'command' in self.server.keys():
40+
self.params = StdioServerParameters(command=self.server['command'], args=self.server['args'], envs=self.server.get('envs'))
41+
transport = await self.stack.enter_async_context(stdio_client(self.params))
42+
else:
43+
transport = await self.stack.enter_async_context(sse_client(url=self.server['url'], headers=self.server.get('headers')))
44+
self.session = await self.stack.enter_async_context(ClientSession(transport[0], transport[1]))
3245
await self.session.initialize()
46+
47+
self.tools = (await self.session.list_tools()).tools
3348
return self
3449

35-
async def list_tools(self) -> list[Tool]:
36-
return (await self.session.list_tools()).tools
37-
3850
async def __aexit__(self, exc_type, exc_value, traceback):
3951
await self.stack.aclose()

0 commit comments

Comments
 (0)