Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fallback models support #532

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 103 additions & 63 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ class Agent(Generic[AgentDeps, ResultData]):
model: models.Model | models.KnownModelName | None
"""The default model configured for this agent."""

fallback_models: list[models.Model | models.KnownModelName] | None
"""A list of fallback models to use if the primary model fails."""

name: str | None
"""The name of the agent, used for logging.

Expand Down Expand Up @@ -111,6 +114,7 @@ def __init__(
self,
model: models.Model | models.KnownModelName | None = None,
*,
fallback_models: list[models.Model | models.KnownModelName] | None = None,
result_type: type[ResultData] = str,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDeps] = NoneType,
Expand All @@ -129,6 +133,7 @@ def __init__(
Args:
model: The default model to use for this agent, if not provide,
you must provide the model when calling it.
fallback_models: A list of fallback models to use if the primary model fails.
result_type: The type of the result data, used to validate the result data, defaults to `str`.
system_prompt: Static system prompts to use for this agent, you can also register system
prompts via a function with [`system_prompt`][pydantic_ai.Agent.system_prompt].
Expand Down Expand Up @@ -158,6 +163,7 @@ def __init__(
else:
self.model = models.infer_model(model)

self.fallback_models = fallback_models
self.end_strategy = end_strategy
self.name = name
self.model_settings = model_settings
Expand Down Expand Up @@ -252,7 +258,24 @@ async def run(
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request', run_step=run_step) as model_req_span:
model_response, request_usage = await agent_model.request(messages, model_settings)
try:
model_response, request_usage = await agent_model.request(messages, model_settings)
except Exception as e:
if self.fallback_models:
fallback_model = self.fallback_models.pop(0)
self.model = fallback_model
return await self.run(
user_prompt,
message_history=message_history,
model=fallback_model,
deps=deps,
model_settings=model_settings,
usage_limits=usage_limits,
infer_name=infer_name,
)
else:
raise e

model_req_span.set_attribute('response', model_response)
model_req_span.set_attribute('usage', request_usage)

Expand Down Expand Up @@ -411,69 +434,86 @@ async def main():
agent_model = await self._prepare_model(run_context)

with _logfire.span('model request {run_step=}', run_step=run_step) as model_req_span:
async with agent_model.request_stream(messages, model_settings) as model_response:
usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
model_req_span.__exit__(None, None, None)

with _logfire.span('handle model response') as handle_span:
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)

# Check if we got a final result
if isinstance(maybe_final_result, _MarkFinalResult):
result_stream = maybe_final_result.data
result_tool_name = maybe_final_result.tool_name
handle_span.message = 'handle model response -> final result'

async def on_complete():
"""Called when the stream has completed.

The model response will have been added to messages by now
by `StreamedRunResult._marked_completed`.
"""
last_message = messages[-1]
assert isinstance(last_message, _messages.ModelResponse)
tool_calls = [
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
]
parts = await self._process_function_tools(
tool_calls, result_tool_name, run_context
try:
async with agent_model.request_stream(messages, model_settings) as model_response:
usage.requests += 1
model_req_span.set_attribute('response_type', model_response.__class__.__name__)
# We want to end the "model request" span here, but we can't exit the context manager
# in the traditional way
model_req_span.__exit__(None, None, None)

with _logfire.span('handle model response') as handle_span:
maybe_final_result = await self._handle_streamed_model_response(model_response, run_context)

# Check if we got a final result
if isinstance(maybe_final_result, _MarkFinalResult):
result_stream = maybe_final_result.data
result_tool_name = maybe_final_result.tool_name
handle_span.message = 'handle model response -> final result'

async def on_complete():
"""Called when the stream has completed.

