Skip to content

Commit

Permalink
chore: Refactor sagemaker-endpoint to use pydantic for inputs (#100)
Browse files Browse the repository at this point in the history
* chore: Refactor sagemaker-endpoint to use pydantic for inputs

* fix pydantic dependency

* add pydantic_settings

* update changelog
  • Loading branch information
LeonLuttenberger authored May 29, 2024
1 parent e1e70aa commit 89e277a
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 105 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### **Changed**
- fixed model deploy cross-account permissions
- added bucket and model package group names as stack outputs in the `sagemaker-templates` module
- refactor inputs for `mlflow-fargate` and `mlflow-image`
- refactor inputs for `sagemaker-studio`
- refactor inputs for the following modules to use Pydantic:
- `mlflow-fargate`
- `mlflow-image`
- `sagemaker-studio`
- `sagemaker-endpoint`
- rename seedfarmer project name to `aiops`
- chore: adding some missing auto_delete attributes
- chore: Add `auto_delete` to `mlflow-fargate` elb access logs bucket
Expand Down
106 changes: 32 additions & 74 deletions modules/sagemaker/sagemaker-endpoint/app.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,40 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import json
import os

import aws_cdk
import cdk_nag

from settings import ApplicationSettings
from stack import DeployEndpointStack


def _param(name: str) -> str:
return f"SEEDFARMER_PARAMETER_{name}"


project_name = os.getenv("SEEDFARMER_PROJECT_NAME", "")
deployment_name = os.getenv("SEEDFARMER_DEPLOYMENT_NAME", "")
module_name = os.getenv("SEEDFARMER_MODULE_NAME", "")
app_prefix = f"{project_name}-{deployment_name}-{module_name}"

DEFAULT_SAGEMAKER_PROJECT_ID = None
DEFAULT_SAGEMAKER_PROJECT_NAME = None
DEFAULT_MODEL_PACKAGE_ARN = None
DEFAULT_MODEL_PACKAGE_GROUP_NAME = None
DEFAULT_MODEL_EXECUTION_ROLE_ARN = None
DEFAULT_MODEL_ARTIFACTS_BUCKET_ARN = None
DEFAULT_ECR_REPO_ARN = None
DEFAULT_VARIANT_NAME = "AllTraffic"
DEFAULT_INITIAL_INSTANCE_COUNT = 1
DEFAULT_INITIAL_VARIANT_WEIGHT = 1
DEFAULT_INSTANCE_TYPE = "ml.m4.xlarge"
DEFAULT_SCALING_MIN_INSTANCE_COUNT = 1
DEFAULT_SCALING_MAX_INSTANCE_COUNT = 10

environment = aws_cdk.Environment(
account=os.environ["CDK_DEFAULT_ACCOUNT"],
region=os.environ["CDK_DEFAULT_REGION"],
)

vpc_id = os.getenv(_param("VPC_ID"))
subnet_ids = json.loads(os.getenv(_param("SUBNET_IDS"), "[]"))
sagemaker_project_id = os.getenv(_param("SAGEMAKER_PROJECT_ID"), DEFAULT_SAGEMAKER_PROJECT_ID)
sagemaker_project_name = os.getenv(_param("SAGEMAKER_PROJECT_NAME"), DEFAULT_SAGEMAKER_PROJECT_NAME)
model_package_arn = os.getenv(_param("MODEL_PACKAGE_ARN"), DEFAULT_MODEL_PACKAGE_ARN)
model_package_group_name = os.getenv(_param("MODEL_PACKAGE_GROUP_NAME"), DEFAULT_MODEL_PACKAGE_GROUP_NAME)
model_execution_role_arn = os.getenv(_param("MODEL_EXECUTION_ROLE_ARN"), DEFAULT_MODEL_EXECUTION_ROLE_ARN)
model_artifacts_bucket_arn = os.getenv(_param("MODEL_ARTIFACTS_BUCKET_ARN"), DEFAULT_MODEL_ARTIFACTS_BUCKET_ARN)
ecr_repo_arn = os.getenv(_param("ECR_REPO_ARN"), DEFAULT_ECR_REPO_ARN)
variant_name = os.getenv(_param("VARIANT_NAME"), DEFAULT_VARIANT_NAME)
initial_instance_count = int(os.getenv(_param("INITIAL_INSTANCE_COUNT"), DEFAULT_INITIAL_INSTANCE_COUNT))
initial_variant_weight = int(os.getenv(_param("INITIAL_VARIANT_WEIGHT"), DEFAULT_INITIAL_VARIANT_WEIGHT))
instance_type = os.getenv(_param("INSTANCE_TYPE"), DEFAULT_INSTANCE_TYPE)
managed_instance_scaling = bool(os.getenv(_param("MANAGED_INSTANCE_SCALING"), False))
scaling_min_instance_count = int(os.getenv(_param("SCALING_MIN_INSTANCE_COUNT"), DEFAULT_SCALING_MIN_INSTANCE_COUNT))
scaling_max_instance_count = int(os.getenv(_param("SCALING_MAX_INSTANCE_COUNT"), DEFAULT_SCALING_MAX_INSTANCE_COUNT))

if not vpc_id:
raise ValueError("Missing input parameter vpc-id")

if not model_package_arn and not model_package_group_name:
raise ValueError("Parameter model-package-arn or model-package-group-name is required")


app = aws_cdk.App()
app_settings = ApplicationSettings()

stack = DeployEndpointStack(
scope=app,
id=app_prefix,
sagemaker_project_id=sagemaker_project_id,
sagemaker_project_name=sagemaker_project_name,
model_package_arn=model_package_arn,
model_package_group_name=model_package_group_name,
model_execution_role_arn=model_execution_role_arn,
vpc_id=vpc_id,
subnet_ids=subnet_ids,
model_artifacts_bucket_arn=model_artifacts_bucket_arn,
ecr_repo_arn=ecr_repo_arn,
id=app_settings.seedfarmer_settings.app_prefix,
sagemaker_project_id=app_settings.module_settings.sagemaker_project_id,
sagemaker_project_name=app_settings.module_settings.sagemaker_project_name,
model_package_arn=app_settings.module_settings.model_package_arn,
model_package_group_name=app_settings.module_settings.model_package_group_name,
model_execution_role_arn=app_settings.module_settings.model_execution_role_arn,
vpc_id=app_settings.module_settings.vpc_id,
subnet_ids=app_settings.module_settings.subnet_ids,
model_artifacts_bucket_arn=app_settings.module_settings.model_artifacts_bucket_arn,
ecr_repo_arn=app_settings.module_settings.ecr_repo_arn,
endpoint_config_prod_variant={
"initial_instance_count": initial_instance_count,
"initial_variant_weight": initial_variant_weight,
"instance_type": instance_type,
"variant_name": variant_name,
"initial_instance_count": app_settings.module_settings.initial_instance_count,
"initial_variant_weight": app_settings.module_settings.initial_variant_weight,
"instance_type": app_settings.module_settings.instance_type,
"variant_name": app_settings.module_settings.variant_name,
},
managed_instance_scaling=managed_instance_scaling,
scaling_min_instance_count=scaling_min_instance_count,
scaling_max_instance_count=scaling_max_instance_count,
env=environment,
managed_instance_scaling=app_settings.module_settings.managed_instance_scaling,
scaling_min_instance_count=app_settings.module_settings.scaling_min_instance_count,
scaling_max_instance_count=app_settings.module_settings.scaling_max_instance_count,
env=aws_cdk.Environment(
account=app_settings.cdk_settings.account,
region=app_settings.cdk_settings.region,
),
)

aws_cdk.CfnOutput(
Expand All @@ -103,4 +53,12 @@ def _param(name: str) -> str:

aws_cdk.Aspects.of(app).add(cdk_nag.AwsSolutionsChecks(log_ignores=True))

if app_settings.module_settings.tags:
for tag_key, tag_value in app_settings.module_settings.tags.items():
aws_cdk.Tags.of(app).add(tag_key, tag_value)

aws_cdk.Tags.of(app).add("SeedFarmerDeploymentName", app_settings.seedfarmer_settings.deployment_name)
aws_cdk.Tags.of(app).add("SeedFarmerModuleName", app_settings.seedfarmer_settings.module_name)
aws_cdk.Tags.of(app).add("SeedFarmerProjectName", app_settings.seedfarmer_settings.project_name)

app.synth()
3 changes: 0 additions & 3 deletions modules/sagemaker/sagemaker-endpoint/coverage.ini

This file was deleted.

1 change: 0 additions & 1 deletion modules/sagemaker/sagemaker-endpoint/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ fixable = ["ALL"]
[tool.mypy]
python_version = "3.8"
strict = true
ignore_missing_imports = true
disallow_untyped_decorators = false
exclude = "codeseeder.out/|example/|tests/"
warn_unused_ignores = false
Expand Down
5 changes: 3 additions & 2 deletions modules/sagemaker/sagemaker-endpoint/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
aws-cdk-lib==2.126.0
cdk-nag==2.28.27
yamldataclassconfig==1.5.0
boto3==1.34.35
boto3==1.34.35
pydantic==2.7.2
pydantic-settings==2.2.1
96 changes: 96 additions & 0 deletions modules/sagemaker/sagemaker-endpoint/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Defines the stack settings."""

from abc import ABC
from typing import Dict, List, Optional

from pydantic import Field, computed_field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict


class CdkBaseSettings(BaseSettings, ABC):
"""Defines common configuration for settings."""

model_config = SettingsConfigDict(
case_sensitive=False,
env_nested_delimiter="__",
protected_namespaces=(),
extra="ignore",
populate_by_name=True,
)


class ModuleSettings(CdkBaseSettings):
"""Seedfarmer Parameters.
These parameters are required for the module stack.
"""

model_config = SettingsConfigDict(env_prefix="SEEDFARMER_PARAMETER_")

vpc_id: str
subnet_ids: List[str]

sagemaker_project_id: Optional[str] = Field(default=None)
sagemaker_project_name: Optional[str] = Field(default=None)
model_package_arn: Optional[str] = Field(default=None)
model_package_group_name: Optional[str] = Field(default=None)
model_execution_role_arn: Optional[str] = Field(default=None)
model_artifacts_bucket_arn: Optional[str] = Field(default=None)
ecr_repo_arn: Optional[str] = Field(default=None)
variant_name: str = Field(default="AllTraffic")

initial_instance_count: int = Field(default=1)
initial_variant_weight: int = Field(default=1)
instance_type: str = Field(default="ml.m4.xlarge")
managed_instance_scaling: bool = Field(default=False)
scaling_min_instance_count: int = Field(default=1)
scaling_max_instance_count: int = Field(default=10)

tags: Optional[Dict[str, str]] = Field(default=None)

@model_validator(mode="after")
def check_model_package(self) -> "ModuleSettings":
if not self.model_package_arn and not self.model_package_group_name:
raise ValueError("Parameter model-package-arn or model-package-group-name is required")

return self


class SeedFarmerSettings(CdkBaseSettings):
"""Seedfarmer Settings.
These parameters comes from seedfarmer by default.
"""

model_config = SettingsConfigDict(env_prefix="SEEDFARMER_")

project_name: str = Field(default="")
deployment_name: str = Field(default="")
module_name: str = Field(default="")

@computed_field # type: ignore
@property
def app_prefix(self) -> str:
"""Application prefix."""
prefix = "-".join([self.project_name, self.deployment_name, self.module_name])
return prefix


class CDKSettings(CdkBaseSettings):
"""CDK Default Settings.
These parameters comes from AWS CDK by default.
"""

model_config = SettingsConfigDict(env_prefix="CDK_DEFAULT_")

account: str
region: str


class ApplicationSettings(CdkBaseSettings):
"""Application settings."""

seedfarmer_settings: SeedFarmerSettings = Field(default_factory=SeedFarmerSettings)
module_settings: ModuleSettings = Field(default_factory=ModuleSettings)
cdk_settings: CDKSettings = Field(default_factory=CDKSettings)
39 changes: 23 additions & 16 deletions modules/sagemaker/sagemaker-endpoint/tests/test_app.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,40 @@
import os
import sys
from unittest import mock

import pytest
from pydantic import ValidationError


@pytest.fixture(scope="function")
@pytest.fixture(scope="function", autouse=True)
def stack_defaults():
os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project"
os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment"
os.environ["SEEDFARMER_MODULE_NAME"] = "test-module"
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"
with mock.patch.dict(os.environ, {}, clear=True):
os.environ["SEEDFARMER_PROJECT_NAME"] = "test-project"
os.environ["SEEDFARMER_DEPLOYMENT_NAME"] = "test-deployment"
os.environ["SEEDFARMER_MODULE_NAME"] = "test-module"

os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345"
os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_ID"] = "12345"
os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_NAME"] = "sagemaker-project"
os.environ["SEEDFARMER_PARAMETER_MODEL_PACKAGE_ARN"] = "example-arn"
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"

# Unload the app import so that subsequent tests don't reuse
if "app" in sys.modules:
del sys.modules["app"]
os.environ["SEEDFARMER_PARAMETER_VPC_ID"] = "vpc-12345"
os.environ["SEEDFARMER_PARAMETER_SUBNET_IDS"] = '["subnet-1","subnet-2","subnet-3"]'
os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_ID"] = "12345"
os.environ["SEEDFARMER_PARAMETER_SAGEMAKER_PROJECT_NAME"] = "sagemaker-project"
os.environ["SEEDFARMER_PARAMETER_MODEL_PACKAGE_ARN"] = "example-arn"

# Unload the app import so that subsequent tests don't reuse
if "app" in sys.modules:
del sys.modules["app"]

def test_app(stack_defaults):
yield None


def test_app() -> None:
import app # noqa: F401


def test_vpc_id(stack_defaults):
def test_vpc_id() -> None:
del os.environ["SEEDFARMER_PARAMETER_VPC_ID"]

with pytest.raises(Exception, match="Missing input parameter vpc-id"):
with pytest.raises(ValidationError):
import app # noqa: F401
17 changes: 10 additions & 7 deletions modules/sagemaker/sagemaker-endpoint/tests/test_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
from botocore.stub import Stubber


@pytest.fixture(scope="function")
def stack_defaults() -> None:
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"
@pytest.fixture(scope="function", autouse=True)
def stack_defaults():
with mock.patch.dict(os.environ, {}, clear=True):
os.environ["CDK_DEFAULT_ACCOUNT"] = "111111111111"
os.environ["CDK_DEFAULT_REGION"] = "us-east-1"

# Unload the app import so that subsequent tests don't reuse
if "stack" in sys.modules:
del sys.modules["stack"]

# Unload the app import so that subsequent tests don't reuse
if "stack" in sys.modules:
del sys.modules["stack"]
yield


@pytest.fixture(scope="function")
Expand Down

0 comments on commit 89e277a

Please sign in to comment.