Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set model deployment configuration through the UI at runtime (verse .env) #151

Merged
merged 17 commits into from
May 24, 2024
4 changes: 3 additions & 1 deletion .env-template
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ NEXT_PUBLIC_HAS_CUSTOM_LOGO=false
COHERE_API_KEY=<API_KEY_HERE>

# 2 - SageMaker
SAGE_MAKER_PROFILE_NAME=<PROFILE NAME>
SAGE_MAKER_ACCESS_KEY=<ACCESS KEY>
SAGE_MAKER_SECRET_KEY=<SECRET KEY>
SAGE_MAKER_SESSION_TOKEN=<SESSION TOKEN>
SAGE_MAKER_REGION_NAME=<REGION NAME>
SAGE_MAKER_ENDPOINT_NAME=<ENDPOINT NAME>

Expand Down
2 changes: 0 additions & 2 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ services:
- ./src/backend/alembic:/workspace/src/backend/alembic
# Mount data folder to sync uploaded files
- ./src/backend/data:/workspace/src/backend/data
# For SageMaker: The line below for AWS configure file to sync credentials
- $HOME/.aws:/root/.aws
# network_mode: host

frontend:
Expand Down
4 changes: 2 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from fastapi import HTTPException

from backend.chat.base import BaseChat
from backend.chat.custom.utils import get_deployment
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_deployment
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.tool import Category, Tool
from backend.services.logger import get_logger
Expand All @@ -30,7 +30,7 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
Generator[StreamResponse, None, None]: Chat response.
"""
# Choose the deployment model - validation already performed by request validator
deployment_model = get_deployment(kwargs.get("deployment_name"))
deployment_model = get_deployment(kwargs.get("deployment_name"), **kwargs)
self.logger.info(f"Using deployment {deployment_model.__class__.__name__}")

if len(chat_request.tools) > 0 and len(chat_request.documents) > 0:
Expand Down
32 changes: 32 additions & 0 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import Any

from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS
from backend.model_deployments.base import BaseDeployment


def get_deployment(name, **kwargs: Any) -> BaseDeployment:
"""Get the deployment implementation.

Args:
deployment (str): Deployment name.

Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.

