Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 6 additions & 13 deletions docs/docs/providers/inference/remote_bedrock.mdx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
description: "AWS Bedrock inference provider for accessing various AI models through AWS's managed service."
description: "AWS Bedrock inference provider using OpenAI compatible endpoint."
sidebar_label: Remote - Bedrock
title: remote::bedrock
---
Expand All @@ -8,27 +8,20 @@ title: remote::bedrock

## Description

AWS Bedrock inference provider for accessing various AI models through AWS's managed service.
AWS Bedrock inference provider using OpenAI compatible endpoint.

## Configuration

| Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------|
| `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
| `refresh_models` | `<class 'bool'>` | No | False | Whether to refresh models periodically from the provider |
| `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID |
| `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY |
| `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN |
| `region_name` | `str \| None` | No | | The default AWS Region to use, for example, us-west-1 or us-west-2.Default use environment variable: AWS_DEFAULT_REGION |
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
| `api_key` | `str \| None` | No | | Amazon Bedrock API key |
| `region_name` | `<class 'str'>` | No | us-east-2 | AWS Region for the Bedrock Runtime endpoint |

## Sample Configuration

```yaml
{}
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
```
3 changes: 3 additions & 0 deletions llama_stack/distributions/ci-tests/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/distributions/starter-gpu/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
Expand Down
3 changes: 3 additions & 0 deletions llama_stack/distributions/starter/run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ providers:
api_key: ${env.TOGETHER_API_KEY:=}
- provider_id: bedrock
provider_type: remote::bedrock
config:
api_key: ${env.AWS_BEDROCK_API_KEY:=}
region_name: ${env.AWS_DEFAULT_REGION:=us-east-2}
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
provider_type: remote::nvidia
config:
Expand Down
5 changes: 3 additions & 2 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
pip_packages=[],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
provider_data_validator="llama_stack.providers.remote.inference.bedrock.config.BedrockProviderDataValidator",
description="AWS Bedrock inference provider using OpenAI compatible endpoint.",
),
RemoteProviderSpec(
api=Api.inference,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ async def get_adapter_impl(config: BedrockConfig, _deps):

assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"

impl = BedrockInferenceAdapter(config)
impl = BedrockInferenceAdapter(config=config)

await impl.initialize()

Expand Down
171 changes: 69 additions & 102 deletions llama_stack/providers/remote/inference/bedrock/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,139 +4,106 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import json
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterable

from botocore.client import BaseClient
from openai import AuthenticationError

from llama_stack.apis.inference import (
ChatCompletionRequest,
Inference,
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletion,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIEmbeddingsResponse,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
)
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_strategy_options,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
)

from .models import MODEL_ENTRIES

REGION_PREFIX_MAP = {
"us": "us.",
"eu": "eu.",
"ap": "ap.",
}


def _get_region_prefix(region: str | None) -> str:
# AWS requires region prefixes for inference profiles
if region is None:
return "us." # default to US when we don't know

# Handle case insensitive region matching
region_lower = region.lower()
for prefix in REGION_PREFIX_MAP:
if region_lower.startswith(f"{prefix}-"):
return REGION_PREFIX_MAP[prefix]

# Fallback to US for anything we don't recognize
return "us."


def _to_inference_profile_id(model_id: str, region: str = None) -> str:
# Return ARNs unchanged
if model_id.startswith("arn:"):
return model_id

# Return inference profile IDs that already have regional prefixes
if any(model_id.startswith(p) for p in REGION_PREFIX_MAP.values()):
return model_id

# Default to US East when no region is provided
if region is None:
region = "us-east-1"

return _get_region_prefix(region) + model_id
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.telemetry.tracing import get_current_span

from .config import BedrockConfig

