Skip to content

Commit

Permalink
feat: optimization technique related validations. (#4921)
Browse files Browse the repository at this point in the history
* Enable quantization and compilation in the same optimization job via ModelBuilder and add validations to block compilation jobs using TRTLLM an Llama-3.1.

* Require EULA acceptance when using a gated 1p draft model via ModelBuilder.

* add accept_draft_model_eula to JumpStartModel when deployment config with gated draft model is selected

* add map of valid optimization combinations

* Add ModelBuilder support for JumpStart-provided draft models.

* Tweak draft model EULA validations and messaging. Remove redundant deployment_config flow validation in optimize_utils in favor of the one directly on jumpstart/factory/model.

* Add "Auto" speculative decoding ModelProvider option; add validations to differentiate SageMaker/JumpStart draft models.

* Fix JumpStartModel.AdditionalModelDataSource model access config assignment.

* move the accept eula configurations into deploy flow

* move the accept eula configurations into deploy flow

* Use correct bucket for SM/JS draft models and minor formatting/validation updates.

* Remove obsolete docstring.

* remove references to accept_draft_model_eula

* renaming of eula fn and error msg

* fix: pin testing deps (#4925)

Co-authored-by: nileshvd <113946607+nileshvd@users.noreply.github.com>

* Revert "change: add TGI 2.4.0 image uri (#4922)" (#4926)

* fix naming and messaging

* ModelBuilder speculative decoding UTs and minor fixes.

* Fix set union.

* add UTs for JumpStart deployment

* fix formatting issues

* address validation comments

* fix doc strings

* Add TRTLLM compilation + speculative decoding validation.

* address nits

---------

Co-authored-by: Joseph Zhang <cjz@amazon.com>
Co-authored-by: EC2 Default User <ec2-user@ip-172-16-2-151.us-west-2.compute.internal>
Co-authored-by: Gary Wang 😤 <garywan@amazon.com>
Co-authored-by: Gary Wang <38331932+gwang111@users.noreply.github.com>
Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com>
Co-authored-by: nileshvd <113946607+nileshvd@users.noreply.github.com>
Co-authored-by: Haotian An <33510317+Captainia@users.noreply.github.com>
  • Loading branch information
8 people authored Nov 19, 2024
1 parent 4b5659d commit 64e138b
Show file tree
Hide file tree
Showing 14 changed files with 1,537 additions and 93 deletions.
16 changes: 13 additions & 3 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@


from typing import Any, Dict, List, Optional, Union
from sagemaker_core.shapes import ModelAccessConfig
from sagemaker import environment_variables, image_uris, instance_types, model_uris, script_uris
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -53,11 +54,11 @@
add_hub_content_arn_tags,
add_jumpstart_model_info_tags,
get_default_jumpstart_session_with_user_agent_suffix,
get_neo_content_bucket,
get_top_ranked_config_name,
update_dict_if_key_not_present,
resolve_model_sagemaker_config_field,
verify_model_region_and_return_specs,
get_draft_model_content_bucket,
)

from sagemaker.jumpstart.factory.utils import (
Expand All @@ -70,7 +71,12 @@

from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig
from sagemaker.session import Session
from sagemaker.utils import camel_case_to_pascal_case, name_from_base, format_tags, Tags
from sagemaker.utils import (
camel_case_to_pascal_case,
name_from_base,
format_tags,
Tags,
)
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker import resource_requirements
Expand Down Expand Up @@ -565,7 +571,9 @@ def _add_additional_model_data_sources_to_kwargs(
# Append speculative decoding data source from metadata
speculative_decoding_data_sources = specs.get_speculative_decoding_s3_data_sources()
for data_source in speculative_decoding_data_sources:
data_source.s3_data_source.set_bucket(get_neo_content_bucket(region=kwargs.region))
data_source.s3_data_source.set_bucket(
get_draft_model_content_bucket(provider=data_source.provider, region=kwargs.region)
)
api_shape_additional_model_data_sources = (
[
camel_case_to_pascal_case(data_source.to_json())
Expand Down Expand Up @@ -648,6 +656,7 @@ def get_deploy_kwargs(
training_config_name: Optional[str] = None,
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
) -> JumpStartModelDeployKwargs:
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Model` object."""

Expand Down Expand Up @@ -684,6 +693,7 @@ def get_deploy_kwargs(
resources=resources,
config_name=config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
)
deploy_kwargs, orig_session = _set_temp_sagemaker_session_if_not_set(kwargs=deploy_kwargs)
deploy_kwargs.specs = verify_model_region_and_return_specs(
Expand Down
41 changes: 31 additions & 10 deletions src/sagemaker/jumpstart/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
from botocore.exceptions import ClientError

from sagemaker_core.shapes import ModelAccessConfig
from sagemaker import payloads
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
from sagemaker.base_deserializers import BaseDeserializer
Expand Down Expand Up @@ -51,6 +52,7 @@
add_instance_rate_stats_to_benchmark_metrics,
deployment_config_response_data,
_deployment_config_lru_cache,
_add_model_access_configs_to_model_data_sources,
)
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
from sagemaker.jumpstart.enums import JumpStartModelType
Expand Down Expand Up @@ -540,12 +542,16 @@ def attach(
inferred_model_id = inferred_model_version = inferred_inference_component_name = None

if inference_component_name is None or model_id is None or model_version is None:
inferred_model_id, inferred_model_version, inferred_inference_component_name, _, _ = (
get_model_info_from_endpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)
(
inferred_model_id,
inferred_model_version,
inferred_inference_component_name,
_,
_,
) = get_model_info_from_endpoint(
endpoint_name=endpoint_name,
inference_component_name=inference_component_name,
sagemaker_session=sagemaker_session,
)

model_id = model_id or inferred_model_id
Expand Down Expand Up @@ -659,6 +665,7 @@ def deploy(
managed_instance_scaling: Optional[str] = None,
endpoint_type: EndpointType = EndpointType.MODEL_BASED,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, ModelAccessConfig]] = None,
) -> PredictorBase:
"""Creates endpoint by calling base ``Model`` class `deploy` method.
Expand Down Expand Up @@ -755,6 +762,11 @@ def deploy(
(Default: EndpointType.MODEL_BASED).
routing_config (Optional[Dict]): Settings the control how the endpoint routes
incoming traffic to the instances that the endpoint hosts.
model_access_configs (Optional[Dict[str, ModelAccessConfig]]): For models that require
ModelAccessConfig, provide a `{ "model_id", ModelAccessConfig(accept_eula=True) }`
to indicate whether model terms of use have been accepted. The `accept_eula` value
must be explicitly defined as `True` in order to accept the end-user license
agreement (EULA) that some models require. (Default: None)
Raises:
MarketplaceModelSubscriptionError: If the caller is not subscribed to the model.
Expand Down Expand Up @@ -795,6 +807,7 @@ def deploy(
model_type=self.model_type,
config_name=self.config_name,
routing_config=routing_config,
model_access_configs=model_access_configs,
)
if (
self.model_type == JumpStartModelType.PROPRIETARY
Expand All @@ -804,6 +817,13 @@ def deploy(
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
)

self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
self.additional_model_data_sources,
deploy_kwargs.model_access_configs,
deploy_kwargs.model_id,
deploy_kwargs.region,
)

try:
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())
except ClientError as e:
Expand Down Expand Up @@ -1016,10 +1036,11 @@ def _get_deployment_configs(
)

if metadata_config.benchmark_metrics:
err, metadata_config.benchmark_metrics = (
add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)
(
err,
metadata_config.benchmark_metrics,
) = add_instance_rate_stats_to_benchmark_metrics(
self.region, metadata_config.benchmark_metrics
)

config_components = metadata_config.config_components.get(config_name)
Expand Down
15 changes: 12 additions & 3 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from copy import deepcopy
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union
from sagemaker_core.shapes import ModelAccessConfig as CoreModelAccessConfig
from sagemaker.model_card.model_card import ModelCard, ModelPackageModelCard
from sagemaker.utils import (
S3_PREFIX,
Expand Down Expand Up @@ -1081,9 +1082,9 @@ def set_bucket(self, bucket: str) -> None:
class AdditionalModelDataSource(JumpStartDataHolderType):
"""Data class of additional model data source mirrors CreateModel API."""

SERIALIZATION_EXCLUSION_SET: Set[str] = set()
SERIALIZATION_EXCLUSION_SET = {"provider"}

__slots__ = ["channel_name", "s3_data_source"]
__slots__ = ["channel_name", "s3_data_source", "hosting_eula_key"]

def __init__(self, spec: Dict[str, Any]):
"""Initializes a AdditionalModelDataSource object.
Expand All @@ -1101,6 +1102,8 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
"""
self.channel_name: str = json_obj["channel_name"]
self.s3_data_source: S3DataSource = S3DataSource(json_obj["s3_data_source"])
self.hosting_eula_key: str = json_obj.get("hosting_eula_key")
self.provider: Dict = json_obj.get("provider", {})

def to_json(self, exclude_keys=True) -> Dict[str, Any]:
"""Returns json representation of AdditionalModelDataSource object."""
Expand All @@ -1119,7 +1122,9 @@ def to_json(self, exclude_keys=True) -> Dict[str, Any]:
class JumpStartModelDataSource(AdditionalModelDataSource):
"""Data class JumpStart additional model data source."""

SERIALIZATION_EXCLUSION_SET = {"artifact_version"}
SERIALIZATION_EXCLUSION_SET = AdditionalModelDataSource.SERIALIZATION_EXCLUSION_SET.union(
{"artifact_version"}
)

__slots__ = list(SERIALIZATION_EXCLUSION_SET) + AdditionalModelDataSource.__slots__

Expand Down Expand Up @@ -2239,6 +2244,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"config_name",
"routing_config",
"specs",
"model_access_configs",
]

SERIALIZATION_EXCLUSION_SET = {
Expand All @@ -2252,6 +2258,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"sagemaker_session",
"training_instance_type",
"config_name",
"model_access_configs",
}

def __init__(
Expand Down Expand Up @@ -2290,6 +2297,7 @@ def __init__(
endpoint_type: Optional[EndpointType] = None,
config_name: Optional[str] = None,
routing_config: Optional[Dict[str, Any]] = None,
model_access_configs: Optional[Dict[str, CoreModelAccessConfig]] = None,
) -> None:
"""Instantiates JumpStartModelDeployKwargs object."""

Expand Down Expand Up @@ -2327,6 +2335,7 @@ def __init__(
self.endpoint_type = endpoint_type
self.config_name = config_name
self.routing_config = routing_config
self.model_access_configs = model_access_configs


class JumpStartEstimatorInitKwargs(JumpStartKwargs):
Expand Down
93 changes: 91 additions & 2 deletions src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
"""This module contains utilities related to SageMaker JumpStart."""
from __future__ import absolute_import

from copy import copy
import logging
import os
Expand All @@ -22,6 +23,7 @@
from botocore.exceptions import ClientError
from packaging.version import Version
import botocore
from sagemaker_core.shapes import ModelAccessConfig
import sagemaker
from sagemaker.config.config_schema import (
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
Expand Down Expand Up @@ -55,6 +57,7 @@
TagsDict,
get_instance_rate_per_hour,
get_domain_for_region,
camel_case_to_pascal_case,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.user_agent import get_user_agent_extra_suffix
Expand Down Expand Up @@ -555,11 +558,18 @@ def get_eula_message(model_specs: JumpStartModelSpecs, region: str) -> str:
"""Returns EULA message to display if one is available, else empty string."""
if model_specs.hosting_eula_key is None:
return ""
return get_formatted_eula_message_template(
model_id=model_specs.model_id, region=region, hosting_eula_key=model_specs.hosting_eula_key
)


def get_formatted_eula_message_template(model_id: str, region: str, hosting_eula_key: str) -> str:
"""Returns a formatted EULA message."""
return (
f"Model '{model_specs.model_id}' requires accepting end-user license agreement (EULA). "
f"Model '{model_id}' requires accepting end-user license agreement (EULA). "
f"See https://{get_jumpstart_content_bucket(region=region)}.s3.{region}."
f"{get_domain_for_region(region)}"
f"/{model_specs.hosting_eula_key} for terms of use."
f"/{hosting_eula_key} for terms of use."
)


Expand Down Expand Up @@ -1525,3 +1535,82 @@ def wrapped_f(*args, **kwargs):
if _func is None:
return wrapper_cache
return wrapper_cache(_func)


def _add_model_access_configs_to_model_data_sources(
model_data_sources: List[Dict[str, any]],
model_access_configs: Dict[str, ModelAccessConfig],
model_id: str,
region: str,
) -> List[Dict[str, any]]:
"""Iterate over the accept EULA configs to ensure all channels are matched
Args:
model_data_sources (DeploymentConfigMetadata): Model data sources that will be updated
model_access_configs (DeploymentConfigMetadata): Config holding accept_eula field
model_id (DeploymentConfigMetadata): Jumpstart model id.
region (str): Region where the user is operating in.
Returns:
List[Dict[str, Any]]: List of model data sources with accept EULA configs applied
Raise:
ValueError if at least one channel that requires EULA acceptance as not passed.
"""
if not model_data_sources:
return model_data_sources

acked_model_data_sources = []
for model_data_source in model_data_sources:
hosting_eula_key = model_data_source.get("HostingEulaKey")
mutable_model_data_source = model_data_source.copy()
if hosting_eula_key:
if (
not model_access_configs
or not model_access_configs.get(model_id)
or not model_access_configs.get(model_id).accept_eula
):
eula_message_template = (
"{model_source}{base_eula_message}{model_access_configs_message}"
)
model_access_config_entry = (
'"{model_id}":ModelAccessConfig(accept_eula=True)'.format(model_id=model_id)
)
raise ValueError(
eula_message_template.format(
model_source="Additional " if model_data_source.get("ChannelName") else "",
base_eula_message=get_formatted_eula_message_template(
model_id=model_id, region=region, hosting_eula_key=hosting_eula_key
),
model_access_configs_message=(
"Please add a ModelAccessConfig entry:"
f" {model_access_config_entry} "
"to model_access_configs to accept the EULA."
),
)
)
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is applied
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
)
acked_model_data_sources.append(mutable_model_data_source)
else:
mutable_model_data_source.pop(
"HostingEulaKey"
) # pop when model access config is not applicable
acked_model_data_sources.append(mutable_model_data_source)
return acked_model_data_sources


def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
"""Returns the correct content bucket for a 1p draft model."""
neo_bucket = get_neo_content_bucket(region=region)
if not provider:
return neo_bucket
provider_name = provider.get("name", "")
if provider_name == "JumpStart":
classification = provider.get("classification", "ungated")
if classification == "gated":
return get_jumpstart_gated_content_bucket(region=region)
return get_jumpstart_content_bucket(region=region)
return neo_bucket
Loading

0 comments on commit 64e138b

Please sign in to comment.