Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with duplicate resources - backwards compatibility #416

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 95 additions & 10 deletions backend/dataall/cdkproxy/stacks/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aws_cdk import (
custom_resources as cr,
aws_ec2 as ec2,
aws_sagemaker as sagemaker,
aws_s3 as s3,
aws_s3_deployment,
aws_iam as iam,
Expand All @@ -26,15 +27,16 @@
Tags,
)
from constructs import DependencyGroup
from botocore.exceptions import ClientError

from .manager import stack
from .pivot_role import PivotRole
from .sagemakerstudio import SageMakerDomain
from .policies.data_policy import DataPolicy
from .policies.service_policy import ServicePolicy
from ... import db
from ...aws.handlers.parameter_store import ParameterStoreManager
from ...aws.handlers.sts import SessionHelper
from ...aws.handlers.sagemaker_studio import SagemakerStudio
from ...db import models
from ...utils.cdk_nag_utils import CDKNagUtil
from ...utils.runtime_stacks_tagging import TagsUtil
Expand Down Expand Up @@ -131,6 +133,28 @@ def get_all_environment_datasets(engine, environment: models.Environment) -> [mo
.all()
)

def check_existing_sagemaker_studio_domain(self, environment: models.Environment):
logger.info('Check if there is an existing sagemaker studio domain in the account')
try:
logger.info('check sagemaker studio domain created as part of data.all environment stack.')
cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
accountid=environment.AwsAccountId, region=environment.region
)
dataall_created_domain = ParameterStoreManager.client(
AwsAccountId=environment.AwsAccountId, region=environment.region,
role=cdk_look_up_role_arn
).get_parameter(
Name=f'/dataall/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id')
return False
except ClientError as e:
logger.info(
f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}')
existing_domain = SagemakerStudio.get_sagemaker_studio_domain(
AwsAccountId=environment.AwsAccountId, region=environment.region,
role=cdk_look_up_role_arn
)
return existing_domain.get('DomainId', False)

def __init__(self, scope, id, target_uri: str = None, **kwargs):
super().__init__(
scope,
Expand Down Expand Up @@ -555,8 +579,9 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs):
)

# Create or import SageMaker Studio domain if ML Studio enabled
if self._environment.mlStudiosEnabled:
# Create dependency group - Sagemaker depends on group IAM roles
self.sagemaker_domain_exists = self.check_existing_sagemaker_studio_domain(environment=self._environment)
if self._environment.mlStudiosEnabled and not self.sagemaker_domain_exists:
# Create dependency group - Sagemaker KMS key policy depends on group IAM roles
sagemaker_dependency_group = DependencyGroup()
sagemaker_dependency_group.add(default_role)
for group_role in group_roles:
Expand All @@ -575,13 +600,73 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs):
f"Default VPC not found, Exception: {e}. If you don't own a default VPC, modify the networking configuration, or disable ML Studio upon environment creation."
)

sagemaker_domain_stack = SageMakerDomain(self, 'SageMakerDomain',
environment=self._environment,
sagemaker_principals=[default_role] + group_roles,
vpc_id=vpc_id,
subnet_ids=subnet_ids
)
sagemaker_domain_stack.node.add_dependency(sagemaker_dependency_group)
sagemaker_domain_role = iam.Role(
self,
'RoleForSagemakerStudioUsers',
assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'),
role_name='RoleSagemakerStudioUsers',
managed_policies=[
iam.ManagedPolicy.from_managed_policy_arn(
self,
id='SagemakerFullAccess',
managed_policy_arn='arn:aws:iam::aws:policy/AmazonSageMakerFullAccess',
),
iam.ManagedPolicy.from_managed_policy_arn(
self, id='S3FullAccess',
managed_policy_arn='arn:aws:iam::aws:policy/AmazonS3FullAccess'
),
],
)

sagemaker_domain_key = kms.Key(
self,
'SagemakerDomainKmsKey',
alias='SagemakerStudioDomain',
enable_key_rotation=True,
policy=iam.PolicyDocument(
assign_sids=True,
statements=[
iam.PolicyStatement(
resources=['*'],
effect=iam.Effect.ALLOW,
principals=[
iam.AccountPrincipal(account_id=self._environment.AwsAccountId),
sagemaker_domain_role,
default_role,
] + group_roles,
actions=['kms:*'],
)
],
),
)
sagemaker_domain_key.node.add_dependency(sagemaker_dependency_group)

sagemaker_domain = sagemaker.CfnDomain(
self,
'SagemakerStudioDomain',
domain_name=f'SagemakerStudioDomain-{self._environment.region}-{self._environment.AwsAccountId}',
auth_mode='IAM',
default_user_settings=sagemaker.CfnDomain.UserSettingsProperty(
execution_role=sagemaker_domain_role.role_arn,
security_groups=[],
sharing_settings=sagemaker.CfnDomain.SharingSettingsProperty(
notebook_output_option='Allowed',
s3_kms_key_id=sagemaker_domain_key.key_id,
s3_output_path=f's3://sagemaker-{self._environment.region}-{self._environment.AwsAccountId}',
),
),
vpc_id=vpc_id,
subnet_ids=subnet_ids,
app_network_access_type='VpcOnly',
kms_key_id=sagemaker_domain_key.key_id,
)

