Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion ads/aqua/extension/deployment_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,15 @@ def post(self, *args, **kwargs): # noqa: ARG002
if not input_data:
raise HTTPError(400, Errors.NO_INPUT_DATA)

self.finish(AquaDeploymentApp().create(**input_data))
model_deployment_id = input_data.pop("model_deployment_id", None)
if model_deployment_id:
self.finish(
AquaDeploymentApp().update(
model_deployment_id=model_deployment_id, **input_data
)
)
else:
self.finish(AquaDeploymentApp().create(**input_data))

def read(self, id):
"""Read the information of an Aqua model deployment."""
Expand Down
226 changes: 222 additions & 4 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@
AquaDeploymentDetail,
ConfigValidationError,
CreateModelDeploymentDetails,
ModelDeploymentDetails,
UpdateModelDeploymentDetails,
)
from ads.aqua.modeldeployment.model_group_config import ModelGroupConfig
from ads.common.object_storage_details import ObjectStorageDetails
Expand All @@ -100,6 +102,9 @@
ModelDeploymentInfrastructure,
ModelDeploymentMode,
)
from ads.model.deployment.model_deployment import (
ModelDeploymentUpdateType,
)
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
from ads.telemetry import telemetry

Expand Down Expand Up @@ -385,14 +390,14 @@ def create(

def _validate_input_models(
self,
create_deployment_details: CreateModelDeploymentDetails,
deployment_details: ModelDeploymentDetails,
):
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details for stacked or multi model deployment."""
"""Validates the base models and associated fine tuned models from 'models' in create_deployment_details or update_deployment_details for stacked or multi model deployment."""
# Collect all unique model IDs (including fine-tuned models)
source_model_ids = list(
{
model_id
for model in create_deployment_details.models
for model in deployment_details.models
for model_id in model.all_model_ids()
}
)
Expand All @@ -403,7 +408,7 @@ def _validate_input_models(
source_models = self.get_multi_source(source_model_ids) or {}

try:
create_deployment_details.validate_input_models(model_details=source_models)
deployment_details.validate_input_models(model_details=source_models)
except ConfigValidationError as err:
raise AquaValueError(f"{err}") from err

Expand Down Expand Up @@ -1214,6 +1219,219 @@ def _get_container_type_key(

return container_type_key

@telemetry(entry_point="plugin=deployment&action=update", name="aqua")
def update(
self,
model_deployment_id: str,
update_model_deployment_details: Optional[UpdateModelDeploymentDetails] = None,
**kwargs,
) -> AquaDeployment:
"""Updates a AQUA model group deployment.

Args:
update_model_deployment_details : UpdateModelDeploymentDetails, optional
An instance of UpdateModelDeploymentDetails containing all optional
fields for updating a model deployment via Aqua.
kwargs:
display_name (str): The name of the model deployment.
description (Optional[str]): The description of the deployment.
models (Optional[List[AquaMultiModelRef]]): List of models for deployment.
instance_count (int): Number of instances used for deployment.
log_group_id (Optional[str]): OCI logging group ID for logs.
access_log_id (Optional[str]): OCID for access logs.
predict_log_id (Optional[str]): OCID for prediction logs.
bandwidth_mbps (Optional[int]): Bandwidth limit on the load balancer in Mbps.
web_concurrency (Optional[int]): Number of worker processes/threads for handling requests.
memory_in_gbs (Optional[float]): Memory (in GB) for the selected shape.
ocpus (Optional[float]): OCPU count for the selected shape.
freeform_tags (Optional[Dict]): Freeform tags for model deployment.
defined_tags (Optional[Dict]): Defined tags for model deployment.

Returns
-------
AquaDeployment
An Aqua deployment instance.
"""
if not update_model_deployment_details:
try:
update_model_deployment_details = UpdateModelDeploymentDetails(**kwargs)
except ValidationError as ex:
custom_errors = build_pydantic_error_message(ex)
raise AquaValueError(
f"Invalid parameters for updating a model group deployment. Error details: {custom_errors}."
) from ex

model_deployment = ModelDeployment.from_id(model_deployment_id)

infrastructure = model_deployment.infrastructure
runtime = model_deployment.runtime

if not runtime.model_group_id:
raise AquaValueError(
"Invalid 'model_deployment_id'. Only model group deployment is supported to update."
)

# updates model group if fine tuned weights changed.
model = self._update_model_group(
runtime.model_group_id, update_model_deployment_details
)

# updates model group deployment infrastructure
(
infrastructure.with_bandwidth_mbps(
update_model_deployment_details.bandwidth_mbps
or infrastructure.bandwidth_mbps
)
.with_replica(
update_model_deployment_details.instance_count or infrastructure.replica
)
.with_web_concurrency(
update_model_deployment_details.web_concurrency
or infrastructure.web_concurrency
)
)

if (
update_model_deployment_details.log_group_id
and update_model_deployment_details.access_log_id
):
infrastructure.with_access_log(
log_group_id=update_model_deployment_details.log_group_id,
log_id=update_model_deployment_details.access_log_id,
)

if (
update_model_deployment_details.log_group_id
and update_model_deployment_details.predict_log_id
):
infrastructure.with_predict_log(
log_group_id=update_model_deployment_details.log_group_id,
log_id=update_model_deployment_details.predict_log_id,
)

if (
update_model_deployment_details.memory_in_gbs
and update_model_deployment_details.ocpus
and infrastructure.shape_name.endswith("Flex")
):
infrastructure.with_shape_config_details(
ocpus=update_model_deployment_details.ocpus,
memory_in_gbs=update_model_deployment_details.memory_in_gbs,
)

# applies ZDT as default type to update parameters if model group id hasn't been changed
update_type = ModelDeploymentUpdateType.ZDT
# applies LIVE update if model group id has been changed
if runtime.model_group_id != model.id:
runtime.with_model_group_id(model.id)
update_type = ModelDeploymentUpdateType.LIVE

freeform_tags = (
update_model_deployment_details.freeform_tags
or model_deployment.freeform_tags
)
defined_tags = (
update_model_deployment_details.defined_tags
or model_deployment.defined_tags
)

# updates model group deployment
(
model_deployment.with_display_name(
update_model_deployment_details.display_name
or model_deployment.display_name
)
.with_description(
update_model_deployment_details.description
or model_deployment.description
)
.with_freeform_tags(**(freeform_tags or {}))
.with_defined_tags(**(defined_tags or {}))
.with_infrastructure(infrastructure)
.with_runtime(runtime)
)

model_deployment.update(wait_for_completion=False, update_type=update_type)

logger.info(f"Updating Aqua Model Deployment {model_deployment.id}.")

return AquaDeployment.from_oci_model_deployment(
model_deployment.dsc_model_deployment, self.region
)

def _update_model_group(
self,
model_group_id: str,
update_model_deployment_details: UpdateModelDeploymentDetails,
) -> DataScienceModelGroup:
"""Creates a new model group if fine tuned weights changed.

Parameters
----------
model_group_id: str
The model group id.
update_model_deployment_details: UpdateModelDeploymentDetails
An instance of UpdateModelDeploymentDetails containing all optional
fields for updating a model deployment via Aqua.

Returns
-------
DataScienceModelGroup
The instance of DataScienceModelGroup.
"""
model_group = DataScienceModelGroup.from_id(model_group_id)
# create a new model group if fine tune weights changed as member models in ds model group is inmutable
if update_model_deployment_details.models:
if len(update_model_deployment_details.models) != 1:
raise AquaValueError(
"Invalid 'models' provided. Only one base model is required for updating model stack deployment."
)
# validates input base and fine tune models
self._validate_input_models(update_model_deployment_details)
target_stacked_model = update_model_deployment_details.models[0]
target_base_model_id = target_stacked_model.model_id
if model_group.base_model_id != target_base_model_id:
raise AquaValueError(
"Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
)

# add member models
member_models = [
{
"inference_key": fine_tune_weight.model_name,
"model_id": fine_tune_weight.model_id,
}
for fine_tune_weight in target_stacked_model.fine_tune_weights
]
# add base model
member_models.append(
{
"inference_key": target_stacked_model.model_name,
"model_id": target_base_model_id,
}
)

# creates a model group with the same configurations from original model group except member models
model_group = (
DataScienceModelGroup()
.with_compartment_id(model_group.compartment_id)
.with_project_id(model_group.project_id)
.with_display_name(model_group.display_name)
.with_description(model_group.description)
.with_freeform_tags(**(model_group.freeform_tags or {}))
.with_defined_tags(**(model_group.defined_tags or {}))
.with_custom_metadata_list(model_group.custom_metadata_list)
.with_base_model_id(target_base_model_id)
.with_member_models(member_models)
.create()
)

logger.info(
f"Model group of base model {target_base_model_id} has been updated: {model_group.id}."
)

return model_group

@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
def list(self, **kwargs) -> List["AquaDeployment"]:
"""List Aqua model deployments in a given compartment and under certain project.
Expand Down
Loading