The model response will have been added to messages by now
by `StreamedRunResult._marked_completed`.
"""
last_message = messages[-1]
assert isinstance(last_message, _messages.ModelResponse)
tool_calls = [
part for part in last_message.parts if isinstance(part, _messages.ToolCallPart)
]
parts = await self._process_function_tools(
tool_calls, result_tool_name, run_context
)
if parts:
messages.append(_messages.ModelRequest(parts))
run_span.set_attribute('all_messages', messages)

yield result.StreamedRunResult(
messages,
new_message_index,
usage,
usage_limits,
result_stream,
self._result_schema,
run_context,
self._result_validators,
result_tool_name,
on_complete,
)
if parts:
messages.append(_messages.ModelRequest(parts))
run_span.set_attribute('all_messages', messages)

yield result.StreamedRunResult(
messages,
new_message_index,
usage,
usage_limits,
result_stream,
self._result_schema,
run_context,
self._result_validators,
result_tool_name,
on_complete,
)
return
else:
# continue the conversation
model_response_msg, tool_responses = maybe_final_result
# if we got a model response add that to messages
messages.append(model_response_msg)
if tool_responses:
# if we got one or more tool response parts, add a model request message
messages.append(_messages.ModelRequest(tool_responses))

handle_span.set_attribute('tool_responses', tool_responses)
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
handle_span.message = f'handle model response -> {tool_responses_str}'
# the model_response should have been fully streamed by now, we can add its usage
model_response_usage = model_response.usage()
usage += model_response_usage
usage_limits.check_tokens(usage)
return
else:
# continue the conversation
model_response_msg, tool_responses = maybe_final_result
# if we got a model response add that to messages
messages.append(model_response_msg)
if tool_responses:
# if we got one or more tool response parts, add a model request message
messages.append(_messages.ModelRequest(tool_responses))

handle_span.set_attribute('tool_responses', tool_responses)
tool_responses_str = ' '.join(r.part_kind for r in tool_responses)
handle_span.message = f'handle model response -> {tool_responses_str}'
# the model_response should have been fully streamed by now, we can add its usage
model_response_usage = model_response.usage()
usage += model_response_usage
usage_limits.check_tokens(usage)
except Exception as e:
if self.fallback_models:
fallback_model = self.fallback_models.pop(0)
self.model = fallback_model
async with self.run_stream(
user_prompt,
message_history=message_history,
model=fallback_model,
deps=deps,
model_settings=model_settings,
usage_limits=usage_limits,
infer_name=infer_name,
) as fallback_response:
yield fallback_response
else:
raise e

@contextmanager
def override(
Expand Down
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
class Model(ABC):
"""Abstract class for a model."""

fallback_models: list[Model | KnownModelName] | None = None

@abstractmethod
async def agent_model(
self,
Expand Down Expand Up @@ -120,6 +122,8 @@ def name(self) -> str:
class AgentModel(ABC):
"""Model configured for each step of an Agent run."""

fallback_models: list[Model | KnownModelName] | None = None

@abstractmethod
async def request(
self, messages: list[ModelMessage], model_settings: ModelSettings | None
Expand Down
44 changes: 44 additions & 0 deletions tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,3 +1181,47 @@ async def get_location(loc_name: str) -> str:
ModelResponse.from_text(content='final response', timestamp=IsNow(tz=timezone.utc)),
]
)


def test_fallback_models(set_event_loop: None):
def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse:
raise RuntimeError("Primary model failed")

primary_model = FunctionModel(return_model)
fallback_model = TestModel()

agent = Agent(primary_model, fallback_models=[fallback_model], result_type=tuple[str, str])

result = agent.run_sync('Hello')
assert result.data == snapshot(('a', 'a'))
assert agent.model == primary_model
assert agent.fallback_models == []

assert fallback_model.agent_model_function_tools == snapshot([])
assert fallback_model.agent_model_allow_text_result is False

assert fallback_model.agent_model_result_tools is not None
assert len(fallback_model.agent_model_result_tools) == 1

assert fallback_model.agent_model_result_tools == snapshot(
[
ToolDefinition(
name='final_result',
description='The final response which ends this conversation',
parameters_json_schema={
'properties': {
'response': {
'maxItems': 2,
'minItems': 2,
'prefixItems': [{'type': 'string'}, {'type': 'string'}],
'title': 'Response',
'type': 'array',
}
},
'required': ['response'],
'type': 'object',
},
outer_typed_dict_key='response',
)
]
)
Loading