Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pydantic_ai_slim/pydantic_ai/providers/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,19 @@ class BedrockModelProfile(ModelProfile):
ALL FIELDS MUST BE `bedrock_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
"""

bedrock_supports_tool_choice: bool = True
bedrock_supports_tool_choice: bool = False
bedrock_tool_result_format: Literal['text', 'json'] = 'text'
bedrock_send_back_thinking_parts: bool = False


def bedrock_amazon_model_profile(model_name: str) -> ModelProfile | None:
"""Get the model profile for an Amazon model used via Bedrock."""
profile = amazon_model_profile(model_name)
if 'nova' in model_name:
return BedrockModelProfile(bedrock_supports_tool_choice=True).update(profile)
return profile


class BedrockProvider(Provider[BaseClient]):
"""Provider for AWS Bedrock."""

Expand All @@ -58,13 +66,13 @@ def client(self) -> BaseClient:
def model_profile(self, model_name: str) -> ModelProfile | None:
provider_to_profile: dict[str, Callable[[str], ModelProfile | None]] = {
'anthropic': lambda model_name: BedrockModelProfile(
bedrock_supports_tool_choice=False, bedrock_send_back_thinking_parts=True
bedrock_supports_tool_choice=True, bedrock_send_back_thinking_parts=True
).update(anthropic_model_profile(model_name)),
'mistral': lambda model_name: BedrockModelProfile(bedrock_tool_result_format='json').update(
mistral_model_profile(model_name)
),
'cohere': cohere_model_profile,
'amazon': amazon_model_profile,
'amazon': bedrock_amazon_model_profile,
'meta': meta_model_profile,
'deepseek': deepseek_model_profile,
}
Expand Down
27 changes: 24 additions & 3 deletions tests/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,15 +844,15 @@ async def test_bedrock_mistral_tool_result_format(bedrock_provider: BedrockProvi
)


async def test_bedrock_anthropic_no_tool_choice(bedrock_provider: BedrockProvider):
async def test_bedrock_no_tool_choice(bedrock_provider: BedrockProvider):
my_tool = ToolDefinition(
name='my_tool',
description='This is my tool',
parameters_json_schema={'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}},
)
mrp = ModelRequestParameters(output_mode='tool', function_tools=[my_tool], allow_text_output=False, output_tools=[])

# Models other than Anthropic support tool_choice
# Amazon Nova supports tool_choice
model = BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider)
tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage]

Expand All @@ -873,10 +873,31 @@ async def test_bedrock_anthropic_no_tool_choice(bedrock_provider: BedrockProvide
}
)

# Anthropic models don't support tool_choice
# Anthropic supports tool_choice
model = BedrockConverseModel('us.anthropic.claude-3-7-sonnet-20250219-v1:0', provider=bedrock_provider)
tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage]

assert tool_config == snapshot(
{
'tools': [
{
'toolSpec': {
'name': 'my_tool',
'description': 'This is my tool',
'inputSchema': {
'json': {'type': 'object', 'title': 'Result', 'properties': {'spam': {'type': 'number'}}}
},
}
}
],
'toolChoice': {'any': {}},
}
)

# Other models don't support tool_choice
model = BedrockConverseModel('us.meta.llama4-maverick-17b-instruct-v1:0', provider=bedrock_provider)
tool_config = model._map_tool_config(mrp) # type: ignore[reportPrivateUsage]

assert tool_config == snapshot(
{
'tools': [
Expand Down
12 changes: 9 additions & 3 deletions tests/providers/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def test_bedrock_provider_model_profile(env: TestEnv, mocker: MockerFixture):
anthropic_profile = provider.model_profile('us.anthropic.claude-3-5-sonnet-20240620-v1:0')
anthropic_model_profile_mock.assert_called_with('claude-3-5-sonnet-20240620')
assert isinstance(anthropic_profile, BedrockModelProfile)
assert not anthropic_profile.bedrock_supports_tool_choice
assert anthropic_profile.bedrock_supports_tool_choice is True

anthropic_profile = provider.model_profile('anthropic.claude-instant-v1')
anthropic_model_profile_mock.assert_called_with('claude-instant')
assert isinstance(anthropic_profile, BedrockModelProfile)
assert not anthropic_profile.bedrock_supports_tool_choice
assert anthropic_profile.bedrock_supports_tool_choice is True

mistral_profile = provider.model_profile('mistral.mistral-large-2407-v1:0')
mistral_model_profile_mock.assert_called_with('mistral-large-2407')
Expand All @@ -84,7 +84,13 @@ def test_bedrock_provider_model_profile(env: TestEnv, mocker: MockerFixture):
assert deepseek_profile is not None
assert deepseek_profile.ignore_streamed_leading_whitespace is True

amazon_profile = provider.model_profile('amazon.titan-text-express-v1')
amazon_profile = provider.model_profile('us.amazon.nova-pro-v1:0')
amazon_model_profile_mock.assert_called_with('nova-pro')
assert isinstance(amazon_profile, BedrockModelProfile)
assert amazon_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
assert amazon_profile.bedrock_supports_tool_choice is True

amazon_profile = provider.model_profile('us.amazon.titan-text-express-v1:0')
amazon_model_profile_mock.assert_called_with('titan-text-express')
assert amazon_profile is not None
assert amazon_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer
Expand Down