diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4515d18bc9..fda19acda4 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -341,6 +341,7 @@ async def stream( ctx.deps.output_schema, ctx.deps.output_validators, build_run_context(ctx), + _output.build_trace_context(ctx), ctx.deps.usage_limits, ) yield agent_stream @@ -529,7 +530,8 @@ async def _handle_tool_calls( if isinstance(output_schema, _output.ToolOutputSchema): for call, output_tool in output_schema.find_tool(tool_calls): try: - result_data = await output_tool.process(call, run_context) + trace_context = _output.build_trace_context(ctx) + result_data = await output_tool.process(call, run_context, trace_context) result_data = await _validate_output(result_data, ctx, call) except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call @@ -586,7 +588,8 @@ async def _handle_text_response( try: if isinstance(output_schema, _output.TextOutputSchema): run_context = build_run_context(ctx) - result_data = await output_schema.process(text, run_context) + trace_context = _output.build_trace_context(ctx) + result_data = await output_schema.process(text, run_context, trace_context) else: m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index bd882bd6d0..c3199dd95c 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import dataclasses import inspect import json from abc import ABC, abstractmethod @@ -7,10 +8,13 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload +from opentelemetry.trace import Tracer from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator from typing_extensions import TypedDict, TypeVar, assert_never +from pydantic_graph.nodes import GraphRunContext + from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UserError @@ -29,6 +33,8 @@ from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition if TYPE_CHECKING: + from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState + from .profiles import ModelProfile T = TypeVar('T') @@ -66,6 +72,71 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' +@dataclass(frozen=True) +class TraceContext: + """A context for tracing output processing.""" + + tracer: Tracer + include_content: bool + call: _messages.ToolCallPart | None = None + + def with_call(self, call: _messages.ToolCallPart): + return dataclasses.replace(self, call=call) + + async def execute_function_with_span( + self, + function_schema: _function_schema.FunctionSchema, + run_context: RunContext[AgentDepsT], + args: dict[str, Any] | Any, + call: _messages.ToolCallPart, + include_tool_call_id: bool = True, + ) -> Any: + """Execute a function call within a traced span, automatically recording the response.""" + # Set up span attributes + attributes = { + 'gen_ai.tool.name': call.tool_name, + 'logfire.msg': f'running output function: {call.tool_name}', + } + if include_tool_call_id: + attributes['gen_ai.tool.call.id'] = call.tool_call_id + if self.include_content: + attributes['tool_arguments'] = call.args_as_json_str() + attributes['logfire.json_schema'] = json.dumps( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + + # Execute function within span + with self.tracer.start_as_current_span('running output function', attributes=attributes) as span: + output = await function_schema.call(args, run_context) + + # Record response if content inclusion is enabled + if self.include_content and span.is_recording(): + from .models.instrumented import InstrumentedModel + + span.set_attribute( + 'tool_response', + output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), + ) + + return output + + +def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext: + """Build a `TraceContext` from the current agent graph run context.""" + return TraceContext( + tracer=ctx.deps.tracer, + include_content=( + ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content + ), + ) + + class ToolRetryError(Exception): """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" @@ -96,6 +167,7 @@ async def validate( result: The result data after Pydantic validation the message content. tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. Returns: Result of either the validated result data (ok) or a retry message (Err). @@ -349,6 +421,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -371,6 +444,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -379,6 +453,7 @@ async def process( Args: text: The output text to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -389,7 +464,7 @@ async def process( return cast(OutputDataT, text) return await self.processor.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -417,6 +492,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -425,6 +501,7 @@ async def process( Args: text: The output text to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -432,7 +509,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ return await self.processor.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -468,6 +545,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -476,6 +554,7 @@ async def process( Args: text: The output text to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -485,7 +564,7 @@ async def process( text = _utils.strip_markdown_fences(text) return await self.processor.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -568,6 +647,7 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -637,6 +717,7 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -645,6 +726,7 @@ async def process( Args: data: The output data to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -670,8 +752,18 @@ async def process( output = output[k] if self._function_schema: + # Wraps the output function call in an OpenTelemetry span. + if trace_context.call: + call = trace_context.call + include_tool_call_id = True + else: + function_name = getattr(self._function_schema.function, '__name__', 'output_function') + call = _messages.ToolCallPart(tool_name=function_name, args=data) + include_tool_call_id = False try: - output = await self._function_schema.call(output, run_context) + output = await trace_context.execute_function_with_span( + self._function_schema, run_context, output, call, include_tool_call_id + ) except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -784,11 +876,12 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) result = union_object.result @@ -804,7 +897,7 @@ async def process( raise return await processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -835,13 +928,20 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} - + # Wraps the output function call in an OpenTelemetry span. + # Note: PlainTextOutputProcessor is used for text responses (not tool calls), + # so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id + function_name = getattr(self._function_schema.function, '__name__', 'text_output_function') + call = _messages.ToolCallPart(tool_name=function_name, args=args) try: - output = await self._function_schema.call(args, run_context) + output = await trace_context.execute_function_with_span( + self._function_schema, run_context, args, call, include_tool_call_id=False + ) except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -881,6 +981,7 @@ async def process( self, tool_call: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -889,6 +990,7 @@ async def process( Args: tool_call: The tool call from the LLM to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -897,7 +999,11 @@ async def process( """ try: output = await self.processor.process( - tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False + tool_call.args, + run_context, + trace_context.with_call(tool_call), + allow_partial=allow_partial, + wrap_validation_errors=False, ) except ValidationError as e: if wrap_validation_errors: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9c87fee517..3ff881294c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1089,6 +1089,7 @@ async def on_complete() -> None: streamed_response, graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), + _output.build_trace_context(graph_ctx), graph_ctx.deps.output_validators, final_result_details.tool_name, on_complete, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 0b5c04fa84..f700482662 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -19,6 +19,7 @@ PlainTextOutputSchema, TextOutputSchema, ToolOutputSchema, + TraceContext, ) from ._run_context import AgentDepsT, RunContext from .messages import AgentStreamEvent, FinalResultEvent @@ -46,6 +47,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_schema: OutputSchema[OutputDataT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] + _trace_ctx: TraceContext _usage_limits: UsageLimits | None _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) @@ -105,13 +107,17 @@ async def _validate_response( call, output_tool = match result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + call, + self._run_ctx, + self._trace_ctx, + allow_partial=allow_partial, + wrap_validation_errors=False, ) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -177,6 +183,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] + _trace_ctx: TraceContext _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] @@ -423,13 +430,17 @@ async def validate_structured_output( call, output_tool = match result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + call, + self._run_ctx, + self._trace_ctx, + allow_partial=allow_partial, + wrap_validation_errors=False, ) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 691b85d9b1..97ba871cc9 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -6,12 +6,17 @@ import pytest from dirty_equals import IsInt, IsJson, IsList from inline_snapshot import snapshot +from pydantic import BaseModel from typing_extensions import NotRequired, TypedDict from pydantic_ai import Agent from pydantic_ai._utils import get_traceparent +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart +from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel from pydantic_ai.models.test import TestModel +from pydantic_ai.output import PromptedOutput, TextOutput +from pydantic_ai.tools import RunContext from .conftest import IsStr @@ -705,3 +710,714 @@ async def add_numbers(x: int, y: int) -> int: 'logfire.span_type': 'span', } ) + + +class WeatherInfo(BaseModel): + temperature: float + description: str + + +def get_weather_info(city: str) -> WeatherInfo: + return WeatherInfo(temperature=28.7, description='sunny') + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=get_weather_info) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_with_run_context_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def get_weather_with_ctx(ctx: RunContext[None], city: str) -> WeatherInfo: + assert ctx is not None + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=get_weather_with_ctx) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_with_retry_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def get_weather_with_retry(city: str) -> WeatherInfo: + if city != 'Mexico City': + from pydantic_ai import ModelRetry + + raise ModelRetry('City not found, I only know Mexico City') + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + args_json = '{"city": "New York City"}' + else: + args_json = '{"city": "Mexico City"}' + + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('New York City', output_type=get_weather_with_retry) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + output_function_attributes = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "New York City"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + }, + ] + ) + else: + assert output_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.span_type': 'span', + }, + ] + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_with_custom_tool_name_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + from pydantic_ai.output import ToolOutput + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=ToolOutput(get_weather_info, name='get_weather')) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes with custom tool name + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'get_weather' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'get_weather', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: get_weather', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'get_weather', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: get_weather', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_bound_instance_method_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, city: str): + return self + + weather = Weather(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=weather.get_weather) + assert result.output == Weather(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_bound_instance_method_with_run_context_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, ctx: RunContext[None], city: str): + assert ctx is not None + return self + + weather = Weather(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=weather.get_weather) + assert result.output == Weather(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_async_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + """Test logfire attributes for async output function types.""" + + async def get_weather_async(city: str) -> WeatherInfo: + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=get_weather_async) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +def upcase_text(text: str) -> str: + """Convert text to uppercase.""" + return text.upper() + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_text_output_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + """Test logfire attributes for TextOutput functions (PlainTextOutputProcessor).""" + + def call_text_response(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Return a plain text response (not a tool call) + from pydantic_ai.messages import TextPart + + return ModelResponse(parts=[TextPart(content='hello world')]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_text_response), instrument=instrumentation_settings) + + result = my_agent.run_sync('Say hello', output_type=TextOutput(upcase_text)) + assert result.output == 'HELLO WORLD' + + summary = get_logfire_summary() + + # Find the text output function span attributes + [text_function_attributes] = [ + attributes + for attributes in summary.attributes.values() + if 'running output function: upcase_text' in attributes.get('logfire.msg', '') + ] + + if include_content: + assert text_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'tool_arguments': '{"text":"hello world"}', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': 'HELLO WORLD', + } + ) + else: + assert text_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_prompted_output_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + """Test that spans are created for PromptedOutput functions with appropriate attributes.""" + + def upcase_text(text: str) -> str: + return text.upper() + + call_count = 0 + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + # Simulate the model returning JSON that will be parsed and used to call the function + return ModelResponse(parts=[TextPart(content='{"text": "hello world"}')]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + agent = Agent( + model=FunctionModel(call_tool), instrument=instrumentation_settings, output_type=PromptedOutput(upcase_text) + ) + + result = agent.run_sync('test') + + # Check that the function was called and returned the expected result + assert result.output == 'HELLO WORLD' + assert call_count == 1 + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes + for attributes in summary.attributes.values() + if attributes.get('logfire.msg', '').startswith('running output function: upcase_text') + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'tool_arguments': '{"text": "hello world"}', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': 'HELLO WORLD', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_text_output_function_with_retry_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def get_weather_with_retry(ctx: RunContext[None], city: str) -> WeatherInfo: + assert ctx is not None + if city != 'Mexico City': + from pydantic_ai import ModelRetry + + raise ModelRetry('City not found, I only know Mexico City') + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + city = 'New York City' + else: + city = 'Mexico City' + + return ModelResponse(parts=[TextPart(content=city)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('New York City', output_type=TextOutput(get_weather_with_retry)) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + text_function_attributes = [ + attributes + for attributes in summary.attributes.values() + if 'running output function: get_weather_with_retry' in attributes.get('logfire.msg', '') + ] + + if include_content: + assert text_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'tool_arguments': '{"city":"New York City"}', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'tool_arguments': '{"city":"Mexico City"}', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + }, + ] + ) + else: + assert text_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.span_type': 'span', + }, + ] + )