diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx index 683ec12f8a..c6804e9c5d 100644 --- a/docs/docs/providers/inference/remote_bedrock.mdx +++ b/docs/docs/providers/inference/remote_bedrock.mdx @@ -1,5 +1,5 @@ --- -description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service." +description: "AWS Bedrock inference provider using OpenAI compatible endpoint." sidebar_label: Remote - Bedrock title: remote::bedrock --- @@ -8,7 +8,7 @@ title: remote::bedrock ## Description -AWS Bedrock inference provider for accessing various AI models through AWS's managed service. +AWS Bedrock inference provider using OpenAI compatible endpoint. ## Configuration @@ -16,19 +16,12 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | | `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | -| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | -| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | -| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | -| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION | -| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | -| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | -| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | -| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | +| `api_key` | `str \| None` | No | | Amazon Bedrock API key | +| `region_name` | `` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint | ## Sample Configuration ```yaml -{} +api_key: ${env.AWS_BEDROCK_API_KEY:=} +region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} ``` diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 40f4d8a0a7..7946f2412a 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -47,6 +47,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index b281218155..062f236264 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -47,6 +47,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/llama_stack/distributions/starter/run.yaml b/llama_stack/distributions/starter/run.yaml index 341b51a976..73f5c2a696 100644 --- a/llama_stack/distributions/starter/run.yaml +++ b/llama_stack/distributions/starter/run.yaml @@ -47,6 +47,9 @@ providers: api_key: ${env.TOGETHER_API_KEY:=} - provider_id: bedrock provider_type: remote::bedrock + config: + api_key: ${env.AWS_BEDROCK_API_KEY:=} + region_name: ${env.AWS_DEFAULT_REGION:=us-east-2} - provider_id: ${env.NVIDIA_API_KEY:+nvidia} provider_type: remote::nvidia config: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index f895658928..93bfeffddf 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -131,10 +131,11 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter_type="bedrock", provider_type="remote::bedrock", - pip_packages=["boto3"], + pip_packages=[], module="llama_stack.providers.remote.inference.bedrock", config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", - description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", + provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator", + description="AWS Bedrock inference provider using OpenAI compatible endpoint.", ), RemoteProviderSpec( api=Api.inference, diff --git a/llama_stack/providers/remote/inference/bedrock/__init__.py b/llama_stack/providers/remote/inference/bedrock/__init__.py index 4d98f4999a..4b0686b187 100644 --- a/llama_stack/providers/remote/inference/bedrock/__init__.py +++ b/llama_stack/providers/remote/inference/bedrock/__init__.py @@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps): assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}" - impl = BedrockInferenceAdapter(config) + impl = BedrockInferenceAdapter(config=config) await impl.initialize() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index d266f9e6f7..a4093aefe8 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -4,139 +4,106 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Iterable -from botocore.client import BaseClient +from openai import AuthenticationError from llama_stack.apis.inference import ( - ChatCompletionRequest, - Inference, + Model, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, OpenAIChatCompletionRequestWithExtraBody, + OpenAICompletion, OpenAICompletionRequestWithExtraBody, OpenAIEmbeddingsRequestWithExtraBody, OpenAIEmbeddingsResponse, ) -from llama_stack.apis.inference.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, -) -from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -from llama_stack.providers.utils.bedrock.client import create_bedrock_client -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) -from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_strategy_options, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, -) - -from .models import MODEL_ENTRIES - -REGION_PREFIX_MAP = { - "us": "us.", - "eu": "eu.", - "ap": "ap.", -} - - -def _get_region_prefix(region: str | None) -> str: - # AWS requires region prefixes for inference profiles - if region is None: - return "us." # default to US when we don't know - - # Handle case insensitive region matching - region_lower = region.lower() - for prefix in REGION_PREFIX_MAP: - if region_lower.startswith(f"{prefix}-"): - return REGION_PREFIX_MAP[prefix] - - # Fallback to US for anything we don't recognize - return "us." - - -def _to_inference_profile_id(model_id: str, region: str = None) -> str: - # Return ARNs unchanged - if model_id.startswith("arn:"): - return model_id - - # Return inference profile IDs that already have regional prefixes - if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()): - return model_id - - # Default to US East when no region is provided - if region is None: - region = "us-east-1" - - return _get_region_prefix(region) + model_id +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin +from llama_stack.providers.utils.telemetry.tracing import get_current_span +from .config import BedrockConfig -class BedrockInferenceAdapter( - ModelRegistryHelper, - Inference, -): - def __init__(self, config: BedrockConfig) -> None: - ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) - self._config = config - self._client = None - @property - def client(self) -> BaseClient: - if self._client is None: - self._client = create_bedrock_client(self._config) - return self._client +class BedrockInferenceAdapter(OpenAIMixin): + """ + Adapter for AWS Bedrock's OpenAI-compatible API endpoints. - async def initialize(self) -> None: - pass + Supports Llama models across regions and GPT-OSS models (us-west-2 only). - async def shutdown(self) -> None: - if self._client is not None: - self._client.close() + Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models + for dynamic model discovery. Models must be pre-registered in the config. + """ - async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict: - bedrock_model = request.model + config: BedrockConfig + provider_data_api_key_field: str = "aws_bedrock_api_key" - sampling_params = request.sampling_params - options = get_sampling_strategy_options(sampling_params) + def get_api_key(self) -> str: + """Get API key for OpenAI client.""" + if not self.config.api_key: + raise ValueError( + "API key is not set. Please provide a valid API key in the " + "provider config or via AWS_BEDROCK_API_KEY environment variable." + ) + return self.config.api_key - if sampling_params.max_tokens: - options["max_gen_len"] = sampling_params.max_tokens - if sampling_params.repetition_penalty > 0: - options["repetition_penalty"] = sampling_params.repetition_penalty + def get_base_url(self) -> str: + """Get base URL for OpenAI client.""" + return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1" - prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model)) + async def list_provider_model_ids(self) -> Iterable[str]: + """ + Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint. + Returns empty list since models must be pre-registered in the config. + """ + return [] - # Convert foundation model ID to inference profile ID - region_name = self.client.meta.region_name - inference_profile_id = _to_inference_profile_id(bedrock_model, region_name) + async def register_model(self, model: Model) -> Model: + """ + Register a model with the Bedrock provider. - return { - "modelId": inference_profile_id, - "body": json.dumps( - { - "prompt": prompt, - **options, - } - ), - } + Bedrock doesn't support dynamic model listing via /v1/models, so we skip + the availability check and accept all models registered in the config. + """ + return model async def openai_embeddings( self, params: OpenAIEmbeddingsRequestWithExtraBody, ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() + """Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint.""" + raise NotImplementedError( + "Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. " + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" + ) async def openai_completion( self, params: OpenAICompletionRequestWithExtraBody, ) -> OpenAICompletion: - raise NotImplementedError("OpenAI completion not supported by the Bedrock provider") + """Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint.""" + raise NotImplementedError( + "Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. " + "Only /v1/chat/completions is supported. " + "See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html" + ) async def openai_chat_completion( self, params: OpenAIChatCompletionRequestWithExtraBody, ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider") + """Override to enable streaming usage metrics and handle authentication errors.""" + # Enable streaming usage metrics when telemetry is active + if params.stream and get_current_span() is not None: + if params.stream_options is None: + params.stream_options = {"include_usage": True} + elif "include_usage" not in params.stream_options: + params.stream_options = {**params.stream_options, "include_usage": True} + + # Wrap call in try/except to catch authentication errors + try: + return await super().openai_chat_completion(params=params) + except AuthenticationError as e: + raise ValueError( + f"AWS Bedrock authentication failed: {e.message}. " + "Please check your API key in the provider config or x-llamastack-provider-data header." + ) from e diff --git a/llama_stack/providers/remote/inference/bedrock/config.py b/llama_stack/providers/remote/inference/bedrock/config.py index 5961a2f153..2b236e9022 100644 --- a/llama_stack/providers/remote/inference/bedrock/config.py +++ b/llama_stack/providers/remote/inference/bedrock/config.py @@ -4,8 +4,33 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig +import os +from pydantic import BaseModel, Field -class BedrockConfig(BedrockBaseConfig): - pass +from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig + + +class BedrockProviderDataValidator(BaseModel): + aws_bedrock_api_key: str | None = Field( + default=None, + description="API key for Amazon Bedrock", + ) + + +class BedrockConfig(RemoteInferenceProviderConfig): + api_key: str | None = Field( + default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"), + description="Amazon Bedrock API key", + ) + region_name: str = Field( + default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"), + description="AWS Region for the Bedrock Runtime endpoint", + ) + + @classmethod + def sample_run_config(cls, **kwargs): + return { + "api_key": "${env.AWS_BEDROCK_API_KEY:=}", + "region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}", + } diff --git a/llama_stack/providers/remote/inference/bedrock/models.py b/llama_stack/providers/remote/inference/bedrock/models.py deleted file mode 100644 index 17273c1220..0000000000 --- a/llama_stack/providers/remote/inference/bedrock/models.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.models.llama.sku_types import CoreModelId -from llama_stack.providers.utils.inference.model_registry import ( - build_hf_repo_model_entry, -) - -SAFETY_MODELS_ENTRIES = [] - - -# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html -MODEL_ENTRIES = [ - build_hf_repo_model_entry( - "meta.llama3-1-8b-instruct-v1:0", - CoreModelId.llama3_1_8b_instruct.value, - ), - build_hf_repo_model_entry( - "meta.llama3-1-70b-instruct-v1:0", - CoreModelId.llama3_1_70b_instruct.value, - ), - build_hf_repo_model_entry( - "meta.llama3-1-405b-instruct-v1:0", - CoreModelId.llama3_1_405b_instruct.value, - ), -] + SAFETY_MODELS_ENTRIES diff --git a/tests/unit/providers/inference/test_bedrock_adapter.py b/tests/unit/providers/inference/test_bedrock_adapter.py new file mode 100644 index 0000000000..2144059d60 --- /dev/null +++ b/tests/unit/providers/inference/test_bedrock_adapter.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest +from openai import AuthenticationError + +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody +from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig + + +def test_adapter_initialization(): + config = BedrockConfig(api_key="test-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.config.api_key == "test-key" + assert adapter.config.region_name == "us-east-1" + + +def test_client_url_construction(): + config = BedrockConfig(api_key="test-key", region_name="us-west-2") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.get_base_url() == "https://bedrock-runtime.us-west-2.amazonaws.com/openai/v1" + assert adapter.get_api_key() == "test-key" + + +def test_api_key_from_config(): + """Test API key is read from config""" + config = BedrockConfig(api_key="config-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + assert adapter.get_api_key() == "config-key" + + +def test_api_key_from_header_overrides_config(): + """Test API key from request header overrides config via client property""" + config = BedrockConfig(api_key="config-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + adapter.provider_data_api_key_field = "aws_bedrock_api_key" + adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key")) + + # The client property is where header override happens (in OpenAIMixin) + assert adapter.client.api_key == "header-key" + + +async def test_authentication_error_handling(): + """Test that AuthenticationError from OpenAI client is converted to ValueError with helpful message""" + config = BedrockConfig(api_key="invalid-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + + # Mock the parent class method to raise AuthenticationError + mock_response = MagicMock() + mock_response.message = "Invalid authentication credentials" + auth_error = AuthenticationError(message="Invalid authentication credentials", response=mock_response, body=None) + + # Create a mock that raises the error + mock_super = AsyncMock(side_effect=auth_error) + + # Patch the parent class method + original_method = BedrockInferenceAdapter.__bases__[0].openai_chat_completion + BedrockInferenceAdapter.__bases__[0].openai_chat_completion = mock_super + + try: + with pytest.raises(ValueError) as exc_info: + params = OpenAIChatCompletionRequestWithExtraBody( + model="test-model", messages=[{"role": "user", "content": "test"}] + ) + await adapter.openai_chat_completion(params=params) + + assert "AWS Bedrock authentication failed" in str(exc_info.value) + assert "Please check your API key" in str(exc_info.value) + finally: + # Restore original method + BedrockInferenceAdapter.__bases__[0].openai_chat_completion = original_method diff --git a/tests/unit/providers/inference/test_bedrock_config.py b/tests/unit/providers/inference/test_bedrock_config.py new file mode 100644 index 0000000000..4d97900e40 --- /dev/null +++ b/tests/unit/providers/inference/test_bedrock_config.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig + + +def test_bedrock_config_defaults_no_env(monkeypatch): + """Test BedrockConfig defaults when env vars are not set""" + monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False) + monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False) + config = BedrockConfig() + assert config.api_key is None + assert config.region_name == "us-east-2" + + +def test_bedrock_config_defaults_with_env(monkeypatch): + """Test BedrockConfig reads from environment variables""" + monkeypatch.setenv("AWS_BEDROCK_API_KEY", "env-key") + monkeypatch.setenv("AWS_DEFAULT_REGION", "eu-west-1") + config = BedrockConfig() + assert config.api_key == "env-key" + assert config.region_name == "eu-west-1" + + +def test_bedrock_config_with_values(): + """Test BedrockConfig accepts explicit values""" + config = BedrockConfig(api_key="test-key", region_name="us-west-2") + assert config.api_key == "test-key" + assert config.region_name == "us-west-2" + + +def test_bedrock_config_sample(): + """Test BedrockConfig sample_run_config returns correct format""" + sample = BedrockConfig.sample_run_config() + assert "api_key" in sample + assert "region_name" in sample + assert sample["api_key"] == "${env.AWS_BEDROCK_API_KEY:=}" + assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}" diff --git a/tests/unit/providers/test_bedrock.py b/tests/unit/providers/test_bedrock.py index 1ff07bbbe8..18ab912584 100644 --- a/tests/unit/providers/test_bedrock.py +++ b/tests/unit/providers/test_bedrock.py @@ -4,50 +4,63 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.remote.inference.bedrock.bedrock import ( - _get_region_prefix, - _to_inference_profile_id, -) +from types import SimpleNamespace +from unittest.mock import AsyncMock, PropertyMock, patch +from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody +from llama_stack.providers.remote.inference.bedrock.bedrock import BedrockInferenceAdapter +from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig -def test_region_prefixes(): - assert _get_region_prefix("us-east-1") == "us." - assert _get_region_prefix("eu-west-1") == "eu." - assert _get_region_prefix("ap-south-1") == "ap." - assert _get_region_prefix("ca-central-1") == "us." - # Test case insensitive - assert _get_region_prefix("US-EAST-1") == "us." - assert _get_region_prefix("EU-WEST-1") == "eu." - assert _get_region_prefix("Ap-South-1") == "ap." +def test_can_create_adapter(): + config = BedrockConfig(api_key="test-key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) - # Test None region - assert _get_region_prefix(None) == "us." + assert adapter is not None + assert adapter.config.region_name == "us-east-1" + assert adapter.get_api_key() == "test-key" -def test_model_id_conversion(): - # Basic conversion - assert ( - _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "us-east-1") == "us.meta.llama3-1-70b-instruct-v1:0" - ) +def test_different_aws_regions(): + # just check a couple regions to verify URL construction works + config = BedrockConfig(api_key="key", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.get_base_url() == "https://bedrock-runtime.us-east-1.amazonaws.com/openai/v1" - # Already has prefix - assert ( - _to_inference_profile_id("us.meta.llama3-1-70b-instruct-v1:0", "us-east-1") - == "us.meta.llama3-1-70b-instruct-v1:0" - ) + config = BedrockConfig(api_key="key", region_name="eu-west-1") + adapter = BedrockInferenceAdapter(config=config) + assert adapter.get_base_url() == "https://bedrock-runtime.eu-west-1.amazonaws.com/openai/v1" - # ARN should be returned unchanged - arn = "arn:aws:bedrock:us-east-1:123456789012:inference-profile/us.meta.llama3-1-70b-instruct-v1:0" - assert _to_inference_profile_id(arn, "us-east-1") == arn - # ARN should be returned unchanged even without region - assert _to_inference_profile_id(arn) == arn +async def test_basic_chat_completion(): + """Test basic chat completion works with OpenAIMixin""" + config = BedrockConfig(api_key="k", region_name="us-east-1") + adapter = BedrockInferenceAdapter(config=config) - # Optional region parameter defaults to us-east-1 - assert _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0") == "us.meta.llama3-1-70b-instruct-v1:0" + class FakeModelStore: + async def get_model(self, model_id): + return SimpleNamespace(provider_resource_id="meta.llama3-1-8b-instruct-v1:0") - # Different regions work with optional parameter - assert ( - _to_inference_profile_id("meta.llama3-1-70b-instruct-v1:0", "eu-west-1") == "eu.meta.llama3-1-70b-instruct-v1:0" + adapter.model_store = FakeModelStore() + + fake_response = SimpleNamespace( + id="chatcmpl-123", + choices=[SimpleNamespace(message=SimpleNamespace(content="Hello!", role="assistant"), finish_reason="stop")], ) + + mock_create = AsyncMock(return_value=fake_response) + + class FakeClient: + def __init__(self): + self.chat = SimpleNamespace(completions=SimpleNamespace(create=mock_create)) + + with patch.object(type(adapter), "client", new_callable=PropertyMock, return_value=FakeClient()): + params = OpenAIChatCompletionRequestWithExtraBody( + model="llama3-1-8b", + messages=[{"role": "user", "content": "hello"}], + stream=False, + ) + response = await adapter.openai_chat_completion(params=params) + + assert response.id == "chatcmpl-123" + assert mock_create.await_count == 1