Raises:
ValueError: If the deployment is not supported.
"""
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
return deployment.deployment_class(**kwargs)

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
26 changes: 8 additions & 18 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
CohereDeployment,
SageMakerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.schemas.deployment import Deployment


Expand All @@ -28,42 +32,28 @@ class ModelDeploymentName(StrEnum):
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=[
"COHERE_API_KEY",
],
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=[
"SAGE_MAKER_REGION_NAME",
"SAGE_MAKER_ENDPOINT_NAME",
"SAGE_MAKER_PROFILE_NAME",
],
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=[
"AZURE_API_KEY",
"AZURE_CHAT_ENDPOINT_URL",
],
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=[
"BEDROCK_ACCESS_KEY",
"BEDROCK_SECRET_KEY",
"BEDROCK_SESSION_TOKEN",
"BEDROCK_REGION_NAME",
],
env_vars=BEDROCK_ENV_VARS,
),
}

Expand Down
20 changes: 14 additions & 6 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,15 @@
from cohere.types import StreamedChatResponse

from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.schemas.cohere_chat import CohereChatRequest

AZURE_API_KEY_ENV_VAR = "AZURE_API_KEY"
# Example URL: "https://<endpoint>.<region>.inference.ai.azure.com/v1"
# Note: It must have /v1 and it should not have /chat
AZURE_CHAT_URL_ENV_VAR = "AZURE_CHAT_ENDPOINT_URL"
AZURE_ENV_VARS = [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR]


class AzureDeployment(BaseDeployment):
"""
Expand All @@ -16,14 +23,15 @@ class AzureDeployment(BaseDeployment):
"""

DEFAULT_MODELS = ["azure-command"]
api_key = os.environ.get("AZURE_API_KEY")
# Example URL: "https://<endpoint>.<region>.inference.ai.azure.com/v1"
# Note: It must have /v1 and it should not have /chat
chat_endpoint_url = os.environ.get("AZURE_CHAT_ENDPOINT_URL")

def __init__(self):
def __init__(self, **kwargs: Any):
# Override the environment variable from the request
self.api_key = get_model_config_var(AZURE_API_KEY_ENV_VAR, **kwargs)
self.chat_endpoint_url = get_model_config_var(AZURE_CHAT_URL_ENV_VAR, **kwargs)

if not self.chat_endpoint_url.endswith("/v1"):
self.chat_endpoint_url = self.chat_endpoint_url + "/v1"
print("Azure chat endpoint url: ", self.chat_endpoint_url)
self.client = cohere.Client(
base_url=self.chat_endpoint_url, api_key=self.api_key
)
Expand All @@ -41,7 +49,7 @@ def list_models(cls) -> List[str]:

@classmethod
def is_available(cls) -> bool:
return all([cls.api_key is not None, cls.chat_endpoint_url is not None])
return all([os.environ.get(var) is not None for var in AZURE_ENV_VARS])

def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
return self.client.chat(
Expand Down
37 changes: 20 additions & 17 deletions src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,36 @@
from cohere.types import StreamedChatResponse

from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.schemas.cohere_chat import CohereChatRequest

BEDROCK_ACCESS_KEY_ENV_VAR = "BEDROCK_ACCESS_KEY"
BEDROCK_SECRET_KEY_ENV_VAR = "BEDROCK_SECRET_KEY"
BEDROCK_SESSION_TOKEN_ENV_VAR = "BEDROCK_SESSION_TOKEN"
BEDROCK_REGION_NAME_ENV_VAR = "BEDROCK_REGION_NAME"
BEDROCK_ENV_VARS = [
BEDROCK_ACCESS_KEY_ENV_VAR,
BEDROCK_SECRET_KEY_ENV_VAR,
BEDROCK_SESSION_TOKEN_ENV_VAR,
BEDROCK_REGION_NAME_ENV_VAR,
]


class BedrockDeployment(BaseDeployment):
DEFAULT_MODELS = ["cohere.command-r-plus-v1:0"]
access_key = os.environ.get("BEDROCK_ACCESS_KEY")
secret_key = os.environ.get("BEDROCK_SECRET_KEY")
session_token = os.environ.get("BEDROCK_SESSION_TOKEN")
region_name = os.environ.get("BEDROCK_REGION_NAME")

def __init__(self):
def __init__(self, **kwargs: Any):
self.client = cohere.BedrockClient(
# TODO: remove hardcoded models once the SDK is updated
chat_model="cohere.command-r-plus-v1:0",
embed_model="cohere.embed-multilingual-v3",
generate_model="cohere.command-text-v14",
aws_access_key=self.access_key,
aws_secret_key=self.secret_key,
aws_session_token=self.session_token,
aws_region=self.region_name,
aws_access_key=get_model_config_var(BEDROCK_ACCESS_KEY_ENV_VAR, **kwargs),
aws_secret_key=get_model_config_var(BEDROCK_SECRET_KEY_ENV_VAR, **kwargs),
aws_session_token=get_model_config_var(
BEDROCK_SESSION_TOKEN_ENV_VAR, **kwargs
),
aws_region=get_model_config_var(BEDROCK_REGION_NAME_ENV_VAR, **kwargs),
)

@property
Expand All @@ -40,14 +50,7 @@ def list_models(cls) -> List[str]:

@classmethod
def is_available(cls) -> bool:
return all(
[
cls.access_key is not None,
cls.secret_key is not None,
cls.session_token is not None,
cls.region_name is not None,
]
)
return all([os.environ.get(var) is not None for var in BEDROCK_ENV_VARS])

def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
# bedrock accepts a subset of the chat request fields
Expand Down
12 changes: 9 additions & 3 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@
from cohere.types import StreamedChatResponse

from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.schemas.cohere_chat import CohereChatRequest

COHERE_API_KEY_ENV_VAR = "COHERE_API_KEY"
COHERE_ENV_VARS = [COHERE_API_KEY_ENV_VAR]


class CohereDeployment(BaseDeployment):
"""Cohere Platform Deployment."""

api_key = os.environ.get("COHERE_API_KEY")
client_name = "cohere-toolkit"
api_key = None

def __init__(self):
def __init__(self, **kwargs: Any):
# Override the environment variable from the request
self.api_key = get_model_config_var(COHERE_API_KEY_ENV_VAR, **kwargs)
self.client = cohere.Client(api_key=self.api_key, client_name=self.client_name)

@property
Expand Down Expand Up @@ -49,7 +55,7 @@ def list_models(cls) -> List[str]:

@classmethod
def is_available(cls) -> bool:
return cls.api_key is not None
return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS])

def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:
return self.client.chat(
Expand Down
46 changes: 32 additions & 14 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,22 @@
from cohere.types import StreamedChatResponse

from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.schemas.cohere_chat import CohereChatRequest

SAGE_MAKER_ACCESS_KEY_ENV_VAR = "SAGE_MAKER_ACCESS_KEY"
SAGE_MAKER_SECRET_KEY_ENV_VAR = "SAGE_MAKER_SECRET_KEY"
SAGE_MAKER_SESSION_TOKEN_ENV_VAR = "SAGE_MAKER_SESSION_TOKEN"
SAGE_MAKER_REGION_NAME_ENV_VAR = "SAGE_MAKER_REGION_NAME"
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR = "SAGE_MAKER_ENDPOINT_NAME"
SAGE_MAKER_ENV_VARS = [
SAGE_MAKER_ACCESS_KEY_ENV_VAR,
SAGE_MAKER_SECRET_KEY_ENV_VAR,
SAGE_MAKER_SESSION_TOKEN_ENV_VAR,
SAGE_MAKER_REGION_NAME_ENV_VAR,
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR,
]


class SageMakerDeployment(BaseDeployment):
"""
Expand All @@ -18,16 +32,26 @@ class SageMakerDeployment(BaseDeployment):
"""

DEFAULT_MODELS = ["sagemaker-command"]
profile_name = os.environ.get("SAGE_MAKER_PROFILE_NAME")
region_name = os.environ.get("SAGE_MAKER_REGION_NAME")
endpoint_name = os.environ.get("SAGE_MAKER_ENDPOINT_NAME")

def __init__(self):
boto3.setup_default_session(profile_name=self.profile_name)
def __init__(self, **kwargs: Any):
# Create the AWS client for the Bedrock runtime with boto3
self.client = boto3.client("sagemaker-runtime", region_name=self.region_name)
self.client = boto3.client(
"sagemaker-runtime",
region_name=get_model_config_var(SAGE_MAKER_REGION_NAME_ENV_VAR, **kwargs),
aws_access_key_id=get_model_config_var(
SAGE_MAKER_ACCESS_KEY_ENV_VAR, **kwargs
),
aws_secret_access_key=get_model_config_var(
SAGE_MAKER_SECRET_KEY_ENV_VAR, **kwargs
),
aws_session_token=get_model_config_var(
SAGE_MAKER_SESSION_TOKEN_ENV_VAR, **kwargs
),
)
self.params = {
"EndpointName": self.endpoint_name,
"EndpointName": get_model_config_var(
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, **kwargs
),
"ContentType": "application/json",
}

Expand All @@ -44,13 +68,7 @@ def list_models(cls) -> List[str]:

@classmethod
def is_available(cls) -> bool:
return all(
[
cls.profile_name is not None,
cls.region_name is not None,
cls.endpoint_name is not None,
]
)
return all([os.environ.get(var) is not None for var in SAGE_MAKER_ENV_VARS])

def invoke_chat_stream(
self, chat_request: CohereChatRequest, **kwargs: Any
Expand Down
36 changes: 15 additions & 21 deletions src/backend/model_deployments/utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS
from backend.model_deployments.base import BaseDeployment
import os
from typing import Any


def get_deployment(deployment_name) -> BaseDeployment:
"""Get the deployment implementation.
def get_model_config_var(var_name: str, **kwargs: Any) -> str:
"""Get the model config variable.

Args:
deployment (str): Deployment name.
var_name (str): Variable name.
model_config (dict): Model config.

Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.
str: Model config variable value.

Raises:
ValueError: If the deployment is not supported.
"""
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(deployment_name)

# Check provided deployment against config const
if deployment is not None and deployment.is_available:
return deployment.deployment_class(**deployment.kwargs)

# Fallback to first available deployment
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
return deployment.deployment_class()

raise ValueError(
f"Deployment {deployment_name} is not supported, and no available deployments were found."
model_config = kwargs.get("deployment_config")
config = (
model_config[var_name]
if model_config and model_config.get(var_name)
else os.environ.get(var_name)
)
if not config:
raise ValueError(f"Missing model config variable: {var_name}")
return config
Loading
Loading