@@ -55,8 +55,13 @@ def set_agent_instruction(self, agent: Agent) -> None:
5555 match agent .instruction :
5656 case _ if callable (agent .instruction ):
5757 parameters = inspect .signature (agent .instruction ).parameters
58- args = (self .context_variables ,) if len (parameters ) == 1 else ()
59- agent .llm .instruction = agent .instruction (* args )
58+ args = {}
59+ for arg_name , p in parameters .items ():
60+ if p .annotation == ContextVariables :
61+ args [arg_name ] = self .context_variables
62+ else :
63+ args [arg_name ] = agent .instruction_args .get (arg_name , p .default )
64+ agent .llm .instruction = agent .instruction (** args )
6065 case _:
6166 agent .llm .instruction = agent .instruction
6267
@@ -88,32 +93,26 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
8893 self ._add_trace_event (TraceEvent (
8994 timestamp = datetime .now (),
9095 event_type = TraceEventType .USER_MESSAGE ,
91- agent_name = agent . name ,
96+ agent = agent ,
9297 data = {'message' : message }
9398 ))
9499
95- assistant_output = agent .chat_completion (message )
96-
97- self ._add_trace_event (TraceEvent (
98- timestamp = datetime .now (),
99- event_type = TraceEventType .AGENT_MESSAGE ,
100- agent_name = agent .name ,
101- data = {'message' : assistant_output .content }
102- ))
100+ assistant_output = agent .chat_completion (message )
103101
104102 while assistant_output .tool_calls : # None 或者空数组
105103 functions = assistant_output .tool_calls
106104 for item in functions :
107105 function = item .function
108106 args = json .loads (function .arguments )
109107
110- self ._add_trace_event (TraceEvent (
111- timestamp = datetime .now (),
112- event_type = TraceEventType .TOOL_CALL ,
113- agent_name = agent .name ,
114- data = {'function' : function .name , 'arguments' : args }
115- ))
116108 if (tool_function := agent .tool_functions .get (function .name )) is not None :
109+
110+ self ._add_trace_event (TraceEvent (
111+ timestamp = datetime .now (),
112+ event_type = TraceEventType .TOOL_CALL ,
113+ agent = agent ,
114+ data = {'function' : function .name , 'arguments' : args }
115+ ))
117116
118117 # 自动注入上下文变量
119118 if (var_name := agent .use_context_variables .get (function .name )) is not None :
@@ -141,6 +140,13 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
141140 result = Result ()
142141
143142 elif (mcp_sever := agent .mcp_functions .get (function .name )) is not None :
143+ self ._add_trace_event (TraceEvent (
144+ timestamp = datetime .now (),
145+ agent = agent ,
146+ event_type = TraceEventType .MCP_TOOL_CALL ,
147+ data = {'function' : function .name , 'arguments' : args }
148+ ))
149+
144150 resp = (await mcp_sever .session .call_tool (function .name , args )).content
145151 text_contents = []
146152 for content in resp :
@@ -158,11 +164,17 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
158164 else :
159165 result = Result (content = f'Failed to call tool: { function .name } ' )
160166
167+ trace_result = {'content' : result .content }
168+ if result .context_variables ._vars :
169+ trace_result ['context_variables' ] = result .context_variables ._vars
170+ if not result .message :
171+ trace_result ['message' ] = False
172+
161173 self ._add_trace_event (TraceEvent (
162174 timestamp = datetime .now (),
163- agent_name = agent . name ,
175+ agent = agent ,
164176 event_type = TraceEventType .TOOL_RESULT ,
165- data = {'function' : function .name , 'result' : result }
177+ data = {'function' : function .name , 'result' : trace_result }
166178 ))
167179
168180 agent .add_tool_call_message (result .content , item .id )
@@ -172,19 +184,19 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
172184
173185 self ._add_trace_event (TraceEvent (
174186 timestamp = datetime .now (),
175- agent_name = agent . name ,
187+ agent = agent ,
176188 event_type = TraceEventType .AGENT_SWITCH ,
177189 data = {'to_agent' : result .agent .name }
178190 ))
179191 agent = result .agent
180-
181- self ._add_trace_event (TraceEvent (
182- timestamp = datetime .now (),
183- agent_name = agent . name ,
184- event_type = TraceEventType .CONTEXT_UPDATE ,
185- data = {'old_context' : self .context_variables , 'new_context' : result .context_variables }
186- ))
187- self .context_variables .update (result .context_variables )
192+ if result . context_variables . _vars :
193+ self ._add_trace_event (TraceEvent (
194+ timestamp = datetime .now (),
195+ agent = agent ,
196+ event_type = TraceEventType .CONTEXT_UPDATE ,
197+ data = {'old_context' : self .context_variables , 'new_context' : result .context_variables }
198+ ))
199+ self .context_variables .update (result .context_variables )
188200
189201 self .set_agent_instruction (agent )
190202
@@ -193,13 +205,13 @@ async def run(self, agent: Agent, message: str = None) -> tuple[Agent, str]:
193205 if turn > self .max_turns :
194206 raise MaxTurnsException
195207
196- message = assistant_output .content
197208 self ._add_trace_event (TraceEvent (
198209 timestamp = datetime .now (),
199210 event_type = TraceEventType .AGENT_MESSAGE ,
200- agent_name = agent . name ,
201- data = {'message' : message }
211+ agent = agent ,
212+ data = {'message' : assistant_output . content }
202213 ))
214+
203215 self .trace ['end_time' ] = datetime .now ()
204216
205217 return agent , assistant_output .content
0 commit comments