class BedrockInferenceAdapter(
ModelRegistryHelper,
Inference,
):
def __init__(self, config: BedrockConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self._config = config
self._client = None

@property
def client(self) -> BaseClient:
if self._client is None:
self._client = create_bedrock_client(self._config)
return self._client
class BedrockInferenceAdapter(OpenAIMixin):
"""
Adapter for AWS Bedrock's OpenAI-compatible API endpoints.

async def initialize(self) -> None:
pass
Supports Llama models across regions and GPT-OSS models (us-west-2 only).

async def shutdown(self) -> None:
if self._client is not None:
self._client.close()
Note: Bedrock's OpenAI-compatible endpoint does not support /v1/models
for dynamic model discovery. Models must be pre-registered in the config.
"""

async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> dict:
bedrock_model = request.model
config: BedrockConfig
provider_data_api_key_field: str = "aws_bedrock_api_key"

sampling_params = request.sampling_params
options = get_sampling_strategy_options(sampling_params)
def get_api_key(self) -> str:
"""Get API key for OpenAI client."""
if not self.config.api_key:
raise ValueError(
"API key is not set. Please provide a valid API key in the "
"provider config or via AWS_BEDROCK_API_KEY environment variable."
)
return self.config.api_key

if sampling_params.max_tokens:
options["max_gen_len"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
options["repetition_penalty"] = sampling_params.repetition_penalty
def get_base_url(self) -> str:
"""Get base URL for OpenAI client."""
return f"https://bedrock-runtime.{self.config.region_name}.amazonaws.com/openai/v1"

prompt = await chat_completion_request_to_prompt(request, self.get_llama_model(request.model))
async def list_provider_model_ids(self) -> Iterable[str]:
"""
Bedrock's OpenAI-compatible endpoint does not support the /v1/models endpoint.
Returns empty list since models must be pre-registered in the config.
"""
return []

# Convert foundation model ID to inference profile ID
region_name = self.client.meta.region_name
inference_profile_id = _to_inference_profile_id(bedrock_model, region_name)
async def register_model(self, model: Model) -> Model:
"""
Register a model with the Bedrock provider.

return {
"modelId": inference_profile_id,
"body": json.dumps(
{
"prompt": prompt,
**options,
}
),
}
Bedrock doesn't support dynamic model listing via /v1/models, so we skip
the availability check and accept all models registered in the config.
"""
return model

async def openai_embeddings(
self,
params: OpenAIEmbeddingsRequestWithExtraBody,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
"""Bedrock's OpenAI-compatible API does not support the /v1/embeddings endpoint."""
raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/embeddings endpoint. "
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
)

async def openai_completion(
self,
params: OpenAICompletionRequestWithExtraBody,
) -> OpenAICompletion:
raise NotImplementedError("OpenAI completion not supported by the Bedrock provider")
"""Bedrock's OpenAI-compatible API does not support the /v1/completions endpoint."""
raise NotImplementedError(
"Bedrock's OpenAI-compatible API does not support /v1/completions endpoint. "
"Only /v1/chat/completions is supported. "
"See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-chat-completions.html"
)

async def openai_chat_completion(
self,
params: OpenAIChatCompletionRequestWithExtraBody,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
raise NotImplementedError("OpenAI chat completion not supported by the Bedrock provider")
"""Override to enable streaming usage metrics and handle authentication errors."""
# Enable streaming usage metrics when telemetry is active
if params.stream and get_current_span() is not None:
if params.stream_options is None:
params.stream_options = {"include_usage": True}
elif "include_usage" not in params.stream_options:
params.stream_options = {**params.stream_options, "include_usage": True}

# Wrap call in try/except to catch authentication errors
try:
return await super().openai_chat_completion(params=params)
except AuthenticationError as e:
raise ValueError(
f"AWS Bedrock authentication failed: {e.message}. "
"Please check your API key in the provider config or x-llamastack-provider-data header."
) from e
31 changes: 28 additions & 3 deletions llama_stack/providers/remote/inference/bedrock/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,33 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
import os

from pydantic import BaseModel, Field

class BedrockConfig(BedrockBaseConfig):
pass
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig


class BedrockProviderDataValidator(BaseModel):
aws_bedrock_api_key: str | None = Field(
default=None,
description="API key for Amazon Bedrock",
)


class BedrockConfig(RemoteInferenceProviderConfig):
api_key: str | None = Field(
default_factory=lambda: os.getenv("AWS_BEDROCK_API_KEY"),
description="Amazon Bedrock API key",
)
region_name: str = Field(
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION", "us-east-2"),
description="AWS Region for the Bedrock Runtime endpoint",
)

@classmethod
def sample_run_config(cls, **kwargs):
return {
"api_key": "${env.AWS_BEDROCK_API_KEY:=}",
"region_name": "${env.AWS_DEFAULT_REGION:=us-east-2}",
}
29 changes: 0 additions & 29 deletions llama_stack/providers/remote/inference/bedrock/models.py

This file was deleted.

Loading
Loading