diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index fa1188bd..e8a3c939 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -74,6 +74,9 @@ class Agent(Generic[AgentDeps, ResultData]): model: models.Model | models.KnownModelName | None """The default model configured for this agent.""" + fallback_models: list[models.Model | models.KnownModelName] | None + """A list of fallback models to use if the primary model fails.""" + name: str | None """The name of the agent, used for logging. @@ -111,6 +114,7 @@ def __init__( self, model: models.Model | models.KnownModelName | None = None, *, + fallback_models: list[models.Model | models.KnownModelName] | None = None, result_type: type[ResultData] = str, system_prompt: str | Sequence[str] = (), deps_type: type[AgentDeps] = NoneType, @@ -129,6 +133,7 @@ def __init__( Args: model: The default model to use for this agent, if not provide, you must provide the model when calling it. + fallback_models: A list of fallback models to use if the primary model fails. result_type: The type of the result data, used to validate the result data, defaults to `str`. system_prompt: Static system prompts to use for this agent, you can also register system prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt]. @@ -158,6 +163,7 @@ def __init__( else: self.model = models.infer_model(model) + self.fallback_models = fallback_models self.end_strategy = end_strategy self.name = name self.model_settings = model_settings @@ -252,7 +258,24 @@ async def run( agent_model = await self._prepare_model(run_context) with _logfire.span('model request', run_step=run_step) as model_req_span: - model_response, request_usage = await agent_model.request(messages, model_settings) + try: + model_response, request_usage = await agent_model.request(messages, model_settings) + except Exception as e: + if self.fallback_models: + fallback_model = self.fallback_models.pop(0) + self.model = fallback_model + return await self.run( + user_prompt, + message_history=message_history, + model=fallback_model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + infer_name=infer_name, + ) + else: + raise e + model_req_span.set_attribute('response', model_response) model_req_span.set_attribute('usage', request_usage) @@ -411,69 +434,86 @@ async def main(): agent_model = await self._prepare_model(run_context) with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span: - async with agent_model.request_stream(messages, model_settings) as model_response: - usage.requests += 1 - model_req_span.set_attribute('response_type', model_response.__class__.__name__) - # We want to end the "model request" span here, but we can't exit the context manager - # in the traditional way - model_req_span.__exit__(None, None, None) - - with _logfire.span('handle model response') as handle_span: - maybe_final_result = await self._handle_streamed_model_response(model_response, run_context) - - # Check if we got a final result - if isinstance(maybe_final_result, _MarkFinalResult): - result_stream = maybe_final_result.data - result_tool_name = maybe_final_result.tool_name - handle_span.message = 'handle model response -> final result' - - async def on_complete(): - """Called when the stream has completed. - - The model response will have been added to messages by now - by `StreamedRunResult._marked_completed`. - """ - last_message = messages[-1] - assert isinstance(last_message, _messages.ModelResponse) - tool_calls = [ - part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) - ] - parts = await self._process_function_tools( - tool_calls, result_tool_name, run_context + try: + async with agent_model.request_stream(messages, model_settings) as model_response: + usage.requests += 1 + model_req_span.set_attribute('response_type', model_response.__class__.__name__) + # We want to end the "model request" span here, but we can't exit the context manager + # in the traditional way + model_req_span.__exit__(None, None, None) + + with _logfire.span('handle model response') as handle_span: + maybe_final_result = await self._handle_streamed_model_response(model_response, run_context) + + # Check if we got a final result + if isinstance(maybe_final_result, _MarkFinalResult): + result_stream = maybe_final_result.data + result_tool_name = maybe_final_result.tool_name + handle_span.message = 'handle model response -> final result' + + async def on_complete(): + """Called when the stream has completed. + + The model response will have been added to messages by now + by `StreamedRunResult._marked_completed`. + """ + last_message = messages[-1] + assert isinstance(last_message, _messages.ModelResponse) + tool_calls = [ + part for part in last_message.parts if isinstance(part, _messages.ToolCallPart) + ] + parts = await self._process_function_tools( + tool_calls, result_tool_name, run_context + ) + if parts: + messages.append(_messages.ModelRequest(parts)) + run_span.set_attribute('all_messages', messages) + + yield result.StreamedRunResult( + messages, + new_message_index, + usage, + usage_limits, + result_stream, + self._result_schema, + run_context, + self._result_validators, + result_tool_name, + on_complete, ) - if parts: - messages.append(_messages.ModelRequest(parts)) - run_span.set_attribute('all_messages', messages) - - yield result.StreamedRunResult( - messages, - new_message_index, - usage, - usage_limits, - result_stream, - self._result_schema, - run_context, - self._result_validators, - result_tool_name, - on_complete, - ) - return - else: - # continue the conversation - model_response_msg, tool_responses = maybe_final_result - # if we got a model response add that to messages - messages.append(model_response_msg) - if tool_responses: - # if we got one or more tool response parts, add a model request message - messages.append(_messages.ModelRequest(tool_responses)) - - handle_span.set_attribute('tool_responses', tool_responses) - tool_responses_str = ' '.join(r.part_kind for r in tool_responses) - handle_span.message = f'handle model response -> {tool_responses_str}' - # the model_response should have been fully streamed by now, we can add its usage - model_response_usage = model_response.usage() - usage += model_response_usage - usage_limits.check_tokens(usage) + return + else: + # continue the conversation + model_response_msg, tool_responses = maybe_final_result + # if we got a model response add that to messages + messages.append(model_response_msg) + if tool_responses: + # if we got one or more tool response parts, add a model request message + messages.append(_messages.ModelRequest(tool_responses)) + + handle_span.set_attribute('tool_responses', tool_responses) + tool_responses_str = ' '.join(r.part_kind for r in tool_responses) + handle_span.message = f'handle model response -> {tool_responses_str}' + # the model_response should have been fully streamed by now, we can add its usage + model_response_usage = model_response.usage() + usage += model_response_usage + usage_limits.check_tokens(usage) + except Exception as e: + if self.fallback_models: + fallback_model = self.fallback_models.pop(0) + self.model = fallback_model + async with self.run_stream( + user_prompt, + message_history=message_history, + model=fallback_model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + infer_name=infer_name, + ) as fallback_response: + yield fallback_response + else: + raise e @contextmanager def override( diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 1a1bc9f4..f1bf5e57 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -90,6 +90,8 @@ class Model(ABC): """Abstract class for a model.""" + fallback_models: list[Model | KnownModelName] | None = None + @abstractmethod async def agent_model( self, @@ -120,6 +122,8 @@ def name(self) -> str: class AgentModel(ABC): """Model configured for each step of an Agent run.""" + fallback_models: list[Model | KnownModelName] | None = None + @abstractmethod async def request( self, messages: list[ModelMessage], model_settings: ModelSettings | None diff --git a/tests/test_agent.py b/tests/test_agent.py index 801eac3b..0b9dc7b7 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1181,3 +1181,47 @@ async def get_location(loc_name: str) -> str: ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)), ] ) + + +def test_fallback_models(set_event_loop: None): + def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + raise RuntimeError("Primary model failed") + + primary_model = FunctionModel(return_model) + fallback_model = TestModel() + + agent = Agent(primary_model, fallback_models=[fallback_model], result_type=tuple[str, str]) + + result = agent.run_sync('Hello') + assert result.data == snapshot(('a', 'a')) + assert agent.model == primary_model + assert agent.fallback_models == [] + + assert fallback_model.agent_model_function_tools == snapshot([]) + assert fallback_model.agent_model_allow_text_result is False + + assert fallback_model.agent_model_result_tools is not None + assert len(fallback_model.agent_model_result_tools) == 1 + + assert fallback_model.agent_model_result_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'properties': { + 'response': { + 'maxItems': 2, + 'minItems': 2, + 'prefixItems': [{'type': 'string'}, {'type': 'string'}], + 'title': 'Response', + 'type': 'array', + } + }, + 'required': ['response'], + 'type': 'object', + }, + outer_typed_dict_key='response', + ) + ] + )