Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Deployments refactor; Add deployment service and fix deployment config setting #831

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any

from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
get_default_deployment,
)
from backend.exceptions import DeploymentNotFoundError
malexw marked this conversation as resolved.
Show resolved Hide resolved
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
Expand All @@ -16,22 +14,11 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:

Returns:
BaseDeployment: Deployment implementation instance based on the deployment name.

Raises:
ValueError: If the deployment is not supported.
"""
kwargs["ctx"] = ctx
deployment = AVAILABLE_MODEL_DEPLOYMENTS.get(name)

# Check provided deployment against config const
if deployment is not None:
return deployment.deployment_class(**kwargs, **deployment.kwargs)

# Fallback to first available deployment
default = get_default_deployment(**kwargs)
if default is not None:
return default
try:
deployment = deployment_service.get_deployment_by_name(name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the DeploymentNotFoundError trigger if no deployment is found when filtering through the DB?

Perhaps we should use the fallback logic in a:

if not deployment:
    .. get_default_deployment()

And wrap the whole thing in a try/catch instead

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I understand what you're seeing. With the new code here, deployment_service will throw if it can't find a deployment with the specified name. In that case, we catch and instead return a default deployment. And if there are no available deployments at all, get_default_deployment will also throw.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see, I confused the get_deployment_by_name call with a similarly named method I think. Good to go then.

except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment()

raise ValueError(
f"Deployment {name} is not supported, and no available deployments were found."
)
return deployment(**kwargs)
6 changes: 0 additions & 6 deletions src/backend/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as MANAGED_DEPLOYMENTS_SETUP,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
from community.config.tools import COMMUNITY_TOOLS_SETUP


Expand Down Expand Up @@ -51,9 +48,6 @@ def start():

# SET UP ENVIRONMENT FOR DEPLOYMENTS
all_deployments = MANAGED_DEPLOYMENTS_SETUP.copy()
if use_community_features:
all_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)

selected_deployments = select_deployments_prompt(all_deployments, secrets)

for deployment in selected_deployments:
Expand Down
137 changes: 16 additions & 121 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
@@ -1,140 +1,35 @@
from enum import StrEnum

from backend.config.settings import Settings
from backend.model_deployments import (
AzureDeployment,
BedrockDeployment,
CohereDeployment,
SageMakerDeployment,
SingleContainerDeployment,
)
from backend.model_deployments.azure import AZURE_ENV_VARS
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS
from backend.model_deployments.single_container import SC_ENV_VARS
from backend.schemas.deployment import Deployment
from backend.services.logger.utils import LoggerFactory

logger = LoggerFactory().get_logger()


class ModelDeploymentName(StrEnum):
CoherePlatform = "Cohere Platform"
SageMaker = "SageMaker"
Azure = "Azure"
Bedrock = "Bedrock"
SingleContainer = "Single Container"


use_community_features = Settings().get('feature_flags.use_community_features')
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() }

# TODO names in the map below should not be the display names but ids
ALL_MODEL_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
deployment_class=CohereDeployment,
models=CohereDeployment.list_models(),
is_available=CohereDeployment.is_available(),
env_vars=COHERE_ENV_VARS,
),
ModelDeploymentName.SingleContainer: Deployment(
id="single_container",
name=ModelDeploymentName.SingleContainer,
deployment_class=SingleContainerDeployment,
models=SingleContainerDeployment.list_models(),
is_available=SingleContainerDeployment.is_available(),
env_vars=SC_ENV_VARS,
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
deployment_class=SageMakerDeployment,
models=SageMakerDeployment.list_models(),
is_available=SageMakerDeployment.is_available(),
env_vars=SAGE_MAKER_ENV_VARS,
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
deployment_class=AzureDeployment,
models=AzureDeployment.list_models(),
is_available=AzureDeployment.is_available(),
env_vars=AZURE_ENV_VARS,
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
deployment_class=BedrockDeployment,
models=BedrockDeployment.list_models(),
is_available=BedrockDeployment.is_available(),
env_vars=BEDROCK_ENV_VARS,
),
}

def get_installed_deployments() -> list[type[BaseDeployment]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very small nit to rename get_available_deployments

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to rename this to whatever, but the reason I wanted to get away from the name available is because the models have an is_available method on them, and it might give the impression that a function named get_available_deployments was filtering based on is_available.

installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values())

def get_available_deployments() -> dict[ModelDeploymentName, Deployment]:
if use_community_features:
if Settings().get("feature_flags.use_community_features"):
try:
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)

model_deployments = ALL_MODEL_DEPLOYMENTS.copy()
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
return model_deployments
except ImportError:
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values())
except ImportError as e:
logger.warning(
event="[Deployments] No available community deployments have been configured"
event="[Deployments] No available community deployments have been configured", ex=e
)

deployments = Settings().get('deployments.enabled_deployments')
if deployments is not None and len(deployments) > 0:
return {
key: value
for key, value in ALL_MODEL_DEPLOYMENTS.items()
if value.id in Settings().get('deployments.enabled_deployments')
}

return ALL_MODEL_DEPLOYMENTS


def get_default_deployment(**kwargs) -> BaseDeployment:
# Fallback to the first available deployment
fallback = None
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.is_available:
fallback = deployment.deployment_class(**kwargs)
break

default = Settings().get('deployments.default_deployment')
if default:
return next(
(
v.deployment_class(**kwargs)
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items()
if v.id == default
),
fallback,
)
else:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None
enabled_deployment_ids = Settings().get("deployments.enabled_deployments")
if enabled_deployment_ids:
return [
deployment
for deployment in installed_deployments
if deployment.id() in enabled_deployment_ids
]

return installed_deployments

AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
AVAILABLE_MODEL_DEPLOYMENTS = get_installed_deployments()
18 changes: 9 additions & 9 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from backend.database_models import AgentDeploymentModel, Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
from backend.schemas.deployment import (
DeploymentCreate,
DeploymentDefinition,
DeploymentUpdate,
)
from backend.services.transaction import validate_transaction


@validate_transaction
Expand All @@ -19,7 +19,7 @@ def create_deployment(db: Session, deployment: DeploymentCreate) -> Deployment:

Args:
db (Session): Database session.
deployment (DeploymentSchema): Deployment data to be created.
deployment (DeploymentDefinition): Deployment data to be created.

Returns:
Deployment: Created deployment.
Expand Down Expand Up @@ -193,14 +193,14 @@ def delete_deployment(db: Session, deployment_id: str) -> None:


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
def create_deployment_by_config(db: Session, deployment_config: DeploymentDefinition) -> Deployment:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t see this method being used anymore.

"""
Create a new deployment by config.

Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
deployment_config (DeploymentDefinition): Deployment config.