ssm.StringParameter(
self,
'SagemakerStudioDomainId',
string_value=sagemaker_domain.attr_domain_id,
parameter_name=f'/dataall/{self._environment.environmentUri}/sagemaker/sagemakerstudio/domain_id',
)

# print the IAM role arn for this service account
CfnOutput(
Expand Down
107 changes: 0 additions & 107 deletions backend/dataall/cdkproxy/stacks/sagemakerstudio.py
Original file line number Diff line number Diff line change
@@ -1,126 +1,19 @@
import logging
import os
import pathlib
from constructs import Construct
from aws_cdk import (
cloudformation_include as cfn_inc,
aws_ec2 as ec2,
aws_iam as iam,
aws_kms as kms,
aws_lambda as _lambda,
aws_sagemaker as sagemaker,
aws_ssm as ssm,
CustomResource,
Duration,
NestedStack,
Stack
)
from botocore.exceptions import ClientError
from .manager import stack
from ... import db
from ...db import models
from ...db.api import Environment
from ...aws.handlers.parameter_store import ParameterStoreManager
from ...aws.handlers.sts import SessionHelper
from ...aws.handlers.sagemaker_studio import (
SagemakerStudio,
)
from ...utils.cdk_nag_utils import CDKNagUtil
from ...utils.runtime_stacks_tagging import TagsUtil

logger = logging.getLogger(__name__)


class SageMakerDomain(NestedStack):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we delete SageMakerDomain ? If we decided not to go with the nested stack and rollback it, I'd still leave this class, remove only NestedStack, changed init to create_ml_studio_part` (or something like that) and keep using it :)


def check_existing_sagemaker_studio_domain(self, environment: models.Environment):
logger.info('Check if there is an existing sagemaker studio domain in the account')
try:
logger.info('check sagemaker studio domain created as part of data.all environment stack.')
cdk_look_up_role_arn = SessionHelper.get_cdk_look_up_role_arn(
accountid=environment.AwsAccountId, region=environment.region
)
dataall_created_domain = ParameterStoreManager.client(
AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn
).get_parameter(Name=f'/dataall/{environment.environmentUri}/sagemaker/sagemakerstudio/domain_id')
return False
except ClientError as e:
logger.info(f'check sagemaker studio domain created outside of data.all. Parameter data.all not found: {e}')
existing_domain = SagemakerStudio.get_sagemaker_studio_domain(
AwsAccountId=environment.AwsAccountId, region=environment.region, role=cdk_look_up_role_arn
)
return existing_domain.get('DomainId', False)

def __init__(self, scope: Construct, construct_id: str, environment: models.Environment, sagemaker_principals, vpc_id, subnet_ids, **kwargs) -> None:
super().__init__(scope, construct_id, **kwargs)

self._environment = environment
self.existing_sagemaker_domain = self.check_existing_sagemaker_studio_domain(environment=self._environment)

if self._environment.mlStudiosEnabled and not self.existing_sagemaker_domain:
sagemaker_domain_role = iam.Role(
self,
'RoleForSagemakerStudioUsers',
assumed_by=iam.ServicePrincipal('sagemaker.amazonaws.com'),
role_name='RoleSagemakerStudioUsers',
managed_policies=[
iam.ManagedPolicy.from_managed_policy_arn(
self,
id='SagemakerFullAccess',
managed_policy_arn='arn:aws:iam::aws:policy/AmazonSageMakerFullAccess',
),
iam.ManagedPolicy.from_managed_policy_arn(
self, id='S3FullAccess', managed_policy_arn='arn:aws:iam::aws:policy/AmazonS3FullAccess'
),
],
)

sagemaker_domain_key = kms.Key(
self,
'SagemakerDomainKmsKey',
alias='SagemakerStudioDomain',
enable_key_rotation=True,
policy=iam.PolicyDocument(
assign_sids=True,
statements=[
iam.PolicyStatement(
resources=['*'],
effect=iam.Effect.ALLOW,
principals=[iam.AccountPrincipal(account_id=self._environment.AwsAccountId), sagemaker_domain_role] + sagemaker_principals,
actions=['kms:*'],
)
],
),
)

sagemaker_domain = sagemaker.CfnDomain(
self,
'SagemakerStudioDomain',
domain_name=f'SagemakerStudioDomain-{self._environment.region}-{self._environment.AwsAccountId}',
auth_mode='IAM',
default_user_settings=sagemaker.CfnDomain.UserSettingsProperty(
execution_role=sagemaker_domain_role.role_arn,
security_groups=[],
sharing_settings=sagemaker.CfnDomain.SharingSettingsProperty(
notebook_output_option='Allowed',
s3_kms_key_id=sagemaker_domain_key.key_id,
s3_output_path=f's3://sagemaker-{self._environment.region}-{self._environment.AwsAccountId}',
),
),
vpc_id=vpc_id,
subnet_ids=subnet_ids,
app_network_access_type='VpcOnly',
kms_key_id=sagemaker_domain_key.key_id,
)

ssm.StringParameter(
self,
'SagemakerStudioDomainId',
string_value=sagemaker_domain.attr_domain_id,
parameter_name=f'/dataall/{self._environment.environmentUri}/sagemaker/sagemakerstudio/domain_id',
)


@stack(stack='sagemakerstudiouserprofile')
class SagemakerStudioUserProfile(Stack):
module_name = __file__
Expand Down