Skip to content

Commit bceb1cb

Browse files
committed
feat: 为 agent 添加 instruction args
1 parent e2a09ef commit bceb1cb

5 files changed

Lines changed: 81 additions & 43 deletions

File tree

docs/tutorials/agent/agent.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,19 @@ translator = Agent(name='translator',
2626

2727
当然你也可以将具体的指令需求直接写在 `instruction` 中, 这种方式将在 [工作流编排](#工作流编排) 中具体解释。
2828

29+
### 指令参数
30+
31+
在创建 `Agent` 对象时, `instruction` 不仅可以是一个字符串,也可以是一个函数,但这个函数 **必需** 返回一个字符串。该函数可以通过 `instruction_args` 传递参数, 这些参数将会在智能体初始化时传递给 `instruction` 函数。
32+
33+
```python
34+
def get_instruction(name: str):
35+
return f'你的名字是{name}'
36+
37+
agent = Agent(llm=llm,
38+
instruction=get_instruction,
39+
instruction_args={'name': '张三'})
40+
```
41+
2942
## 创建一个控制器
3043

3144
```python
@@ -194,7 +207,7 @@ controller = Controller(context_variables={'current_time': '2024/09/01'})
194207

195208
#### instruction中使用
196209

197-
在创建 `Agent` 对象时, `instruction` 不仅可以是一个字符串,也可以是一个函数,但这个函数 **必需** 返回一个字符串, 同时可以额外传递一个`ContextVariables` 类型的形参
210+
`instruction` 类型为函数时, 可以额外传递一个`ContextVariables` 类型的形参, 但 **必需** 标注这个形参的类型为 `ContextVariables` 类型
198211

199212
``` python
200213
from course_graph.agent import ContextVariables
@@ -214,7 +227,7 @@ assistant = Agent(name="assistant",
214227

215228
在定义外部工具函数时, 也可以传递一个 `ContextVariables` 类型的形参。同样的, 控制器也会在调用这些函数的时候自动注入上下文变量。
216229

217-
虽然不需要在文档中描述这个形参, 但是 **必需** 标注这个形参的类型为 `ContextVariables` 类型:
230+
相同的,虽然不需要在文档中描述这个形参, 但是 **必需** 标注这个形参的类型为 `ContextVariables` 类型:
218231

219232
```python
220233
def get_weather(location: str, context_variables: ContextVariables) -> str:
@@ -354,6 +367,7 @@ controller = Controller(trace_callback=pprint)
354367
- `TOOL_CALL`: 工具调用
355368
- `TOOL_RESULT`: 工具调用结果
356369
- `CONTEXT_UPDATE`: 上下文变量更新
370+
- `MCP_TOOL_CALL`: MCP 工具调用
357371

358372
## 多智能体编排
359373

src/course_graph/agent/agent.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
from openai.types.chat import *
1111
import inspect
1212
import docstring_parser
13-
from typing import Callable, Awaitable
14-
from typing import Literal
13+
from typing import Callable, Awaitable, Literal
1514
from openai import NOT_GIVEN, NotGiven
1615
from .mcp import MCPServer
16+
from shortuuid import uuid
1717

1818

1919
class Agent:
@@ -23,9 +23,10 @@ def __init__(
2323
llm: LLMBase,
2424
name: str = 'Assistant',
2525
functions: list[Callable | Awaitable] = None,
26-
tool_choice: str | NotGiven | Literal['required', 'auto', 'none'] = NOT_GIVEN,
27-
parallel_tool_calls: bool | NotGiven = NOT_GIVEN,
28-
instruction: str | Callable[[ContextVariables], str] | Callable[[], str] = 'You are a helpful assistant.',
26+
tool_choice: str | Literal['required', 'auto', 'none'] = 'auto',
27+
parallel_tool_calls: bool = False,
28+
instruction: str | Callable[..., str] = 'You are a helpful assistant.',
29+
instruction_args: dict = None,
2930
mcp_server: list[MCPServer] = None,
3031
mcp_impl: Literal['function_call'] = 'function_call'
3132
) -> None:
@@ -36,15 +37,17 @@ def __init__(
3637
name (str, optional): 名称. Defaults to 'Assistant'.
3738
functions: (list[Callable | Awaitable], optional): 工具函数. Defaults to None.
3839
parallel_tool_calls: (bool, optional): 允许工具并行调用. Defaults to False.
39-
tool_choice: (Literal['required', 'auto', 'none'] | NotGiven, optional). 强制使用工具函数, 选择模式或提供函数名称. Defaults to NOT_GIVEN.
40-
instruction (str | Callable[[ContextVariables], str] | Callable[[], str], optional): 指令. Defaults to 'You are a helpful assistant.'.
40+
tool_choice: (Literal['required', 'auto', 'none'], optional). 强制使用工具函数, 选择模式或提供函数名称. Defaults to 'auto'.
41+
instruction (str | Callable[Any, str], optional): 指令. Defaults to 'You are a helpful assistant.'.
42+
instruction_args: (dict, optional): 指令参数. Defaults to {}.
4143
mcp_server: (list[MCPServer], optional): MCP 服务器. Defaults to None.
4244
mcp_impl: (Literal['function_call'] | NotGiven, optional): MCP 协议实现方式, 目前只支持 'function_call'. Defaults to 'function_call'.
4345
"""
46+
self.id = str(uuid())
4447
self.llm = llm
4548
self.name = name
4649
self.instruction = instruction
47-
50+
self.instruction_args = instruction_args
4851
self.tools: list[ChatCompletionToolParam] = [] # for LLM
4952

5053
self.tool_functions: dict[str, Callable | Awaitable] = {} # for local function call
@@ -56,8 +59,9 @@ def __init__(
5659

5760
if functions:
5861
self.add_tool_functions(*functions)
62+
5963

60-
if tool_choice != NOT_GIVEN and tool_choice not in ['required', 'auto', 'none']: # 需要注意工具无限循环
64+
if tool_choice not in ['required', 'auto', 'none']: # 需要注意工具无限循环
6165
self.tool_choice = {
6266
"type": "function",
6367
"function": {
@@ -281,3 +285,6 @@ def add_tool_functions(self, *functions: Callable | Awaitable) -> 'Agent':
281285
}
282286
})
283287
return self
288+
289+
def __repr__(self) -> str:
290+
return f'Agent(id={self.id}, name={self.name})'

src/course_graph/agent/controller.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,13 @@ def set_agent_instruction(self, agent: Agent) -> None:
5555
match agent.instruction:
5656
case _ if callable(agent.instruction):
5757
parameters = inspect.signature(agent.instruction).parameters
58-
args = (self.context_variables,) if len(parameters) == 1 else ()
59-
agent.llm.instruction = agent.instruction(*args)
58+
args = {}
59+
for arg_name, p in parameters.items():
60+
if p.annotation == ContextVariables:
61+
args[arg_name] = self.context_variables
62+
else:
63+
args[arg_name] = agent.instruction_args.get(arg_name, p.default)
64+
agent.llm.instruction = agent.instruction(**args)
6065
case _:
6166
agent.llm.instruction = agent.instruction
6267

@@ -88,32 +93,26 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
8893
self._add_trace_event(TraceEvent(
8994
timestamp=datetime.now(),
9095
event_type=TraceEventType.USER_MESSAGE,
91-
agent_name=agent.name,
96+
agent=agent,
9297
data={'message': message}
9398
))
9499

95-
assistant_output = agent.chat_completion(message)
96-
97-
self._add_trace_event(TraceEvent(
98-
timestamp=datetime.now(),
99-
event_type=TraceEventType.AGENT_MESSAGE,
100-
agent_name=agent.name,
101-
data={'message': assistant_output.content}
102-
))
100+
assistant_output = agent.chat_completion(message)
103101

104102
while assistant_output.tool_calls: # None 或者空数组
105103
functions = assistant_output.tool_calls
106104
for item in functions:
107105
function = item.function
108106
args = json.loads(function.arguments)
109107

110-
self._add_trace_event(TraceEvent(
111-
timestamp=datetime.now(),
112-
event_type=TraceEventType.TOOL_CALL,
113-
agent_name=agent.name,
114-
data={'function': function.name, 'arguments': args}
115-
))
116108
if (tool_function := agent.tool_functions.get(function.name)) is not None:
109+
110+
self._add_trace_event(TraceEvent(
111+
timestamp=datetime.now(),
112+
event_type=TraceEventType.TOOL_CALL,
113+
agent=agent,
114+
data={'function': function.name, 'arguments': args}
115+
))
117116

118117
# 自动注入上下文变量
119118
if (var_name := agent.use_context_variables.get(function.name)) is not None:
@@ -141,6 +140,13 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
141140
result = Result()
142141

143142
elif (mcp_sever := agent.mcp_functions.get(function.name)) is not None:
143+
self._add_trace_event(TraceEvent(
144+
timestamp=datetime.now(),
145+
agent=agent,
146+
event_type=TraceEventType.MCP_TOOL_CALL,
147+
data={'function': function.name, 'arguments': args}
148+
))
149+
144150
resp = (await mcp_sever.session.call_tool(function.name, args)).content
145151
text_contents = []
146152
for content in resp:
@@ -158,11 +164,17 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
158164
else:
159165
result = Result(content=f'Failed to call tool: {function.name}')
160166

167+
trace_result = {'content': result.content}
168+
if result.context_variables._vars:
169+
trace_result['context_variables'] = result.context_variables._vars
170+
if not result.message:
171+
trace_result['message'] = False
172+
161173
self._add_trace_event(TraceEvent(
162174
timestamp=datetime.now(),
163-
agent_name=agent.name,
175+
agent=agent,
164176
event_type=TraceEventType.TOOL_RESULT,
165-
data={'function': function.name, 'result': result}
177+
data={'function': function.name, 'result': trace_result}
166178
))
167179

168180
agent.add_tool_call_message(result.content, item.id)
@@ -172,19 +184,19 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
172184

173185
self._add_trace_event(TraceEvent(
174186
timestamp=datetime.now(),
175-
agent_name=agent.name,
187+
agent=agent,
176188
event_type=TraceEventType.AGENT_SWITCH,
177189
data={'to_agent': result.agent.name}
178190
))
179191
agent = result.agent
180-
181-
self._add_trace_event(TraceEvent(
182-
timestamp=datetime.now(),
183-
agent_name=agent.name,
184-
event_type=TraceEventType.CONTEXT_UPDATE,
185-
data={'old_context': self.context_variables, 'new_context': result.context_variables}
186-
))
187-
self.context_variables.update(result.context_variables)
192+
if result.context_variables._vars:
193+
self._add_trace_event(TraceEvent(
194+
timestamp=datetime.now(),
195+
agent=agent,
196+
event_type=TraceEventType.CONTEXT_UPDATE,
197+
data={'old_context': self.context_variables, 'new_context': result.context_variables}
198+
))
199+
self.context_variables.update(result.context_variables)
188200

189201
self.set_agent_instruction(agent)
190202

@@ -193,13 +205,13 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
193205
if turn > self.max_turns:
194206
raise MaxTurnsException
195207

196-
message = assistant_output.content
197208
self._add_trace_event(TraceEvent(
198209
timestamp=datetime.now(),
199210
event_type=TraceEventType.AGENT_MESSAGE,
200-
agent_name=agent.name,
201-
data={'message': message}
211+
agent=agent,
212+
data={'message': assistant_output.content}
202213
))
214+
203215
self.trace['end_time'] = datetime.now()
204216

205217
return agent, assistant_output.content

src/course_graph/agent/trace.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import TypedDict, List
1010
from datetime import datetime
1111
from course_graph import set_logger, logger
12+
from course_graph.agent import Agent
1213

1314
set_logger(console=True, file=False)
1415

@@ -20,12 +21,13 @@ class TraceEventType(Enum):
2021
TOOL_CALL = 'tool_call'
2122
TOOL_RESULT = 'tool_result'
2223
CONTEXT_UPDATE = 'context_update'
24+
MCP_TOOL_CALL = 'mcp_tool_call'
2325

2426

2527
@dataclass
2628
class TraceEvent:
2729
timestamp: datetime
28-
agent_name: str
30+
agent: Agent
2931
event_type: TraceEventType
3032
data: dict
3133

src/course_graph/agent/types.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def update(self, other: Union[dict, 'ContextVariables']):
4444
self._vars.update(other)
4545
else:
4646
self._vars.update(other._vars)
47+
48+
def get(self, key: Any, default: Any = None) -> Any:
49+
return self._vars.get(key, default)
4750

4851

4952
@dataclass

0 commit comments

Comments
 (0)