@@ -566,6 +566,8 @@ async def main():
566566 if output_toolset :
567567 output_toolset .max_retries = self ._max_result_retries
568568 output_toolset .output_validators = output_validators
569+ toolset = self ._get_toolset (output_toolset = output_toolset , additional_toolsets = toolsets )
570+ tool_manager = ToolManager [AgentDepsT ](toolset )
569571
570572 # Build the graph
571573 graph : Graph [_agent_graph .GraphAgentState , _agent_graph .GraphAgentDeps [AgentDepsT , Any ], FinalResult [Any ]] = (
@@ -581,88 +583,73 @@ async def main():
581583 run_step = 0 ,
582584 )
583585
586+ # Merge model settings in order of precedence: run > agent > model
587+ merged_settings = merge_model_settings (model_used .settings , self .model_settings )
588+ model_settings = merge_model_settings (merged_settings , model_settings )
589+ usage_limits = usage_limits or _usage .UsageLimits ()
590+
591+ async def get_instructions (run_context : RunContext [AgentDepsT ]) -> str | None :
592+ parts = [
593+ self ._instructions ,
594+ * [await func .run (run_context ) for func in self ._instructions_functions ],
595+ ]
596+
597+ model_profile = model_used .profile
598+ if isinstance (output_schema , _output .PromptedOutputSchema ):
599+ instructions = output_schema .instructions (model_profile .prompted_output_template )
600+ parts .append (instructions )
601+
602+ parts = [p for p in parts if p ]
603+ if not parts :
604+ return None
605+ return '\n \n ' .join (parts ).strip ()
606+
584607 if isinstance (model_used , InstrumentedModel ):
585608 instrumentation_settings = model_used .instrumentation_settings
586609 tracer = model_used .instrumentation_settings .tracer
587610 else :
588611 instrumentation_settings = None
589612 tracer = NoOpTracer ()
590613
591- run_context = RunContext [AgentDepsT ](
592- deps = deps ,
593- model = model_used ,
594- usage = usage ,
614+ graph_deps = _agent_graph .GraphAgentDeps [AgentDepsT , RunOutputDataT ](
615+ user_deps = deps ,
595616 prompt = user_prompt ,
596- messages = state .message_history ,
617+ new_message_index = new_message_index ,
618+ model = model_used ,
619+ model_settings = model_settings ,
620+ usage_limits = usage_limits ,
621+ max_result_retries = self ._max_result_retries ,
622+ end_strategy = self .end_strategy ,
623+ output_schema = output_schema ,
624+ output_validators = output_validators ,
625+ history_processors = self .history_processors ,
626+ builtin_tools = list (self ._builtin_tools ),
627+ tool_manager = tool_manager ,
597628 tracer = tracer ,
598- trace_include_content = instrumentation_settings is not None and instrumentation_settings .include_content ,
599- run_step = state .run_step ,
629+ get_instructions = get_instructions ,
630+ instrumentation_settings = instrumentation_settings ,
631+ )
632+ start_node = _agent_graph .UserPromptNode [AgentDepsT ](
633+ user_prompt = user_prompt ,
634+ instructions = self ._instructions ,
635+ instructions_functions = self ._instructions_functions ,
636+ system_prompts = self ._system_prompts ,
637+ system_prompt_functions = self ._system_prompt_functions ,
638+ system_prompt_dynamic_functions = self ._system_prompt_dynamic_functions ,
600639 )
601640
602- toolset = self ._get_toolset (output_toolset = output_toolset , additional_toolsets = toolsets )
603-
604- async with toolset :
605- # This will raise errors for any name conflicts
606- tool_manager = await ToolManager [AgentDepsT ].build (toolset , run_context )
607-
608- # Merge model settings in order of precedence: run > agent > model
609- merged_settings = merge_model_settings (model_used .settings , self .model_settings )
610- model_settings = merge_model_settings (merged_settings , model_settings )
611- usage_limits = usage_limits or _usage .UsageLimits ()
612- agent_name = self .name or 'agent'
613- run_span = tracer .start_span (
614- 'agent run' ,
615- attributes = {
616- 'model_name' : model_used .model_name if model_used else 'no-model' ,
617- 'agent_name' : agent_name ,
618- 'logfire.msg' : f'{ agent_name } run' ,
619- },
620- )
621-
622- async def get_instructions (run_context : RunContext [AgentDepsT ]) -> str | None :
623- parts = [
624- self ._instructions ,
625- * [await func .run (run_context ) for func in self ._instructions_functions ],
626- ]
627-
628- model_profile = model_used .profile
629- if isinstance (output_schema , _output .PromptedOutputSchema ):
630- instructions = output_schema .instructions (model_profile .prompted_output_template )
631- parts .append (instructions )
632-
633- parts = [p for p in parts if p ]
634- if not parts :
635- return None
636- return '\n \n ' .join (parts ).strip ()
637-
638- graph_deps = _agent_graph .GraphAgentDeps [AgentDepsT , RunOutputDataT ](
639- user_deps = deps ,
640- prompt = user_prompt ,
641- new_message_index = new_message_index ,
642- model = model_used ,
643- model_settings = model_settings ,
644- usage_limits = usage_limits ,
645- max_result_retries = self ._max_result_retries ,
646- end_strategy = self .end_strategy ,
647- output_schema = output_schema ,
648- output_validators = output_validators ,
649- history_processors = self .history_processors ,
650- builtin_tools = list (self ._builtin_tools ),
651- tool_manager = tool_manager ,
652- tracer = tracer ,
653- get_instructions = get_instructions ,
654- instrumentation_settings = instrumentation_settings ,
655- )
656- start_node = _agent_graph .UserPromptNode [AgentDepsT ](
657- user_prompt = user_prompt ,
658- instructions = self ._instructions ,
659- instructions_functions = self ._instructions_functions ,
660- system_prompts = self ._system_prompts ,
661- system_prompt_functions = self ._system_prompt_functions ,
662- system_prompt_dynamic_functions = self ._system_prompt_dynamic_functions ,
663- )
641+ agent_name = self .name or 'agent'
642+ run_span = tracer .start_span (
643+ 'agent run' ,
644+ attributes = {
645+ 'model_name' : model_used .model_name if model_used else 'no-model' ,
646+ 'agent_name' : agent_name ,
647+ 'logfire.msg' : f'{ agent_name } run' ,
648+ },
649+ )
664650
665- try :
651+ try :
652+ async with toolset :
666653 async with graph .iter (
667654 start_node ,
668655 state = state ,
@@ -682,12 +669,12 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
682669 else json .dumps (InstrumentedModel .serialize_any (final_result .output ))
683670 ),
684671 )
672+ finally :
673+ try :
674+ if instrumentation_settings and run_span .is_recording ():
675+ run_span .set_attributes (self ._run_span_end_attributes (state , usage , instrumentation_settings ))
685676 finally :
686- try :
687- if instrumentation_settings and run_span .is_recording ():
688- run_span .set_attributes (self ._run_span_end_attributes (state , usage , instrumentation_settings ))
689- finally :
690- run_span .end ()
677+ run_span .end ()
691678
692679 def _run_span_end_attributes (
693680 self , state : _agent_graph .GraphAgentState , usage : _usage .Usage , settings : InstrumentationSettings
0 commit comments