Returns:
Deployment: Created deployment.
Expand All @@ -213,7 +213,7 @@ def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure that we have deployment_class attribute for new configuration

is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
is_community=deployment_config.is_community,
)
db.add(deployment)
db.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/backend/crud/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.transaction import validate_transaction

Expand Down Expand Up @@ -157,14 +157,14 @@ def get_models_by_agent_id(
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentDefinition, model: str) -> Model:
"""
Create a new model by config if present

Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
deployment_config (DeploymentDefinition): Deployment config data.
model (str): Model data.

Returns:
Expand Down
19 changes: 13 additions & 6 deletions src/backend/database_models/seeders/deplyments_models_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from sqlalchemy import text
from sqlalchemy.orm import Session

from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName
from backend.config.deployments import ALL_MODEL_DEPLOYMENTS
from backend.database_models import Deployment, Model, Organization
from backend.model_deployments import (
CohereDeployment,
SingleContainerDeployment,
SageMakerDeployment,
AzureDeployment,
BedrockDeployment,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
Expand All @@ -18,7 +25,7 @@
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)
malexw marked this conversation as resolved.
Show resolved Hide resolved

MODELS_NAME_MAPPING = {
ModelDeploymentName.CoherePlatform: {
CohereDeployment.name(): {
"command": {
"cohere_name": "command",
"is_default": False,
Expand Down Expand Up @@ -60,7 +67,7 @@
"is_default": False,
},
},
ModelDeploymentName.SingleContainer: {
SingleContainerDeployment.name(): {
"command": {
"cohere_name": "command",
"is_default": False,
Expand Down Expand Up @@ -102,19 +109,19 @@
"is_default": False,
},
},
ModelDeploymentName.SageMaker: {
SageMakerDeployment.name(): {
"sagemaker-command": {
"cohere_name": "command",
"is_default": True,
},
},
ModelDeploymentName.Azure: {
AzureDeployment.name(): {
"azure-command": {
"cohere_name": "command-r",
"is_default": True,
},
},
ModelDeploymentName.Bedrock: {
BedrockDeployment.name(): {
"cohere.command-r-plus-v1:0": {
"cohere_name": "command-r-plus",
"is_default": True,
Expand Down
13 changes: 13 additions & 0 deletions src/backend/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class ToolkitException(Exception):
"""
Base class for all toolkit exceptions.
"""

class DeploymentNotFoundError(ToolkitException):
def __init__(self, deployment_id: str):
super(DeploymentNotFoundError, self).__init__(f"Deployment {deployment_id} not found")
self.deployment_id = deployment_id

class NoAvailableDeploymentsError(ToolkitException):
malexw marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self):
super(NoAvailableDeploymentsError, self).__init__("No deployments have been configured. Have the appropriate config values been added to configuration.yaml or secrets.yaml?")
15 changes: 15 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from backend.config.routers import ROUTER_DEPENDENCIES
from backend.config.settings import Settings
from backend.exceptions import DeploymentNotFoundError
from backend.routers.agent import router as agent_router
from backend.routers.auth import router as auth_router
from backend.routers.chat import router as chat_router
Expand Down Expand Up @@ -111,6 +112,20 @@ async def validation_exception_handler(request: Request, exc: Exception):
)


@app.exception_handler(DeploymentNotFoundError)
async def deployment_not_found_handler(request: Request, exc: DeploymentNotFoundError):
ctx = get_context(request)
logger = ctx.get_logger()
logger.error(
event="Deployment not found",
deployment_id=exc.deployment_id,
)
return JSONResponse(
status_code=404,
malexw marked this conversation as resolved.
Show resolved Hide resolved
content={"detail": str(exc)},
)


@app.on_event("startup")
async def startup_event():
"""
Expand Down
Loading