@@ -774,90 +774,91 @@ async def main():
774774
775775 toolset = self ._get_toolset (output_toolset = output_toolset , additional_toolsets = toolsets )
776776 # This will raise errors for any name conflicts
777- run_toolset = await ToolManager [AgentDepsT ].build (toolset , run_context )
778-
779- # Merge model settings in order of precedence: run > agent > model
780- merged_settings = merge_model_settings (model_used .settings , self .model_settings )
781- model_settings = merge_model_settings (merged_settings , model_settings )
782- usage_limits = usage_limits or _usage .UsageLimits ()
783- agent_name = self .name or 'agent'
784- run_span = tracer .start_span (
785- 'agent run' ,
786- attributes = {
787- 'model_name' : model_used .model_name if model_used else 'no-model' ,
788- 'agent_name' : agent_name ,
789- 'logfire.msg' : f'{ agent_name } run' ,
790- },
791- )
792-
793- async def get_instructions (run_context : RunContext [AgentDepsT ]) -> str | None :
794- parts = [
795- self ._instructions ,
796- * [await func .run (run_context ) for func in self ._instructions_functions ],
797- ]
798-
799- model_profile = model_used .profile
800- if isinstance (output_schema , _output .PromptedOutputSchema ):
801- instructions = output_schema .instructions (model_profile .prompted_output_template )
802- parts .append (instructions )
777+ async with toolset :
778+ run_toolset = await ToolManager [AgentDepsT ].build (toolset , run_context )
779+
780+ # Merge model settings in order of precedence: run > agent > model
781+ merged_settings = merge_model_settings (model_used .settings , self .model_settings )
782+ model_settings = merge_model_settings (merged_settings , model_settings )
783+ usage_limits = usage_limits or _usage .UsageLimits ()
784+ agent_name = self .name or 'agent'
785+ run_span = tracer .start_span (
786+ 'agent run' ,
787+ attributes = {
788+ 'model_name' : model_used .model_name if model_used else 'no-model' ,
789+ 'agent_name' : agent_name ,
790+ 'logfire.msg' : f'{ agent_name } run' ,
791+ },
792+ )
803793
804- parts = [p for p in parts if p ]
805- if not parts :
806- return None
807- return '\n \n ' .join (parts ).strip ()
794+ async def get_instructions (run_context : RunContext [AgentDepsT ]) -> str | None :
795+ parts = [
796+ self ._instructions ,
797+ * [await func .run (run_context ) for func in self ._instructions_functions ],
798+ ]
808799
809- graph_deps = _agent_graph .GraphAgentDeps [AgentDepsT , RunOutputDataT ](
810- user_deps = deps ,
811- prompt = user_prompt ,
812- new_message_index = new_message_index ,
813- model = model_used ,
814- model_settings = model_settings ,
815- usage_limits = usage_limits ,
816- max_result_retries = self ._max_result_retries ,
817- end_strategy = self .end_strategy ,
818- output_schema = output_schema ,
819- output_validators = output_validators ,
820- history_processors = self .history_processors ,
821- tool_manager = run_toolset ,
822- tracer = tracer ,
823- get_instructions = get_instructions ,
824- instrumentation_settings = instrumentation_settings ,
825- )
826- start_node = _agent_graph .UserPromptNode [AgentDepsT ](
827- user_prompt = user_prompt ,
828- instructions = self ._instructions ,
829- instructions_functions = self ._instructions_functions ,
830- system_prompts = self ._system_prompts ,
831- system_prompt_functions = self ._system_prompt_functions ,
832- system_prompt_dynamic_functions = self ._system_prompt_dynamic_functions ,
833- )
800+ model_profile = model_used .profile
801+ if isinstance (output_schema , _output .PromptedOutputSchema ):
802+ instructions = output_schema .instructions (model_profile .prompted_output_template )
803+ parts .append (instructions )
804+
805+ parts = [p for p in parts if p ]
806+ if not parts :
807+ return None
808+ return '\n \n ' .join (parts ).strip ()
809+
810+ graph_deps = _agent_graph .GraphAgentDeps [AgentDepsT , RunOutputDataT ](
811+ user_deps = deps ,
812+ prompt = user_prompt ,
813+ new_message_index = new_message_index ,
814+ model = model_used ,
815+ model_settings = model_settings ,
816+ usage_limits = usage_limits ,
817+ max_result_retries = self ._max_result_retries ,
818+ end_strategy = self .end_strategy ,
819+ output_schema = output_schema ,
820+ output_validators = output_validators ,
821+ history_processors = self .history_processors ,
822+ tool_manager = run_toolset ,
823+ tracer = tracer ,
824+ get_instructions = get_instructions ,
825+ instrumentation_settings = instrumentation_settings ,
826+ )
827+ start_node = _agent_graph .UserPromptNode [AgentDepsT ](
828+ user_prompt = user_prompt ,
829+ instructions = self ._instructions ,
830+ instructions_functions = self ._instructions_functions ,
831+ system_prompts = self ._system_prompts ,
832+ system_prompt_functions = self ._system_prompt_functions ,
833+ system_prompt_dynamic_functions = self ._system_prompt_dynamic_functions ,
834+ )
834835
835- try :
836- async with graph .iter (
837- start_node ,
838- state = state ,
839- deps = graph_deps ,
840- span = use_span (run_span ) if run_span .is_recording () else None ,
841- infer_name = False ,
842- ) as graph_run :
843- agent_run = AgentRun (graph_run )
844- yield agent_run
845- if (final_result := agent_run .result ) is not None and run_span .is_recording ():
846- if instrumentation_settings and instrumentation_settings .include_content :
847- run_span .set_attribute (
848- 'final_result' ,
849- (
850- final_result .output
851- if isinstance (final_result .output , str )
852- else json .dumps (InstrumentedModel .serialize_any (final_result .output ))
853- ),
854- )
855- finally :
856836 try :
857- if instrumentation_settings and run_span .is_recording ():
858- run_span .set_attributes (self ._run_span_end_attributes (state , usage , instrumentation_settings ))
837+ async with graph .iter (
838+ start_node ,
839+ state = state ,
840+ deps = graph_deps ,
841+ span = use_span (run_span ) if run_span .is_recording () else None ,
842+ infer_name = False ,
843+ ) as graph_run :
844+ agent_run = AgentRun (graph_run )
845+ yield agent_run
846+ if (final_result := agent_run .result ) is not None and run_span .is_recording ():
847+ if instrumentation_settings and instrumentation_settings .include_content :
848+ run_span .set_attribute (
849+ 'final_result' ,
850+ (
851+ final_result .output
852+ if isinstance (final_result .output , str )
853+ else json .dumps (InstrumentedModel .serialize_any (final_result .output ))
854+ ),
855+ )
859856 finally :
860- run_span .end ()
857+ try :
858+ if instrumentation_settings and run_span .is_recording ():
859+ run_span .set_attributes (self ._run_span_end_attributes (state , usage , instrumentation_settings ))
860+ finally :
861+ run_span .end ()
861862
862863 def _run_span_end_attributes (
863864 self , state : _agent_graph .GraphAgentState , usage : _usage .Usage , settings : InstrumentationSettings
@@ -2173,7 +2174,7 @@ async def __anext__(
21732174 ) -> _agent_graph .AgentNode [AgentDepsT , OutputDataT ] | End [FinalResult [OutputDataT ]]:
21742175 """Advance to the next node automatically based on the last returned node."""
21752176 next_node = await self ._graph_run .__anext__ ()
2176- if _agent_graph .is_agent_node (next_node ):
2177+ if _agent_graph .is_agent_node (node = next_node ):
21772178 return next_node
21782179 assert isinstance (next_node , End ), f'Unexpected node type: { type (next_node )} '
21792180 return next_node
0 commit comments