diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f557b8a..54150ab4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### **Changed** - fixed model deploy cross-account permissions - added bucket and model package group names as stack outputs in the `sagemaker-templates` module -- refactor inputs for `mlflow-fargate` and `mlflow-image` -- refactor inputs for `sagemaker-studio` +- refactor inputs for the following modules to use Pydantic: + - `mlflow-fargate` + - `mlflow-image` + - `sagemaker-studio` + - `sagemaker-endpoint` - rename seedfarmer project name to `aiops` - chore: adding some missing auto_delete attributes - chore: Add `auto_delete` to `mlflow-fargate` elb access logs bucket diff --git a/modules/sagemaker/sagemaker-endpoint/app.py b/modules/sagemaker/sagemaker-endpoint/app.py index 5aa61788..da27b2e2 100644 --- a/modules/sagemaker/sagemaker-endpoint/app.py +++ b/modules/sagemaker/sagemaker-endpoint/app.py @@ -1,90 +1,40 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import json -import os - import aws_cdk import cdk_nag +from settings import ApplicationSettings from stack import DeployEndpointStack - -def _param(name: str) -> str: - return f"SEEDFARMER_PARAMETER_{name}" - - -project_name = os.getenv("SEEDFARMER_PROJECT_NAME", "") -deployment_name = os.getenv("SEEDFARMER_DEPLOYMENT_NAME", "") -module_name = os.getenv("SEEDFARMER_MODULE_NAME", "") -app_prefix = f"{project_name}-{deployment_name}-{module_name}" - -DEFAULT_SAGEMAKER_PROJECT_ID = None -DEFAULT_SAGEMAKER_PROJECT_NAME = None -DEFAULT_MODEL_PACKAGE_ARN = None -DEFAULT_MODEL_PACKAGE_GROUP_NAME = None -DEFAULT_MODEL_EXECUTION_ROLE_ARN = None -DEFAULT_MODEL_ARTIFACTS_BUCKET_ARN = None -DEFAULT_ECR_REPO_ARN = None -DEFAULT_VARIANT_NAME = "AllTraffic" -DEFAULT_INITIAL_INSTANCE_COUNT = 1 -DEFAULT_INITIAL_VARIANT_WEIGHT = 1 -DEFAULT_INSTANCE_TYPE = "ml.m4.xlarge" -DEFAULT_SCALING_MIN_INSTANCE_COUNT = 1 -DEFAULT_SCALING_MAX_INSTANCE_COUNT = 10 - -environment = aws_cdk.Environment( - account=os.environ["CDK_DEFAULT_ACCOUNT"], - region=os.environ["CDK_DEFAULT_REGION"], -) - -vpc_id = os.getenv(_param("VPC_ID")) -subnet_ids = json.loads(os.getenv(_param("SUBNET_IDS"), "[]")) -sagemaker_project_id = os.getenv(_param("SAGEMAKER_PROJECT_ID"), DEFAULT_SAGEMAKER_PROJECT_ID) -sagemaker_project_name = os.getenv(_param("SAGEMAKER_PROJECT_NAME"), DEFAULT_SAGEMAKER_PROJECT_NAME) -model_package_arn = os.getenv(_param("MODEL_PACKAGE_ARN"), DEFAULT_MODEL_PACKAGE_ARN) -model_package_group_name = os.getenv(_param("MODEL_PACKAGE_GROUP_NAME"), DEFAULT_MODEL_PACKAGE_GROUP_NAME) -model_execution_role_arn = os.getenv(_param("MODEL_EXECUTION_ROLE_ARN"), DEFAULT_MODEL_EXECUTION_ROLE_ARN) -model_artifacts_bucket_arn = os.getenv(_param("MODEL_ARTIFACTS_BUCKET_ARN"), DEFAULT_MODEL_ARTIFACTS_BUCKET_ARN) -ecr_repo_arn = os.getenv(_param("ECR_REPO_ARN"), DEFAULT_ECR_REPO_ARN) -variant_name = os.getenv(_param("VARIANT_NAME"), DEFAULT_VARIANT_NAME) -initial_instance_count = int(os.getenv(_param("INITIAL_INSTANCE_COUNT"), DEFAULT_INITIAL_INSTANCE_COUNT)) -initial_variant_weight = int(os.getenv(_param("INITIAL_VARIANT_WEIGHT"), DEFAULT_INITIAL_VARIANT_WEIGHT)) -instance_type = os.getenv(_param("INSTANCE_TYPE"), DEFAULT_INSTANCE_TYPE) -managed_instance_scaling = bool(os.getenv(_param("MANAGED_INSTANCE_SCALING"), False)) -scaling_min_instance_count = int(os.getenv(_param("SCALING_MIN_INSTANCE_COUNT"), DEFAULT_SCALING_MIN_INSTANCE_COUNT)) -scaling_max_instance_count = int(os.getenv(_param("SCALING_MAX_INSTANCE_COUNT"), DEFAULT_SCALING_MAX_INSTANCE_COUNT)) - -if not vpc_id: - raise ValueError("Missing input parameter vpc-id") - -if not model_package_arn and not model_package_group_name: - raise ValueError("Parameter model-package-arn or model-package-group-name is required") - - app = aws_cdk.App() +app_settings = ApplicationSettings() + stack = DeployEndpointStack( scope=app, - id=app_prefix, - sagemaker_project_id=sagemaker_project_id, - sagemaker_project_name=sagemaker_project_name, - model_package_arn=model_package_arn, - model_package_group_name=model_package_group_name, - model_execution_role_arn=model_execution_role_arn, - vpc_id=vpc_id, - subnet_ids=subnet_ids, - model_artifacts_bucket_arn=model_artifacts_bucket_arn, - ecr_repo_arn=ecr_repo_arn, + id=app_settings.seedfarmer_settings.app_prefix, + sagemaker_project_id=app_settings.module_settings.sagemaker_project_id, + sagemaker_project_name=app_settings.module_settings.sagemaker_project_name, + model_package_arn=app_settings.module_settings.model_package_arn, + model_package_group_name=app_settings.module_settings.model_package_group_name, + model_execution_role_arn=app_settings.module_settings.model_execution_role_arn, + vpc_id=app_settings.module_settings.vpc_id, + subnet_ids=app_settings.module_settings.subnet_ids, + model_artifacts_bucket_arn=app_settings.module_settings.model_artifacts_bucket_arn, + ecr_repo_arn=app_settings.module_settings.ecr_repo_arn, endpoint_config_prod_variant={ - "initial_instance_count": initial_instance_count, - "initial_variant_weight": initial_variant_weight, - "instance_type": instance_type, - "variant_name": variant_name, + "initial_instance_count": app_settings.module_settings.initial_instance_count, + "initial_variant_weight": app_settings.module_settings.initial_variant_weight, + "instance_type": app_settings.module_settings.instance_type, + "variant_name": app_settings.module_settings.variant_name, }, - managed_instance_scaling=managed_instance_scaling, - scaling_min_instance_count=scaling_min_instance_count, - scaling_max_instance_count=scaling_max_instance_count, - env=environment, + managed_instance_scaling=app_settings.module_settings.managed_instance_scaling, + scaling_min_instance_count=app_settings.module_settings.scaling_min_instance_count, + scaling_max_instance_count=app_settings.module_settings.scaling_max_instance_count, + env=aws_cdk.Environment( + account=app_settings.cdk_settings.account, + region=app_settings.cdk_settings.region, + ), ) aws_cdk.CfnOutput( @@ -103,4 +53,12 @@ def _param(name: str) -> str: aws_cdk.Aspects.of(app).add(cdk_nag.AwsSolutionsChecks(log_ignores=True)) +if app_settings.module_settings.tags: + for tag_key, tag_value in app_settings.module_settings.tags.items(): + aws_cdk.Tags.of(app).add(tag_key, tag_value) + +aws_cdk.Tags.of(app).add("SeedFarmerDeploymentName", app_settings.seedfarmer_settings.deployment_name) +aws_cdk.Tags.of(app).add("SeedFarmerModuleName", app_settings.seedfarmer_settings.module_name) +aws_cdk.Tags.of(app).add("SeedFarmerProjectName", app_settings.seedfarmer_settings.project_name) + app.synth() diff --git a/modules/sagemaker/sagemaker-endpoint/coverage.ini b/modules/sagemaker/sagemaker-endpoint/coverage.ini deleted file mode 100644 index c3878739..00000000 --- a/modules/sagemaker/sagemaker-endpoint/coverage.ini +++ /dev/null @@ -1,3 +0,0 @@ -[run] -omit = - tests/* \ No newline at end of file diff --git a/modules/sagemaker/sagemaker-endpoint/pyproject.toml b/modules/sagemaker/sagemaker-endpoint/pyproject.toml index ef8db5f9..0a619237 100644 --- a/modules/sagemaker/sagemaker-endpoint/pyproject.toml +++ b/modules/sagemaker/sagemaker-endpoint/pyproject.toml @@ -23,7 +23,6 @@ fixable = ["ALL"] [tool.mypy] python_version = "3.8" strict = true -ignore_missing_imports = true disallow_untyped_decorators = false exclude = "codeseeder.out/|example/|tests/" warn_unused_ignores = false diff --git a/modules/sagemaker/sagemaker-endpoint/requirements.txt b/modules/sagemaker/sagemaker-endpoint/requirements.txt index ec40e171..9797ea41 100644 --- a/modules/sagemaker/sagemaker-endpoint/requirements.txt +++ b/modules/sagemaker/sagemaker-endpoint/requirements.txt @@ -1,4 +1,5 @@ aws-cdk-lib==2.126.0 cdk-nag==2.28.27 -yamldataclassconfig==1.5.0 -boto3==1.34.35 \ No newline at end of file +boto3==1.34.35 +pydantic==2.7.2 +pydantic-settings==2.2.1 diff --git a/modules/sagemaker/sagemaker-endpoint/settings.py b/modules/sagemaker/sagemaker-endpoint/settings.py new file mode 100644 index 00000000..ceae6f5d --- /dev/null +++ b/modules/sagemaker/sagemaker-endpoint/settings.py @@ -0,0 +1,96 @@ +"""Defines the stack settings.""" + +from abc import ABC +from typing import Dict, List, Optional + +from pydantic import Field, computed_field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class CdkBaseSettings(BaseSettings, ABC): + """Defines common configuration for settings.""" + + model_config = SettingsConfigDict( + case_sensitive=False, + env_nested_delimiter="__", + protected_namespaces=(), + extra="ignore", + populate_by_name=True, + ) + + +class ModuleSettings(CdkBaseSettings): + """Seedfarmer Parameters. + + These parameters are required for the module stack. + """ + + model_config = SettingsConfigDict(env_prefix="SEEDFARMER_PARAMETER_") + + vpc_id: str + subnet_ids: List[str] + + sagemaker_project_id: Optional[str] = Field(default=None) + sagemaker_project_name: Optional[str] = Field(default=None) + model_package_arn: Optional[str] = Field(default=None) + model_package_group_name: Optional[str] = Field(default=None) + model_execution_role_arn: Optional[str] = Field(default=None) + model_artifacts_bucket_arn: Optional[str] = Field(default=None) + ecr_repo_arn: Optional[str] = Field(default=None) + variant_name: str = Field(default="AllTraffic") + + initial_instance_count: int = Field(default=1) + initial_variant_weight: int = Field(default=1) + instance_type: str = Field(default="ml.m4.xlarge") + managed_instance_scaling: bool = Field(default=False) + scaling_min_instance_count: int = Field(default=1) + scaling_max_instance_count: int = Field(default=10) + + tags: Optional[Dict[str, str]] = Field(default=None) + + @model_validator(mode="after") + def check_model_package(self) -> "ModuleSettings": + if not self.model_package_arn and not self.model_package_group_name: + raise ValueError("Parameter model-package-arn or model-package-group-name is required") + + return self + + +class SeedFarmerSettings(CdkBaseSettings): + """Seedfarmer Settings. + + These parameters comes from seedfarmer by default. + """ + + model_config = SettingsConfigDict(env_prefix="SEEDFARMER_") + + project_name: str = Field(default="") + deployment_name: str = Field(default="") + module_name: str = Field(default="") + + @computed_field # type: ignore + @property + def app_prefix(self) -> str: + """Application prefix.""" + prefix = "-".join([self.project_name, self.deployment_name, self.module_name]) + return prefix + + +class CDKSettings(CdkBaseSettings): + """CDK Default Settings. + + These parameters comes from AWS CDK by default. + """ + + model_config = SettingsConfigDict(env_prefix="CDK_DEFAULT_") + + account: str + region: str + + +class ApplicationSettings(CdkBaseSettings): + """Application settings.""" + + seedfarmer_settings: SeedFarmerSettings = Field(default_factory=SeedFarmerSettings) + module_settings: ModuleSettings = Field(default_factory=ModuleSettings) + cdk_settings: CDKSettings = Field(default_factory=CDKSettings) diff --git a/modules/sagemaker/sagemaker-endpoint/tests/test_app.py b/modules/sagemaker/sagemaker-endpoint/tests/test_app.py index 53842afe..78aa9748 100644 --- a/modules/sagemaker/sagemaker-endpoint/tests/test_app.py +++ b/modules/sagemaker/sagemaker-endpoint/tests/test_app.py @@ -1,33 +1,40 @@ import os import sys +from unittest import mock import pytest +from pydantic import ValidationError -@pytest.fixture(scope="function") +@pytest.fixture(scope="function", autouse=True) def stack_defaults(): - os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project" - os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment" - os.environ["SEEDFARMER_MODULE_NAME"] = "test-module" - os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" - os.environ["CDK_DEFAULT_REGION"] = "us-east-1" + with mock.patch.dict(os.environ, {}, clear=True): + os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project" + os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment" + os.environ["SEEDFARMER_MODULE_NAME"] = "test-module" - os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345" - os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_ID"] = "12345" - os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_NAME"] = "sagemaker-project" - os.environ["SEEDFARMER_PARAMETER_MODEL_PACKAGE_ARN"] = "example-arn" + os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" + os.environ["CDK_DEFAULT_REGION"] = "us-east-1" - # Unload the app import so that subsequent tests don't reuse - if "app" in sys.modules: - del sys.modules["app"] + os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345" + os.environ["SEEDFARMER_PARAMETER_SUBNET_IDS"] = '["subnet-1","subnet-2","subnet-3"]' + os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_ID"] = "12345" + os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_NAME"] = "sagemaker-project" + os.environ["SEEDFARMER_PARAMETER_MODEL_PACKAGE_ARN"] = "example-arn" + # Unload the app import so that subsequent tests don't reuse + if "app" in sys.modules: + del sys.modules["app"] -def test_app(stack_defaults): + yield None + + +def test_app() -> None: import app # noqa: F401 -def test_vpc_id(stack_defaults): +def test_vpc_id() -> None: del os.environ["SEEDFARMER_PARAMETER_VPC_ID"] - with pytest.raises(Exception, match="Missing input parameter vpc-id"): + with pytest.raises(ValidationError): import app # noqa: F401 diff --git a/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py b/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py index f1052735..c37faaf6 100644 --- a/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py +++ b/modules/sagemaker/sagemaker-endpoint/tests/test_stack.py @@ -10,14 +10,17 @@ from botocore.stub import Stubber -@pytest.fixture(scope="function") -def stack_defaults() -> None: - os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" - os.environ["CDK_DEFAULT_REGION"] = "us-east-1" +@pytest.fixture(scope="function", autouse=True) +def stack_defaults(): + with mock.patch.dict(os.environ, {}, clear=True): + os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111" + os.environ["CDK_DEFAULT_REGION"] = "us-east-1" + + # Unload the app import so that subsequent tests don't reuse + if "stack" in sys.modules: + del sys.modules["stack"] - # Unload the app import so that subsequent tests don't reuse - if "stack" in sys.modules: - del sys.modules["stack"] + yield @pytest.fixture(scope="function")