|
4 | 4 | # File Name: course_graph/agent/mcp.py |
5 | 5 | # Description: MCP Server 和 Client 实现相关 |
6 | 6 |
|
7 | | -from typing import Literal |
| 7 | +from typing import TypedDict, Required, NotRequired |
8 | 8 | from mcp import ClientSession, StdioServerParameters |
9 | 9 | from mcp.client.stdio import stdio_client |
| 10 | +from mcp.client.sse import sse_client |
10 | 11 | from contextlib import AsyncExitStack |
11 | 12 | from mcp.types import Tool |
12 | 13 |
|
13 | 14 |
|
| 15 | +class STDIO(TypedDict): |
| 16 | + command: Required[str] |
| 17 | + args: Required[list[str]] |
| 18 | + envs: NotRequired[dict[str, str]] |
| 19 | + |
| 20 | + |
| 21 | +class SSE(TypedDict): |
| 22 | + url: Required[str] |
| 23 | + headers: NotRequired[dict[str, str]] |
| 24 | + |
| 25 | + |
14 | 26 | class MCPServer: |
15 | | - def __init__(self, type: Literal['stdio'], command: str, args: list[str], envs: dict[str, str] = None): |
| 27 | + def __init__(self, server: STDIO | SSE): |
16 | 28 | """ MCP 服务器 |
17 | 29 |
|
18 | 30 | Args: |
19 | | - type (Literal['stdio']): 服务器类型, 目前只支持 'stdio' |
20 | | - command (str): 命令 |
21 | | - args (list[str]): 参数 |
22 | | - envs (dict[str, str], optional): 环境变量. Defaults to None. |
| 31 | + server (STDIO | SSE): MCP 服务器配置 |
23 | 32 | """ |
24 | | - self.params = StdioServerParameters(command=command, args=args, envs=envs) |
| 33 | + self.server = server |
25 | 34 | self.stack = AsyncExitStack() |
26 | | - self.session = None |
| 35 | + self.session: ClientSession = None |
| 36 | + self.tools: list[Tool] = None |
27 | 37 |
|
28 | 38 | async def __aenter__(self): |
29 | | - stdio_transport = await self.stack.enter_async_context(stdio_client(self.params)) |
30 | | - self.stdio, self.write = stdio_transport |
31 | | - self.session = await self.stack.enter_async_context(ClientSession(self.stdio, self.write)) |
| 39 | + if 'command' in self.server.keys(): |
| 40 | + self.params = StdioServerParameters(command=self.server['command'], args=self.server['args'], envs=self.server.get('envs')) |
| 41 | + transport = await self.stack.enter_async_context(stdio_client(self.params)) |
| 42 | + else: |
| 43 | + transport = await self.stack.enter_async_context(sse_client(url=self.server['url'], headers=self.server.get('headers'))) |
| 44 | + self.session = await self.stack.enter_async_context(ClientSession(transport[0], transport[1])) |
32 | 45 | await self.session.initialize() |
| 46 | + |
| 47 | + self.tools = (await self.session.list_tools()).tools |
33 | 48 | return self |
34 | 49 |
|
35 | | - async def list_tools(self) -> list[Tool]: |
36 | | - return (await self.session.list_tools()).tools |
37 | | - |
38 | 50 | async def __aexit__(self, exc_type, exc_value, traceback): |
39 | 51 | await self.stack.aclose() |
0 commit comments