diff --git a/backend/dataall/base/aws/ec2_client.py b/backend/dataall/base/aws/ec2_client.py index 6b53ebaab..21f61a0fc 100644 --- a/backend/dataall/base/aws/ec2_client.py +++ b/backend/dataall/base/aws/ec2_client.py @@ -9,7 +9,7 @@ class EC2: @staticmethod def get_client(account_id: str, region: str, role=None): - session = SessionHelper.remote_session(accountid=account_id, role=role) + session = SessionHelper.remote_session(accountid=account_id, region=region, role=role) return session.client('ec2', region_name=region) @staticmethod diff --git a/backend/dataall/base/aws/iam.py b/backend/dataall/base/aws/iam.py index ad52ac0a5..7ea87df17 100644 --- a/backend/dataall/base/aws/iam.py +++ b/backend/dataall/base/aws/iam.py @@ -8,15 +8,15 @@ class IAM: @staticmethod - def client(account_id: str, role=None): - session = SessionHelper.remote_session(accountid=account_id, role=role) + def client(account_id: str, region: str, role=None): + session = SessionHelper.remote_session(accountid=account_id, region=region, role=role) return session.client('iam') @staticmethod - def get_role(account_id: str, role_arn: str, role=None): + def get_role(account_id: str, region: str, role_arn: str, role=None): log.info(f'Getting IAM role = {role_arn}') try: - client = IAM.client(account_id=account_id, role=role) + client = IAM.client(account_id=account_id, region=region, role=role) response = client.get_role(RoleName=role_arn.split('/')[-1]) assert response['Role']['Arn'] == role_arn, "Arn doesn't match the role name. Check Arn and try again." except ClientError as e: @@ -30,10 +30,10 @@ def get_role(account_id: str, role_arn: str, role=None): return response['Role'] @staticmethod - def get_role_arn_by_name(account_id: str, role_name: str, role=None): + def get_role_arn_by_name(account_id: str, region: str, role_name: str, role=None): log.info(f'Getting IAM role name= {role_name}') try: - client = IAM.client(account_id=account_id, role=role) + client = IAM.client(account_id=account_id, region=region, role=role) response = client.get_role(RoleName=role_name) return response['Role']['Arn'] except ClientError as e: @@ -47,11 +47,12 @@ def get_role_arn_by_name(account_id: str, role_name: str, role=None): @staticmethod def get_role_policy( account_id: str, + region: str, role_name: str, policy_name: str, ): try: - client = IAM.client(account_id) + client = IAM.client(account_id, region) response = client.get_role_policy( RoleName=role_name, PolicyName=policy_name, @@ -68,11 +69,12 @@ def get_role_policy( @staticmethod def delete_role_policy( account_id: str, + region: str, role_name: str, policy_name: str, ): try: - client = IAM.client(account_id) + client = IAM.client(account_id, region) client.delete_role_policy( RoleName=role_name, PolicyName=policy_name, @@ -85,10 +87,10 @@ def delete_role_policy( log.error(f'Failed to delete policy {policy_name} of role {role_name} : {e}') @staticmethod - def get_managed_policy_by_name(account_id: str, policy_name: str): + def get_managed_policy_by_name(account_id: str, region: str, policy_name: str): try: arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' - client = IAM.client(account_id) + client = IAM.client(account_id, region) response = client.get_policy(PolicyArn=arn) return response['Policy'] except ClientError as e: @@ -100,9 +102,9 @@ def get_managed_policy_by_name(account_id: str, policy_name: str): return None @staticmethod - def create_managed_policy(account_id: str, policy_name: str, policy: str): + def create_managed_policy(account_id: str, region: str, policy_name: str, policy: str): try: - client = IAM.client(account_id) + client = IAM.client(account_id, region) response = client.create_policy( PolicyName=policy_name, PolicyDocument=policy, @@ -118,10 +120,10 @@ def create_managed_policy(account_id: str, policy_name: str, policy: str): raise Exception(f'Failed to create managed policy {policy_name} : {e}') @staticmethod - def delete_managed_policy_by_name(account_id: str, policy_name): + def delete_managed_policy_by_name(account_id: str, region: str, policy_name): try: arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' - client = IAM.client(account_id) + client = IAM.client(account_id, region) client.delete_policy(PolicyArn=arn) except ClientError as e: if e.response['Error']['Code'] == 'AccessDenied': @@ -131,10 +133,10 @@ def delete_managed_policy_by_name(account_id: str, policy_name): raise Exception(f'Failed to delete managed policy {policy_name} : {e}') @staticmethod - def get_managed_policy_default_version(account_id: str, policy_name: str): + def get_managed_policy_default_version(account_id: str, region: str, policy_name: str): try: arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' - client = IAM.client(account_id) + client = IAM.client(account_id, region) response = client.get_policy(PolicyArn=arn) versionId = response['Policy']['DefaultVersionId'] policyVersion = client.get_policy_version(PolicyArn=arn, VersionId=versionId) @@ -150,11 +152,11 @@ def get_managed_policy_default_version(account_id: str, policy_name: str): @staticmethod def update_managed_policy_default_version( - account_id: str, policy_name: str, old_version_id: str, policy_document: str + account_id: str, region: str, policy_name: str, old_version_id: str, policy_document: str ): try: arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' - client = IAM.client(account_id) + client = IAM.client(account_id, region) client.create_policy_version(PolicyArn=arn, PolicyDocument=policy_document, SetAsDefault=True) client.delete_policy_version(PolicyArn=arn, VersionId=old_version_id) @@ -168,11 +170,12 @@ def update_managed_policy_default_version( @staticmethod def delete_managed_policy_non_default_versions( account_id: str, + region: str, policy_name: str, ): try: arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' - client = IAM.client(account_id) + client = IAM.client(account_id, region) # List all policy versions paginator = client.get_paginator('list_policy_versions') @@ -197,9 +200,9 @@ def delete_managed_policy_non_default_versions( return None, None @staticmethod - def is_policy_attached(account_id: str, policy_name: str, role_name: str): + def is_policy_attached(account_id: str, region: str, policy_name: str, role_name: str): try: - client = IAM.client(account_id) + client = IAM.client(account_id, region) paginator = client.get_paginator('list_attached_role_policies') pages = paginator.paginate(RoleName=role_name) policies = [] @@ -215,9 +218,9 @@ def is_policy_attached(account_id: str, policy_name: str, role_name: str): return False @staticmethod - def attach_role_policy(account_id, role_name, policy_arn): + def attach_role_policy(account_id, region: str, role_name, policy_arn): try: - client = IAM.client(account_id) + client = IAM.client(account_id, region) response = client.attach_role_policy(RoleName=role_name, PolicyArn=policy_arn) return True except ClientError as e: @@ -229,10 +232,10 @@ def attach_role_policy(account_id, role_name, policy_arn): raise e @staticmethod - def detach_policy_from_role(account_id: str, role_name: str, policy_name: str): + def detach_policy_from_role(account_id: str, region: str, role_name: str, policy_name: str): try: arn = f'arn:aws:iam::{account_id}:policy/{policy_name}' - client = IAM.client(account_id) + client = IAM.client(account_id, region) client.detach_role_policy(RoleName=role_name, PolicyArn=arn) except ClientError as e: if e.response['Error']['Code'] == 'AccessDenied': diff --git a/backend/dataall/base/aws/parameter_store.py b/backend/dataall/base/aws/parameter_store.py index a78b7cb16..cda2acfd8 100644 --- a/backend/dataall/base/aws/parameter_store.py +++ b/backend/dataall/base/aws/parameter_store.py @@ -19,7 +19,7 @@ def __init__(self): def client(AwsAccountId=None, region=None, role=None): if AwsAccountId: log.info(f"SSM Parameter remote session with role:{role if role else 'PivotRole'}") - session = SessionHelper.remote_session(accountid=AwsAccountId, role=role) + session = SessionHelper.remote_session(accountid=AwsAccountId, region=region, role=role) else: log.info('SSM Parameter session in central account') session = SessionHelper.get_session() diff --git a/backend/dataall/base/aws/quicksight.py b/backend/dataall/base/aws/quicksight.py index 3c6675d83..5f73fbbd0 100644 --- a/backend/dataall/base/aws/quicksight.py +++ b/backend/dataall/base/aws/quicksight.py @@ -31,24 +31,26 @@ def __init__(self): pass @staticmethod - def get_quicksight_client(AwsAccountId, region='eu-west-1'): + def get_quicksight_client(AwsAccountId, region, session_region='eu-west-1'): """Returns a boto3 quicksight client in the provided account/region Args: AwsAccountId(str) : aws account id - region(str) : aws region + region(str) : aws region of the environment + session_region(str) : region to create the session Returns : boto3.client ("quicksight") """ - session = SessionHelper.remote_session(accountid=AwsAccountId) - return session.client('quicksight', region_name=region) + session = SessionHelper.remote_session(accountid=AwsAccountId, region=region) + return session.client('quicksight', region_name=session_region) @staticmethod - def get_identity_region(AwsAccountId): + def get_identity_region(AwsAccountId, region): """Quicksight manages identities in one region, and there is no API to retrieve it However, when using Quicksight user/group apis in the wrong region, the client will throw and exception showing the region Quicksight's using as its identity region. Args: AwsAccountId(str) : aws account id + AwsAccountId(str) : aws region of environment Returns: str the region quicksight uses as identity region """ @@ -59,7 +61,9 @@ def get_identity_region(AwsAccountId): try: identity_region = QuicksightClient.QUICKSIGHT_IDENTITY_REGIONS[index].get('code') index += 1 - client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=identity_region) + client = QuicksightClient.get_quicksight_client( + AwsAccountId=AwsAccountId, region=region, session_region=identity_region + ) response = client.describe_account_settings(AwsAccountId=AwsAccountId) logger.info(f'Returning identity region = {identity_region} for account {AwsAccountId}') return identity_region @@ -84,15 +88,16 @@ def get_identity_region(AwsAccountId): ) @staticmethod - def get_quicksight_client_in_identity_region(AwsAccountId): + def get_quicksight_client_in_identity_region(AwsAccountId, region): """Returns a boto3 quicksight client in the Quicksight identity region for the provided account Args: AwsAccountId(str) : aws account id + region(str) : aws region of the environment Returns : boto3.client ("quicksight") """ - identity_region = QuicksightClient.get_identity_region(AwsAccountId) - session = SessionHelper.remote_session(accountid=AwsAccountId) + identity_region = QuicksightClient.get_identity_region(AwsAccountId, region) + session = SessionHelper.remote_session(accountid=AwsAccountId, region=region) return session.client('quicksight', region_name=identity_region) @staticmethod @@ -105,7 +110,7 @@ def check_quicksight_enterprise_subscription(AwsAccountId, region=None): True if Quicksight Enterprise Edition is enabled in the AWS Account """ logger.info(f'Checking Quicksight subscription in AWS account = {AwsAccountId}') - client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=region) + client = QuicksightClient.get_quicksight_client(AwsAccountId=AwsAccountId, region=region, session_region=region) try: response = client.describe_account_subscription(AwsAccountId=AwsAccountId) if not response['AccountInfo']: @@ -141,8 +146,8 @@ def create_quicksight_group(AwsAccountId, region, GroupName=DEFAULT_GROUP_NAME): Returns:dict quicksight.describe_group response """ - client = QuicksightClient.get_quicksight_client_in_identity_region(AwsAccountId) - group = QuicksightClient.describe_group(client, AwsAccountId, GroupName) + client = QuicksightClient.get_quicksight_client_in_identity_region(AwsAccountId, region) + group = QuicksightClient.describe_group(client, AwsAccountId, region, GroupName) if not group: if GroupName == QuicksightClient.DEFAULT_GROUP_NAME: logger.info(f'Initializing data.all default group = {GroupName}') @@ -161,16 +166,16 @@ def create_quicksight_group(AwsAccountId, region, GroupName=DEFAULT_GROUP_NAME): return group @staticmethod - def describe_group(client, AwsAccountId, GroupName=DEFAULT_GROUP_NAME): + def describe_group(client, AwsAccountId, region, GroupName=DEFAULT_GROUP_NAME): try: response = client.describe_group(AwsAccountId=AwsAccountId, GroupName=GroupName, Namespace='default') logger.info( f'Quicksight {GroupName} group already exists in {AwsAccountId} ' - f'(using identity region {QuicksightClient.get_identity_region(AwsAccountId)}): ' + f'(using identity region {QuicksightClient.get_identity_region(AwsAccountId, region)}): ' f'{response}' ) return response except client.exceptions.ResourceNotFoundException: logger.info( - f'Creating Quicksight group in {AwsAccountId} (using identity region {QuicksightClient.get_identity_region(AwsAccountId)})' + f'Creating Quicksight group in {AwsAccountId} (using identity region {QuicksightClient.get_identity_region(AwsAccountId, region)})' ) diff --git a/backend/dataall/base/aws/secrets_manager.py b/backend/dataall/base/aws/secrets_manager.py index 42275ca64..21cf06260 100644 --- a/backend/dataall/base/aws/secrets_manager.py +++ b/backend/dataall/base/aws/secrets_manager.py @@ -14,7 +14,7 @@ class SecretsManager: def __init__(self, account_id=None, region=_DEFAULT_REGION): if account_id: - session = SessionHelper.remote_session(account_id) + session = SessionHelper.remote_session(account_id, region) self._client = session.client('secretsmanager', region_name=region) else: self._client = boto3.client('secretsmanager', region_name=region) diff --git a/backend/dataall/base/aws/sts.py b/backend/dataall/base/aws/sts.py index 4f7787950..2b9d7a6aa 100644 --- a/backend/dataall/base/aws/sts.py +++ b/backend/dataall/base/aws/sts.py @@ -6,6 +6,7 @@ import boto3 from botocore.client import Config from botocore.exceptions import ClientError +from dataall.base.config import config from dataall.version import __version__, __pkg_name__ @@ -101,14 +102,20 @@ def get_external_id_secret(cls): ) @classmethod - def get_delegation_role_name(cls): + def get_delegation_role_name(cls, region): """Returns the role name that this package assumes on remote accounts Returns: string: name of the assumed role """ - return SessionHelper._get_parameter_value( + base_name = SessionHelper._get_parameter_value( parameter_path=f'/dataall/{os.getenv("envname", "local")}/pivotRole/pivotRoleName' ) + return ( + f'{base_name}-{region}' + if config.get_property('core.features.cdk_pivot_role_multiple_environments_same_account', default=False) + and base_name != 'dataallPivotRole' + else base_name + ) @classmethod def get_console_access_url(cls, boto3_session, region='eu-west-1', bucket=None): @@ -150,20 +157,22 @@ def get_console_access_url(cls, boto3_session, region='eu-west-1', bucket=None): return request_url @classmethod - def get_delegation_role_arn(cls, accountid): + def get_delegation_role_arn(cls, accountid, region): """Returns the name that will be assumed to perform IAM actions on a given AWS accountid Args: accountid(string) : aws account id + region(string) : aws account region Returns: string : arn of the delegation role on the target aws account id """ - return 'arn:aws:iam::{}:role/{}'.format(accountid, cls.get_delegation_role_name()) + return 'arn:aws:iam::{}:role/{}'.format(accountid, cls.get_delegation_role_name(region)) @classmethod def get_cdk_look_up_role_arn(cls, accountid, region): """Returns the name that will be assumed to perform IAM actions on a given AWS accountid using CDK Toolkit role Args: accountid(string) : aws account id + region(string) : aws account region Returns: string : arn of the CDKToolkit role on the target aws account id """ @@ -177,6 +186,7 @@ def get_cdk_exec_role_arn(cls, accountid, region): """Returns the name that will be assumed to perform IAM actions on a given AWS accountid using CDK Toolkit role Args: accountid(string) : aws account id + region(string) : aws account region Returns: string : arn of the CDKToolkit role on the target aws account id """ @@ -186,23 +196,11 @@ def get_cdk_exec_role_arn(cls, accountid, region): return 'arn:aws:iam::{}:role/cdk-hnb659fds-cfn-exec-role-{}-{}'.format(accountid, accountid, region) @classmethod - def get_delegation_role_id(cls, accountid): - """Returns the name that will be assumed to perform IAM actions on a given AWS accountid - Args: - accountid(string) : aws account id - Returns : - string : RoleId of the role - """ - session = SessionHelper.remote_session(accountid=accountid) - client = session.client('iam', region_name='eu-west-1') - response = client.get_role(RoleName=cls.get_delegation_role_name()) - return response['Role']['RoleId'] - - @classmethod - def remote_session(cls, accountid, role=None): + def remote_session(cls, accountid, region, role=None): """Creates a remote boto3 session on the remote AWS account , assuming the delegation Role Args: accountid(string) : aws account id + region(string) : aws region role(string) : arn of the IAM role to assume in the boto3 session Returns : boto3.session.Session: boto3 Session, on the target aws accountid, assuming the delegation role or a provided role @@ -213,7 +211,7 @@ def remote_session(cls, accountid, role=None): role_arn = role else: log.info(f'Remote boto3 session using pivot role for account= {accountid}') - role_arn = cls.get_delegation_role_arn(accountid=accountid) + role_arn = cls.get_delegation_role_arn(accountid=accountid, region=region) session = SessionHelper.get_session(base_session=base_session, role_arn=role_arn) return session @@ -248,8 +246,8 @@ def get_organization_id(cls, session=None): return response['Organization']['Id'] @staticmethod - def get_role_id(accountid, name): - session = SessionHelper.remote_session(accountid=accountid) + def get_role_id(accountid, region, name): + session = SessionHelper.remote_session(accountid=accountid, region=region) client = session.client('iam') try: response = client.get_role(RoleName=name) @@ -363,12 +361,12 @@ def generate_console_url(credentials, session_duration=None, region='eu-west-1', return request_url @staticmethod - def is_assumable_pivot_role(accountid): + def is_assumable_pivot_role(accountid, region): try: - SessionHelper.remote_session(accountid=accountid) + SessionHelper.remote_session(accountid=accountid, region=region) except ClientError as e: log.error( - f'Failed to assume dataall pivot role session in environment with account id {accountid} due to {e}' + f'Failed to assume dataall pivot role session in environment with account id {accountid} region {region} due to {e}' ) return False except Exception as e: diff --git a/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py b/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py index e8cd3f403..43e1ccb17 100644 --- a/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py +++ b/backend/dataall/base/cdkproxy/cdk_cli_wrapper.py @@ -60,7 +60,7 @@ def aws_configure(profile_name='default'): def update_stack_output(session, stack): outputs = {} stack_outputs = None - aws = SessionHelper.remote_session(stack.accountid) + aws = SessionHelper.remote_session(stack.accountid, stack.region) cfn = aws.resource('cloudformation', region_name=stack.region) try: stack_outputs = cfn.Stack(f'{stack.name}').outputs @@ -204,7 +204,7 @@ def describe_stack(stack, engine: Engine = None, stackid: str = None): stack = session.query(Stack).get(stackid) if stack.status == 'DELETE_COMPLETE': return {'StackId': stack.stackid, 'StackStatus': stack.status} - session = SessionHelper.remote_session(stack.accountid) + session = SessionHelper.remote_session(stack.accountid, stack.region) resource = session.resource('cloudformation', region_name=stack.region) try: meta = resource.Stack(f'{stack.name}') diff --git a/backend/dataall/core/environment/api/resolvers.py b/backend/dataall/core/environment/api/resolvers.py index 968d76823..20ddc5d08 100644 --- a/backend/dataall/core/environment/api/resolvers.py +++ b/backend/dataall/core/environment/api/resolvers.py @@ -73,8 +73,8 @@ def check_environment(context: Context, source, account_id, region, data): cdk_role_name = CloudFormation.check_existing_cdk_toolkit_stack(AwsAccountId=account_id, region=region) if not pivot_role_as_part_of_environment: log.info('Check if PivotRole exist in the account') - pivot_role_arn = SessionHelper.get_delegation_role_arn(accountid=account_id) - role = IAM.get_role(account_id=account_id, role_arn=pivot_role_arn, role=cdk_look_up_role_arn) + pivot_role_arn = SessionHelper.get_delegation_role_arn(accountid=account_id, region=region) + role = IAM.get_role(account_id=account_id, region=region, role_arn=pivot_role_arn, role=cdk_look_up_role_arn) if not role: raise exceptions.AWSResourceNotFound( action='CHECK_PIVOT_ROLE', @@ -171,7 +171,7 @@ def invite_group(context: Context, source, input): def add_consumption_role(context: Context, source, input): with context.engine.scoped_session() as session: env = EnvironmentService.get_environment_by_uri(session, input['environmentUri']) - role = IAM.get_role(env.AwsAccountId, input['IAMRoleArn']) + role = IAM.get_role(env.AwsAccountId, env.region, input['IAMRoleArn']) if not role: raise exceptions.AWSResourceNotFound( action='ADD_CONSUMPTION_ROLE', @@ -357,6 +357,7 @@ def get_policies(context: Context, source, **kwargs): role_name=source.IAMRoleName, environmentUri=environment.environmentUri, account=environment.AwsAccountId, + region=environment.region, resource_prefix=environment.resourcePrefix, ).get_all_policies() @@ -404,7 +405,7 @@ def _get_environment_group_aws_session(session, username, groups, environment, g action='ENVIRONMENT_AWS_ACCESS', message=f'User: {username} is not member of the team {groupUri}', ) - pivot_session = SessionHelper.remote_session(environment.AwsAccountId) + pivot_session = SessionHelper.remote_session(environment.AwsAccountId, environment.region) if not groupUri: if environment.SamlGroupName in groups: aws_session = SessionHelper.get_session( @@ -686,7 +687,7 @@ def get_pivot_role_name(context: Context, source, organizationUri=None): resource_uri=organizationUri, permission_name=GET_ORGANIZATION, ) - pivot_role_name = SessionHelper.get_delegation_role_name() + pivot_role_name = SessionHelper.get_delegation_role_name(region='') if not pivot_role_name: raise exceptions.AWSResourceNotFound( action='GET_PIVOT_ROLE_NAME', diff --git a/backend/dataall/core/environment/cdk/environment_stack.py b/backend/dataall/core/environment/cdk/environment_stack.py index 6fd3211d6..1d428255f 100644 --- a/backend/dataall/core/environment/cdk/environment_stack.py +++ b/backend/dataall/core/environment/cdk/environment_stack.py @@ -118,7 +118,6 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): ) # Read input self.target_uri = target_uri - self.pivot_role_name = SessionHelper.get_delegation_role_name() self.external_id = SessionHelper.get_external_id_secret() self.dataall_central_account = SessionHelper.get_account() @@ -130,6 +129,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): self.engine = self.get_engine() self._environment = self.get_target(target_uri=target_uri) + self.pivot_role_name = SessionHelper.get_delegation_role_name(region=self._environment.region) self.environment_groups: [EnvironmentGroup] = self.get_environment_groups( self.engine, environment=self._environment @@ -438,6 +438,7 @@ def create_group_environment_role(self, group: EnvironmentGroup, id: str): resource_prefix=self._environment.resourcePrefix, role_name=group.environmentIAMRoleName, account=self._environment.AwsAccountId, + region=self._environment.region, ) try: managed_policies = policy_manager.get_all_policies() diff --git a/backend/dataall/core/environment/cdk/pivot_role_stack.py b/backend/dataall/core/environment/cdk/pivot_role_stack.py index 6185b9a97..33965e659 100644 --- a/backend/dataall/core/environment/cdk/pivot_role_stack.py +++ b/backend/dataall/core/environment/cdk/pivot_role_stack.py @@ -40,7 +40,7 @@ def generate_policies(self) -> List[iam.ManagedPolicy]: iam.ManagedPolicy( self.stack, f'PivotRolePolicy-{index + 1}', - managed_policy_name=f'{self.env_resource_prefix}-pivot-role-cdk-policy-{index + 1}', + managed_policy_name=f'{self.env_resource_prefix}-pivot-role-cdk-policy-{self.region}-{index + 1}', statements=chunk, ) ) diff --git a/backend/dataall/core/environment/services/environment_service.py b/backend/dataall/core/environment/services/environment_service.py index a62e273fc..647a8ef2e 100644 --- a/backend/dataall/core/environment/services/environment_service.py +++ b/backend/dataall/core/environment/services/environment_service.py @@ -242,6 +242,7 @@ def invite_group(session, uri, data=None) -> (Environment, EnvironmentGroup): role_name=env_group_iam_role_name, environmentUri=environment.environmentUri, account=environment.AwsAccountId, + region=environment.region, resource_prefix=environment.resourcePrefix, ).create_all_policies(managed=env_role_imported) @@ -335,6 +336,7 @@ def remove_group(session, uri, group): role_name=group_membership.environmentIAMRoleName, environmentUri=environment.environmentUri, account=environment.AwsAccountId, + region=environment.region, resource_prefix=environment.resourcePrefix, ).delete_all_policies() @@ -444,6 +446,7 @@ def add_consumption_role(session, uri, data=None) -> (Environment, EnvironmentGr role_name=consumption_role.IAMRoleName, environmentUri=environment.environmentUri, account=environment.AwsAccountId, + region=environment.region, resource_prefix=environment.resourcePrefix, ).create_all_policies(managed=consumption_role.dataallManaged) @@ -478,6 +481,7 @@ def remove_consumption_role(session, uri, env_uri): role_name=consumption_role.IAMRoleName, environmentUri=environment.environmentUri, account=environment.AwsAccountId, + region=environment.region, resource_prefix=environment.resourcePrefix, ).delete_all_policies() @@ -935,6 +939,7 @@ def delete_environment(session, uri, environment): role_name=environment.EnvironmentDefaultIAMRoleName, environmentUri=environment.environmentUri, account=environment.AwsAccountId, + region=environment.region, resource_prefix=environment.resourcePrefix, ).delete_all_policies() diff --git a/backend/dataall/core/environment/services/managed_iam_policies.py b/backend/dataall/core/environment/services/managed_iam_policies.py index 872391673..e55505267 100644 --- a/backend/dataall/core/environment/services/managed_iam_policies.py +++ b/backend/dataall/core/environment/services/managed_iam_policies.py @@ -13,9 +13,10 @@ class ManagedPolicy(ABC): """ @abstractmethod - def __init__(self, role_name, account, environmentUri, resource_prefix): + def __init__(self, role_name, account, region, environmentUri, resource_prefix): self.role_name = role_name self.account = account + self.region = region self.environmentUri = environmentUri self.resource_prefix = resource_prefix @@ -43,17 +44,17 @@ def generate_empty_policy(self) -> dict: def check_if_policy_exists(self) -> bool: policy_name = self.generate_policy_name() - share_policy = IAM.get_managed_policy_by_name(self.account, policy_name) + share_policy = IAM.get_managed_policy_by_name(self.account, self.region, policy_name) return share_policy is not None def check_if_policy_attached(self): policy_name = self.generate_policy_name() - return IAM.is_policy_attached(self.account, policy_name, self.role_name) + return IAM.is_policy_attached(self.account, self.region, policy_name, self.role_name) def attach_policy(self): policy_arn = f'arn:aws:iam::{self.account}:policy/{self.generate_policy_name()}' try: - IAM.attach_role_policy(self.account, self.role_name, policy_arn) + IAM.attach_role_policy(self.account, self.region, self.role_name, policy_arn) except Exception as e: raise Exception(f"Required customer managed policy {policy_arn} can't be attached: {e}") @@ -63,11 +64,13 @@ def __init__( self, role_name, account, + region, environmentUri, resource_prefix, ): self.role_name = role_name self.account = account + self.region = region self.environmentUri = environmentUri self.resource_prefix = resource_prefix self.ManagedPolicies = ManagedPolicy.__subclasses__() @@ -78,6 +81,7 @@ def _initialize_policy(self, managedPolicy): return managedPolicy( role_name=self.role_name, account=self.account, + region=self.region, environmentUri=self.environmentUri, resource_prefix=self.resource_prefix, ) @@ -94,12 +98,16 @@ def create_all_policies(self, managed) -> bool: logger.info(f'Creating policy {policy_name}') IAM.create_managed_policy( - account_id=self.account, policy_name=policy_name, policy=json.dumps(empty_policy) + account_id=self.account, + region=self.region, + policy_name=policy_name, + policy=json.dumps(empty_policy), ) if managed: IAM.attach_role_policy( account_id=self.account, + region=self.region, role_name=self.role_name, policy_arn=f'arn:aws:iam::{self.account}:policy/{policy_name}', ) @@ -118,11 +126,15 @@ def delete_all_policies(self) -> bool: logger.info(f'Deleting policy {policy_name}') if Policy.check_if_policy_attached(): IAM.detach_policy_from_role( - account_id=self.account, role_name=self.role_name, policy_name=policy_name + account_id=self.account, region=self.region, role_name=self.role_name, policy_name=policy_name ) if Policy.check_if_policy_exists(): - IAM.delete_managed_policy_non_default_versions(account_id=self.account, policy_name=policy_name) - IAM.delete_managed_policy_by_name(account_id=self.account, policy_name=policy_name) + IAM.delete_managed_policy_non_default_versions( + account_id=self.account, region=self.region, policy_name=policy_name + ) + IAM.delete_managed_policy_by_name( + account_id=self.account, region=self.region, policy_name=policy_name + ) except Exception as e: raise e return True diff --git a/backend/dataall/core/stacks/aws/cloudformation.py b/backend/dataall/core/stacks/aws/cloudformation.py index 4f113819e..b9864ff8f 100644 --- a/backend/dataall/core/stacks/aws/cloudformation.py +++ b/backend/dataall/core/stacks/aws/cloudformation.py @@ -17,7 +17,7 @@ def __init__(self): @staticmethod def client(AwsAccountId, region, role=None): - session = SessionHelper.remote_session(accountid=AwsAccountId, role=role) + session = SessionHelper.remote_session(accountid=AwsAccountId, region=region, role=role) return session.client('cloudformation', region_name=region) @staticmethod @@ -46,7 +46,7 @@ def delete_cloudformation_stack(**data): region = data['region'] stack_name = data['stack_name'] try: - aws_session = SessionHelper.remote_session(accountid=accountid) + aws_session = SessionHelper.remote_session(accountid=accountid, region=region) cfnclient = aws_session.client('cloudformation', region_name=region) response = cfnclient.delete_stack( StackName=stack_name, @@ -63,7 +63,7 @@ def _get_stack(**data) -> dict: accountid = data['accountid'] region = data['region'] stack_name = data['stack_name'] - aws_session = SessionHelper.remote_session(accountid=accountid) + aws_session = SessionHelper.remote_session(accountid=accountid, region=region) cfnclient = aws_session.client('cloudformation', region_name=region) response = cfnclient.describe_stacks(StackName=stack_name) return response['Stacks'][0] @@ -137,7 +137,7 @@ def _describe_stack_resources(**data): accountid = data['accountid'] region = data.get('region', 'eu-west-1') stack_name = data['stack_name'] - aws_session = SessionHelper.remote_session(accountid=accountid) + aws_session = SessionHelper.remote_session(accountid=accountid, region=region) client = aws_session.client('cloudformation', region_name=region) try: stack_resources = client.describe_stack_resources(StackName=stack_name) @@ -151,7 +151,7 @@ def _describe_stack_events(**data): accountid = data['accountid'] region = data.get('region', 'eu-west-1') stack_name = data['stack_name'] - aws_session = SessionHelper.remote_session(accountid=accountid) + aws_session = SessionHelper.remote_session(accountid=accountid, region=region) client = aws_session.client('cloudformation', region_name=region) try: stack_events = client.describe_stack_events(StackName=stack_name) diff --git a/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py b/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py index 5cc271a01..bf7380e30 100644 --- a/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py +++ b/backend/dataall/modules/dashboards/aws/dashboard_quicksight_client.py @@ -19,10 +19,14 @@ def __init__(self, username, aws_account_id, region='eu-west-1'): self._account_id = aws_account_id self._region = region self._username = username - self._client = QuicksightClient.get_quicksight_client(aws_account_id, region) + self._client = QuicksightClient.get_quicksight_client( + AwsAccountId=aws_account_id, region=region, session_region=region + ) def register_user_in_group(self, group_name, user_role='READER'): - identity_region_client = QuicksightClient.get_quicksight_client_in_identity_region(self._account_id) + identity_region_client = QuicksightClient.get_quicksight_client_in_identity_region( + self._account_id, self._region + ) QuicksightClient.create_quicksight_group( AwsAccountId=self._account_id, region=self._region, GroupName=group_name ) @@ -80,7 +84,7 @@ def get_reader_session(self, user_role='READER', dashboard_id=None, domain_name: def get_shared_reader_session(self, group_name, user_role='READER', dashboard_id=None): aws_account_id = self._account_id - identity_region = QuicksightClient.get_identity_region(aws_account_id) + identity_region = QuicksightClient.get_identity_region(aws_account_id, self._region) group_principal = f'arn:aws:quicksight:{identity_region}:{aws_account_id}:group/default/{group_name}' user = self.register_user_in_group(group_name, user_role) @@ -239,7 +243,7 @@ def _check_dashboard_permissions(self, dashboard_id): return read_principals, write_principals def _list_user_groups(self): - client = QuicksightClient.get_quicksight_client_in_identity_region(self._account_id) + client = QuicksightClient.get_quicksight_client_in_identity_region(self._account_id, self._region) user = self._describe_user() if not user: return [] @@ -248,7 +252,7 @@ def _list_user_groups(self): def _describe_user(self): """Describes a QS user, returns None if not found""" - client = QuicksightClient.get_quicksight_client_in_identity_region(self._account_id) + client = QuicksightClient.get_quicksight_client_in_identity_region(self._account_id, self._region) try: response = client.describe_user(UserName=self._username, AwsAccountId=self._account_id, Namespace='default') except ClientError: diff --git a/backend/dataall/modules/datapipelines/aws/codecommit_datapipeline_client.py b/backend/dataall/modules/datapipelines/aws/codecommit_datapipeline_client.py index 34c1c040d..906c0a82e 100644 --- a/backend/dataall/modules/datapipelines/aws/codecommit_datapipeline_client.py +++ b/backend/dataall/modules/datapipelines/aws/codecommit_datapipeline_client.py @@ -3,7 +3,7 @@ class DatapipelineCodecommitClient: def __init__(self, aws_account_id, region) -> None: - self._session = SessionHelper.remote_session(aws_account_id) + self._session = SessionHelper.remote_session(aws_account_id, region) self._client = self._session.client('codecommit', region_name=region) def delete_repository(self, repository): diff --git a/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_cli_wrapper_extension.py b/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_cli_wrapper_extension.py index 8ccfe26a6..8d7befe5b 100644 --- a/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_cli_wrapper_extension.py +++ b/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_cli_wrapper_extension.py @@ -27,7 +27,7 @@ def extend_deployment(self, stack, session, env): update_stack_output(session, stack) return True, path, app_path - aws = SessionHelper.remote_session(stack.accountid) + aws = SessionHelper.remote_session(stack.accountid, stack.region) creds = aws.get_credentials() env.update( { diff --git a/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_pipeline.py b/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_pipeline.py index eda876c38..b0c6ce698 100644 --- a/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_pipeline.py +++ b/backend/dataall/modules/datapipelines/cdk/datapipelines_cdk_pipeline.py @@ -287,7 +287,7 @@ def _check_repository(codecommit_client, repo_name): @staticmethod def _set_env_vars(pipeline_environment): - aws = SessionHelper.remote_session(pipeline_environment.AwsAccountId) + aws = SessionHelper.remote_session(pipeline_environment.AwsAccountId, pipeline_environment.region) env_creds = aws.get_credentials() python_path = '/:'.join(sys.path)[1:] + ':/code' + os.getenv('PATH') diff --git a/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py b/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py index ce4efa754..d02c55dfc 100644 --- a/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py +++ b/backend/dataall/modules/datapipelines/cdk/datapipelines_pipeline.py @@ -522,7 +522,7 @@ def initialize_repo(pipeline, code_dir_path, env_vars): @staticmethod def _set_env_vars(pipeline_environment): - aws = SessionHelper.remote_session(pipeline_environment.AwsAccountId) + aws = SessionHelper.remote_session(pipeline_environment.AwsAccountId, pipeline_environment.region) env_creds = aws.get_credentials() env = { diff --git a/backend/dataall/modules/datapipelines/services/datapipelines_service.py b/backend/dataall/modules/datapipelines/services/datapipelines_service.py index aa0504d9e..fcc18c3c0 100644 --- a/backend/dataall/modules/datapipelines/services/datapipelines_service.py +++ b/backend/dataall/modules/datapipelines/services/datapipelines_service.py @@ -272,12 +272,12 @@ def get_credentials(uri): aws_account_id = pipeline.AwsAccountId return DataPipelineService._get_credentials_from_aws( - env_role_arn=env_role_arn, aws_account_id=aws_account_id + env_role_arn=env_role_arn, aws_account_id=aws_account_id, region=pipeline.region ) @staticmethod - def _get_credentials_from_aws(env_role_arn, aws_account_id): - aws_session = SessionHelper.remote_session(aws_account_id) + def _get_credentials_from_aws(env_role_arn, aws_account_id, region): + aws_session = SessionHelper.remote_session(aws_account_id, region) env_session = SessionHelper.get_session(aws_session, role_arn=env_role_arn) c = env_session.get_credentials() body = json.dumps( diff --git a/backend/dataall/modules/dataset_sharing/aws/glue_client.py b/backend/dataall/modules/dataset_sharing/aws/glue_client.py index 32e8a64a7..61795b199 100644 --- a/backend/dataall/modules/dataset_sharing/aws/glue_client.py +++ b/backend/dataall/modules/dataset_sharing/aws/glue_client.py @@ -9,7 +9,7 @@ class GlueClient: def __init__(self, account_id, region, database): - aws_session = SessionHelper.remote_session(accountid=account_id) + aws_session = SessionHelper.remote_session(accountid=account_id, region=region) self._client = aws_session.client('glue', region_name=region) self._database = database self._account_id = account_id diff --git a/backend/dataall/modules/dataset_sharing/aws/kms_client.py b/backend/dataall/modules/dataset_sharing/aws/kms_client.py index 5b2506ebe..ff0e0aad3 100644 --- a/backend/dataall/modules/dataset_sharing/aws/kms_client.py +++ b/backend/dataall/modules/dataset_sharing/aws/kms_client.py @@ -10,7 +10,7 @@ class KmsClient: _DEFAULT_POLICY_NAME = 'default' def __init__(self, account_id: str, region: str): - session = SessionHelper.remote_session(accountid=account_id) + session = SessionHelper.remote_session(accountid=account_id, region=region) self._client = session.client('kms', region_name=region) self._account_id = account_id diff --git a/backend/dataall/modules/dataset_sharing/aws/lakeformation_client.py b/backend/dataall/modules/dataset_sharing/aws/lakeformation_client.py index 9b3f17c37..5214d57a5 100644 --- a/backend/dataall/modules/dataset_sharing/aws/lakeformation_client.py +++ b/backend/dataall/modules/dataset_sharing/aws/lakeformation_client.py @@ -11,7 +11,7 @@ class LakeFormationClient: def __init__(self, account_id, region): - self._session = SessionHelper.remote_session(accountid=account_id) + self._session = SessionHelper.remote_session(accountid=account_id, region=region) self._client = self._session.client('lakeformation', region_name=region) def grant_permissions_to_database( diff --git a/backend/dataall/modules/dataset_sharing/aws/ram_client.py b/backend/dataall/modules/dataset_sharing/aws/ram_client.py index 84b7b0915..01d6a8f71 100644 --- a/backend/dataall/modules/dataset_sharing/aws/ram_client.py +++ b/backend/dataall/modules/dataset_sharing/aws/ram_client.py @@ -10,7 +10,7 @@ class RamClient: def __init__(self, account_id, region): - session = SessionHelper.remote_session(accountid=account_id) + session = SessionHelper.remote_session(accountid=account_id, region=region) self._client = session.client('ram', region_name=region) self._account_id = account_id diff --git a/backend/dataall/modules/dataset_sharing/aws/s3_client.py b/backend/dataall/modules/dataset_sharing/aws/s3_client.py index 637c4feae..bc6880492 100755 --- a/backend/dataall/modules/dataset_sharing/aws/s3_client.py +++ b/backend/dataall/modules/dataset_sharing/aws/s3_client.py @@ -7,7 +7,7 @@ class S3ControlClient: def __init__(self, account_id: str, region: str): - session = SessionHelper.remote_session(accountid=account_id) + session = SessionHelper.remote_session(accountid=account_id, region=region) self._client = session.client('s3control', region_name=region) self._account_id = account_id @@ -114,7 +114,7 @@ def generate_default_bucket_policy(s3_bucket_name: str): class S3Client: def __init__(self, account_id, region): - session = SessionHelper.remote_session(accountid=account_id) + session = SessionHelper.remote_session(accountid=account_id, region=region) self._client = session.client('s3', region_name=region) self._account_id = account_id diff --git a/backend/dataall/modules/dataset_sharing/services/managed_share_policy_service.py b/backend/dataall/modules/dataset_sharing/services/managed_share_policy_service.py index 6d920df40..8bce96207 100644 --- a/backend/dataall/modules/dataset_sharing/services/managed_share_policy_service.py +++ b/backend/dataall/modules/dataset_sharing/services/managed_share_policy_service.py @@ -19,9 +19,10 @@ class SharePolicyService(ManagedPolicy): - def __init__(self, role_name, account, environmentUri, resource_prefix): + def __init__(self, role_name, account, region, environmentUri, resource_prefix): self.role_name = role_name self.account = account + self.region = region self.environmentUri = environmentUri self.resource_prefix = resource_prefix @@ -141,7 +142,7 @@ def create_managed_policy_from_inline_and_delete_inline(self): policy_document = self._generate_managed_policy_from_inline_policies() log.info(f'Creating policy from inline backwards compatibility. Policy = {str(policy_document)}') policy_arn = IAM.create_managed_policy( - self.account, self.generate_policy_name(), json.dumps(policy_document) + self.account, self.region, self.generate_policy_name(), json.dumps(policy_document) ) # Delete obsolete inline policies @@ -189,7 +190,7 @@ def _get_policy_resources_from_inline_policy(self, policy_name): # This function can only be used for backwards compatibility where policies had statement[0] for s3 # and statement[1] for KMS permissions try: - existing_policy = IAM.get_role_policy(self.account, self.role_name, policy_name) + existing_policy = IAM.get_role_policy(self.account, self.region, self.role_name, policy_name) if existing_policy is not None: kms_resources = ( existing_policy['Statement'][1]['Resource'] if len(existing_policy['Statement']) > 1 else [] @@ -225,10 +226,10 @@ def _update_policy_resources_from_inline_policy(self, policy, statement_sid, exi def _delete_old_inline_policies(self): for policy_name in [OLD_IAM_S3BUCKET_ROLE_POLICY, OLD_IAM_ACCESS_POINT_ROLE_POLICY]: try: - existing_policy = IAM.get_role_policy(self.account, self.role_name, policy_name) + existing_policy = IAM.get_role_policy(self.account, self.region, self.role_name, policy_name) if existing_policy is not None: log.info(f'Deleting inline policy {policy_name}...') - IAM.delete_role_policy(self.account, self.role_name, policy_name) + IAM.delete_role_policy(self.account, self.region, self.role_name, policy_name) else: pass except Exception as e: diff --git a/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py b/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py index e733cc999..4dee75ae1 100644 --- a/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py +++ b/backend/dataall/modules/dataset_sharing/services/share_managers/lf_share_manager.py @@ -90,7 +90,9 @@ def get_share_principals(self) -> [str]: :return: List of principals' arns """ principal_iam_role_arn = IAM.get_role_arn_by_name( - account_id=self.target_environment.AwsAccountId, role_name=self.share.principalIAMRoleName + account_id=self.target_environment.AwsAccountId, + region=self.target_environment.region, + role_name=self.share.principalIAMRoleName, ) principals = [principal_iam_role_arn] dashboard_enabled = EnvironmentService.get_boolean_env_param( @@ -210,7 +212,7 @@ def grant_pivot_role_all_database_permissions_to_source_database(self) -> True: :return: True if it is successful """ self.lf_client_in_source.grant_permissions_to_database( - principals=[SessionHelper.get_delegation_role_arn(self.source_account_id)], + principals=[SessionHelper.get_delegation_role_arn(self.source_account_id, self.source_account_region)], database_name=self.source_database_name, permissions=['ALL'], ) @@ -243,7 +245,11 @@ def grant_pivot_role_all_database_permissions_to_shared_database(self) -> True: :return: True if it is successful """ self.lf_client_in_target.grant_permissions_to_database( - principals=[SessionHelper.get_delegation_role_arn(self.target_environment.AwsAccountId)], + principals=[ + SessionHelper.get_delegation_role_arn( + self.target_environment.AwsAccountId, self.target_environment.region + ) + ], database_name=self.shared_db_name, permissions=['ALL'], ) @@ -254,7 +260,7 @@ def check_pivot_role_permissions_to_source_database(self) -> None: Checks 'ALL' Lake Formation permissions to data.all PivotRole to the source database in source account :return: True if the permissions exists and are applied """ - principal = SessionHelper.get_delegation_role_arn(self.source_account_id) + principal = SessionHelper.get_delegation_role_arn(self.source_account_id, self.source_account_region) return self.lf_client_in_source.check_permissions_to_database( principals=[principal], database_name=self.source_database_name, @@ -266,7 +272,9 @@ def check_pivot_role_permissions_to_shared_database(self) -> None: Checks 'ALL' Lake Formation permissions to data.all PivotRole to the shared database in target account :return: True if the permissions exists and are applied """ - principal = SessionHelper.get_delegation_role_arn(self.target_environment.AwsAccountId) + principal = SessionHelper.get_delegation_role_arn( + self.target_environment.AwsAccountId, self.target_environment.region + ) return self.lf_client_in_target.check_permissions_to_database( principals=[principal], database_name=self.shared_db_name, @@ -322,7 +330,11 @@ def grant_pivot_role_drop_permissions_to_resource_link_table(self, table: Datase :return: True if it is successful """ self.lf_client_in_target.grant_permissions_to_table( - principals=[SessionHelper.get_delegation_role_arn(self.target_environment.AwsAccountId)], + principals=[ + SessionHelper.get_delegation_role_arn( + self.target_environment.AwsAccountId, self.target_environment.region + ) + ], database_name=self.shared_db_name, table_name=table.GlueTableName, catalog_id=self.target_environment.AwsAccountId, @@ -625,7 +637,7 @@ def _verify_catalog_ownership(self, catalog_account_id, catalog_region, catalog_ f'Database {self.dataset.GlueDatabaseName} is a resource link and ' f'the source database {catalog_database} belongs to a catalog account {catalog_account_id}' ) - if SessionHelper.is_assumable_pivot_role(catalog_account_id): + if SessionHelper.is_assumable_pivot_role(catalog_account_id, catalog_region): self._validate_catalog_ownership_tag(catalog_account_id, catalog_region, catalog_database) else: raise Exception(f'Pivot role is not assumable, catalog account {catalog_account_id} is not onboarded') diff --git a/backend/dataall/modules/dataset_sharing/services/share_managers/s3_access_point_share_manager.py b/backend/dataall/modules/dataset_sharing/services/share_managers/s3_access_point_share_manager.py index 51007e814..a2913f5f4 100644 --- a/backend/dataall/modules/dataset_sharing/services/share_managers/s3_access_point_share_manager.py +++ b/backend/dataall/modules/dataset_sharing/services/share_managers/s3_access_point_share_manager.py @@ -175,6 +175,7 @@ def check_target_role_access_policy(self) -> None: share_policy_service = SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, + region=self.target_environment.region, role_name=self.target_requester_IAMRoleName, resource_prefix=self.target_environment.resourcePrefix, ) @@ -202,7 +203,7 @@ def check_target_role_access_policy(self) -> None: ] version_id, policy_document = IAM.get_managed_policy_default_version( - self.target_environment.AwsAccountId, share_resource_policy_name + self.target_environment.AwsAccountId, self.target_environment.region, share_resource_policy_name ) s3_statement_index = SharePolicyService._get_statement_by_sid( @@ -282,6 +283,7 @@ def grant_target_role_access_policy(self): share_policy_service = SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, + region=self.target_environment.region, role_name=self.target_requester_IAMRoleName, resource_prefix=self.target_environment.resourcePrefix, ) @@ -306,7 +308,7 @@ def grant_target_role_access_policy(self): share_resource_policy_name = share_policy_service.generate_policy_name() version_id, policy_document = IAM.get_managed_policy_default_version( - self.target_account_id, share_resource_policy_name + self.target_account_id, self.target_environment.region, share_resource_policy_name ) key_alias = f'alias/{self.dataset.KmsAlias}' @@ -339,7 +341,11 @@ def grant_target_role_access_policy(self): ) IAM.update_managed_policy_default_version( - self.target_account_id, share_resource_policy_name, version_id, json.dumps(policy_document) + self.target_account_id, + self.target_environment.region, + share_resource_policy_name, + version_id, + json.dumps(policy_document), ) def check_access_point_and_policy(self) -> None: @@ -361,7 +367,9 @@ def check_access_point_and_policy(self) -> None: existing_policy = json.loads(existing_policy) statements = {item['Sid']: item for item in existing_policy['Statement']} - target_requester_id = SessionHelper.get_role_id(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_id = SessionHelper.get_role_id( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) error = False if f'{target_requester_id}0' not in statements.keys(): error = True @@ -425,7 +433,9 @@ def manage_access_point_and_policy(self): retries += 1 existing_policy = s3_client.get_access_point_policy(self.access_point_name) # requester will use this role to access resources - target_requester_id = SessionHelper.get_role_id(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_id = SessionHelper.get_role_id( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) if existing_policy: # Update existing access point policy logger.info( @@ -482,7 +492,9 @@ def check_dataset_bucket_key_policy(self) -> None: self.folder_errors.append(ShareErrorFormatter.dne_error_msg('KMS Key Policy', kms_key_id)) return - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) existing_policy = json.loads(existing_policy) counter = count() statements = {item.get('Sid', next(counter)): item for item in existing_policy.get('Statement', {})} @@ -510,8 +522,10 @@ def update_dataset_bucket_key_policy(self): kms_client = KmsClient(self.source_account_id, self.source_environment.region) kms_key_id = kms_client.get_key_id(key_alias) existing_policy = kms_client.get_key_policy(kms_key_id) - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) - pivot_role_name = SessionHelper.get_delegation_role_name() + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) + pivot_role_name = SessionHelper.get_delegation_role_name(self.source_environment.region) if existing_policy: existing_policy = json.loads(existing_policy) @@ -519,16 +533,17 @@ def update_dataset_bucket_key_policy(self): statements = {item.get('Sid', next(counter)): item for item in existing_policy.get('Statement', {})} if DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID in statements.keys(): - logger.info(f'KMS key policy already contains share statement {DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID}') + logger.info( + f'KMS key policy already contains share statement {DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID}, updating existing statement' + ) + else: logger.info( f'KMS key policy does not contain statement {DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID}, generating a new one' ) - statements[DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID] = ( - self.generate_enable_pivot_role_permissions_policy_statement( - pivot_role_name, self.dataset_account_id - ) - ) + statements[DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID] = ( + self.generate_enable_pivot_role_permissions_policy_statement(pivot_role_name, self.dataset_account_id) + ) if DATAALL_ACCESS_POINT_KMS_DECRYPT_SID in statements.keys(): logger.info( @@ -566,7 +581,9 @@ def revoke_access_in_access_point_policy(self): s3_client = S3ControlClient(self.source_account_id, self.source_environment.region) access_point_policy = json.loads(s3_client.get_access_point_policy(self.access_point_name)) access_point_arn = s3_client.get_bucket_access_point_arn(self.access_point_name) - target_requester_id = SessionHelper.get_role_id(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_id = SessionHelper.get_role_id( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) statements = {item['Sid']: item for item in access_point_policy['Statement']} if f'{target_requester_id}0' in statements.keys(): prefix_list = statements[f'{target_requester_id}0']['Condition']['StringLike']['s3:prefix'] @@ -610,6 +627,7 @@ def revoke_target_role_access_policy(self): share_policy_service = SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, + region=self.target_environment.region, role_name=self.target_requester_IAMRoleName, resource_prefix=self.target_environment.resourcePrefix, ) @@ -625,7 +643,7 @@ def revoke_target_role_access_policy(self): share_resource_policy_name = share_policy_service.generate_policy_name() version_id, policy_document = IAM.get_managed_policy_default_version( - self.target_account_id, share_resource_policy_name + self.target_account_id, self.target_environment.region, share_resource_policy_name ) key_alias = f'alias/{self.dataset.KmsAlias}' @@ -652,7 +670,11 @@ def revoke_target_role_access_policy(self): policy_document=policy_document, ) IAM.update_managed_policy_default_version( - self.target_account_id, share_resource_policy_name, version_id, json.dumps(policy_document) + self.target_account_id, + self.target_environment.region, + share_resource_policy_name, + version_id, + json.dumps(policy_document), ) def delete_dataset_bucket_key_policy( @@ -664,7 +686,9 @@ def delete_dataset_bucket_key_policy( kms_client = KmsClient(dataset.AwsAccountId, dataset.region) kms_key_id = kms_client.get_key_id(key_alias) existing_policy = json.loads(kms_client.get_key_policy(kms_key_id)) - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) counter = count() statements = {item.get('Sid', next(counter)): item for item in existing_policy.get('Statement', {})} if DATAALL_ACCESS_POINT_KMS_DECRYPT_SID in statements.keys(): @@ -738,6 +762,8 @@ def generate_enable_pivot_role_permissions_policy_statement(pivot_role_name, dat 'kms:ReEncrypt*', 'kms:TagResource', 'kms:UntagResource', + 'kms:DescribeKey', + 'kms:List*', ], 'Resource': '*', } diff --git a/backend/dataall/modules/dataset_sharing/services/share_managers/s3_bucket_share_manager.py b/backend/dataall/modules/dataset_sharing/services/share_managers/s3_bucket_share_manager.py index 7674931b9..7380da65b 100644 --- a/backend/dataall/modules/dataset_sharing/services/share_managers/s3_bucket_share_manager.py +++ b/backend/dataall/modules/dataset_sharing/services/share_managers/s3_bucket_share_manager.py @@ -89,6 +89,7 @@ def check_s3_iam_access(self) -> None: share_policy_service = SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, + region=self.target_environment.region, role_name=self.target_requester_IAMRoleName, resource_prefix=self.target_environment.resourcePrefix, ) @@ -111,7 +112,7 @@ def check_s3_iam_access(self) -> None: s3_target_resources = [f'arn:aws:s3:::{self.bucket_name}', f'arn:aws:s3:::{self.bucket_name}/*'] version_id, policy_document = IAM.get_managed_policy_default_version( - self.target_environment.AwsAccountId, share_resource_policy_name + self.target_environment.AwsAccountId, self.target_environment.region, share_resource_policy_name ) s3_statement_index = SharePolicyService._get_statement_by_sid( policy_document, f'{IAM_S3_BUCKETS_STATEMENT_SID}S3' @@ -190,6 +191,7 @@ def grant_s3_iam_access(self): share_policy_service = SharePolicyService( environmentUri=self.target_environment.environmentUri, account=self.target_environment.AwsAccountId, + region=self.target_environment.region, role_name=self.target_requester_IAMRoleName, resource_prefix=self.target_environment.resourcePrefix, ) @@ -216,7 +218,7 @@ def grant_s3_iam_access(self): logger.info(f'Share policy name is {share_resource_policy_name}') version_id, policy_document = IAM.get_managed_policy_default_version( - self.target_account_id, share_resource_policy_name + self.target_account_id, self.target_environment.region, share_resource_policy_name ) key_alias = f'alias/{self.target_bucket.KmsAlias}' @@ -244,7 +246,11 @@ def grant_s3_iam_access(self): ) IAM.update_managed_policy_default_version( - self.target_account_id, share_resource_policy_name, version_id, json.dumps(policy_document) + self.target_account_id, + self.target_environment.region, + share_resource_policy_name, + version_id, + json.dumps(policy_document), ) def get_bucket_policy_or_default(self): @@ -270,7 +276,9 @@ def check_role_bucket_policy(self) -> None: and add to bucket errors if check fails :return: None """ - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) s3_client = S3Client(self.source_account_id, self.source_environment.region) bucket_policy = s3_client.get_bucket_policy(self.bucket_name) error = False @@ -300,7 +308,9 @@ def grant_role_bucket_policy(self): """ logger.info(f'Granting access via Bucket policy for {self.bucket_name}') try: - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) bucket_policy = self.get_bucket_policy_or_default() counter = count() statements = {item.get('Sid', next(counter)): item for item in bucket_policy.get('Statement', {})} @@ -353,7 +363,9 @@ def check_dataset_bucket_key_policy(self) -> None: self.bucket_errors.append(ShareErrorFormatter.dne_error_msg('KMS Key Policy', kms_key_id)) return - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) existing_policy = json.loads(existing_policy) counter = count() statements = {item.get('Sid', next(counter)): item for item in existing_policy.get('Statement', {})} @@ -382,8 +394,10 @@ def grant_dataset_bucket_key_policy(self): kms_client = KmsClient(self.source_account_id, self.source_environment.region) kms_key_id = kms_client.get_key_id(key_alias) existing_policy = kms_client.get_key_policy(kms_key_id) - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) - pivot_role_name = SessionHelper.get_delegation_role_name() + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) + pivot_role_name = SessionHelper.get_delegation_role_name(self.source_environment.region) if existing_policy: existing_policy = json.loads(existing_policy) @@ -392,17 +406,18 @@ def grant_dataset_bucket_key_policy(self): if DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID in statements.keys(): logger.info( - f'KMS key policy already contains share statement {DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID}' + f'KMS key policy already contains share statement {DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID}, updating existing statement' ) + else: logger.info( f'KMS key policy does not contain statement {DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID}, generating a new one' ) - statements[DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID] = ( - self.generate_enable_pivot_role_permissions_policy_statement( - pivot_role_name, self.source_account_id - ) + statements[DATAALL_KMS_PIVOT_ROLE_PERMISSIONS_SID] = ( + self.generate_enable_pivot_role_permissions_policy_statement( + pivot_role_name, self.source_account_id ) + ) if DATAALL_BUCKET_KMS_DECRYPT_SID in statements.keys(): logger.info( @@ -438,7 +453,9 @@ def delete_target_role_bucket_policy(self): try: s3_client = S3Client(self.source_account_id, self.source_environment.region) bucket_policy = json.loads(s3_client.get_bucket_policy(self.bucket_name)) - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) counter = count() statements = {item.get('Sid', next(counter)): item for item in bucket_policy.get('Statement', {})} if DATAALL_READ_ONLY_SID in statements.keys(): @@ -466,6 +483,7 @@ def delete_target_role_access_policy( share_policy_service = SharePolicyService( role_name=share.principalIAMRoleName, account=target_environment.AwsAccountId, + region=self.target_environment.region, environmentUri=target_environment.environmentUri, resource_prefix=target_environment.resourcePrefix, ) @@ -479,7 +497,7 @@ def delete_target_role_access_policy( share_resource_policy_name = share_policy_service.generate_policy_name() version_id, policy_document = IAM.get_managed_policy_default_version( - self.target_account_id, share_resource_policy_name + self.target_account_id, self.target_environment.region, share_resource_policy_name ) key_alias = f'alias/{target_bucket.KmsAlias}' @@ -504,7 +522,11 @@ def delete_target_role_access_policy( ) IAM.update_managed_policy_default_version( - self.target_account_id, share_resource_policy_name, version_id, json.dumps(policy_document) + self.target_account_id, + self.target_environment.region, + share_resource_policy_name, + version_id, + json.dumps(policy_document), ) def delete_target_role_bucket_key_policy( @@ -517,7 +539,9 @@ def delete_target_role_bucket_key_policy( kms_client = KmsClient(target_bucket.AwsAccountId, target_bucket.region) kms_key_id = kms_client.get_key_id(key_alias) existing_policy = json.loads(kms_client.get_key_policy(kms_key_id)) - target_requester_arn = IAM.get_role_arn_by_name(self.target_account_id, self.target_requester_IAMRoleName) + target_requester_arn = IAM.get_role_arn_by_name( + self.target_account_id, self.target_environment.region, self.target_requester_IAMRoleName + ) counter = count() statements = {item.get('Sid', next(counter)): item for item in existing_policy.get('Statement', {})} if DATAALL_BUCKET_KMS_DECRYPT_SID in statements.keys(): @@ -602,6 +626,8 @@ def generate_enable_pivot_role_permissions_policy_statement(pivot_role_name, sou 'kms:ReEncrypt*', 'kms:TagResource', 'kms:UntagResource', + 'kms:DescribeKey', + 'kms:List*', ], 'Resource': '*', } diff --git a/backend/dataall/modules/dataset_sharing/services/share_object_service.py b/backend/dataall/modules/dataset_sharing/services/share_object_service.py index bbd00526b..3297a1273 100644 --- a/backend/dataall/modules/dataset_sharing/services/share_object_service.py +++ b/backend/dataall/modules/dataset_sharing/services/share_object_service.py @@ -114,6 +114,7 @@ def create_share_object( share_policy_service = SharePolicyService( account=environment.AwsAccountId, + region=environment.region, role_name=principal_iam_role_name, environmentUri=environment.environmentUri, resource_prefix=environment.resourcePrefix, diff --git a/backend/dataall/modules/datasets/aws/athena_table_client.py b/backend/dataall/modules/datasets/aws/athena_table_client.py index 4181ca9e4..3c5bf01ec 100644 --- a/backend/dataall/modules/datasets/aws/athena_table_client.py +++ b/backend/dataall/modules/datasets/aws/athena_table_client.py @@ -14,7 +14,7 @@ class AthenaTableClient: def __init__(self, env: Environment, table: DatasetTable): - session = SessionHelper.remote_session(accountid=table.AWSAccountId) + session = SessionHelper.remote_session(accountid=table.AWSAccountId, region=env.region) self._client = session.client('athena', region_name=env.region) self._creds = session.get_credentials() self._env = env diff --git a/backend/dataall/modules/datasets/aws/glue_dataset_client.py b/backend/dataall/modules/datasets/aws/glue_dataset_client.py index 93a42a858..0cc589254 100644 --- a/backend/dataall/modules/datasets/aws/glue_dataset_client.py +++ b/backend/dataall/modules/datasets/aws/glue_dataset_client.py @@ -9,9 +9,8 @@ class DatasetCrawler: def __init__(self, dataset: Dataset): - session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) - region = dataset.region if dataset.region else 'eu-west-1' - self._client = session.client('glue', region_name=region) + session = SessionHelper.remote_session(accountid=dataset.AwsAccountId, region=dataset.region) + self._client = session.client('glue', region_name=dataset.region) self._dataset = dataset def get_crawler(self, crawler_name=None): diff --git a/backend/dataall/modules/datasets/aws/glue_profiler_client.py b/backend/dataall/modules/datasets/aws/glue_profiler_client.py index 4b6f3eac3..0526a7137 100644 --- a/backend/dataall/modules/datasets/aws/glue_profiler_client.py +++ b/backend/dataall/modules/datasets/aws/glue_profiler_client.py @@ -12,7 +12,7 @@ class GlueDatasetProfilerClient: """Controls glue profiling jobs in AWS""" def __init__(self, dataset: Dataset): - session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) + session = SessionHelper.remote_session(accountid=dataset.AwsAccountId, region=dataset.region) self._client = session.client('glue', region_name=dataset.region) self._name = dataset.GlueProfilingJobName diff --git a/backend/dataall/modules/datasets/aws/lf_dataset_client.py b/backend/dataall/modules/datasets/aws/lf_dataset_client.py index f10ae963f..13cffc52d 100644 --- a/backend/dataall/modules/datasets/aws/lf_dataset_client.py +++ b/backend/dataall/modules/datasets/aws/lf_dataset_client.py @@ -11,7 +11,7 @@ class LakeFormationDatasetClient: def __init__(self, env: Environment, dataset: Dataset): - session = SessionHelper.remote_session(env.AwsAccountId) + session = SessionHelper.remote_session(env.AwsAccountId, env.region) self._client = session.client('lakeformation', region_name=env.region) self._dataset = dataset self._env = env diff --git a/backend/dataall/modules/datasets/aws/lf_table_client.py b/backend/dataall/modules/datasets/aws/lf_table_client.py index a248d004c..c0daade69 100644 --- a/backend/dataall/modules/datasets/aws/lf_table_client.py +++ b/backend/dataall/modules/datasets/aws/lf_table_client.py @@ -12,7 +12,7 @@ class LakeFormationTableClient: def __init__(self, table: DatasetTable, aws_session=None): if not aws_session: - aws_session = SessionHelper.remote_session(table.AWSAccountId) + aws_session = SessionHelper.remote_session(table.AWSAccountId, table.region) self._client = aws_session.client('lakeformation', region_name=table.region) self._table = table @@ -22,7 +22,7 @@ def grant_pivot_role_all_table_permissions(self): for tables managed inside dataall """ table = self._table - principal = SessionHelper.get_delegation_role_arn(table.AWSAccountId) + principal = SessionHelper.get_delegation_role_arn(table.AWSAccountId, table.region) self._grant_permissions_to_table(principal, ['SELECT', 'ALTER', 'DROP', 'INSERT']) def grant_principals_all_table_permissions(self, principals: [str]): diff --git a/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py b/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py index f79ca6d5d..bc449326e 100644 --- a/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py +++ b/backend/dataall/modules/datasets/aws/s3_dataset_bucket_policy_client.py @@ -11,7 +11,7 @@ class S3DatasetBucketPolicyClient: def __init__(self, dataset: Dataset): - session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) + session = SessionHelper.remote_session(accountid=dataset.AwsAccountId, region=dataset.region) self._client = session.client('s3') self._dataset = dataset diff --git a/backend/dataall/modules/datasets/aws/s3_dataset_client.py b/backend/dataall/modules/datasets/aws/s3_dataset_client.py index aa5ffc98f..dd199aec0 100644 --- a/backend/dataall/modules/datasets/aws/s3_dataset_client.py +++ b/backend/dataall/modules/datasets/aws/s3_dataset_client.py @@ -16,7 +16,7 @@ def __init__(self, dataset: Dataset): It first starts a session assuming the pivot role, then we define another session assuming the dataset role from the pivot role """ - self._pivot_role_session = SessionHelper.remote_session(accountid=dataset.AwsAccountId) + self._pivot_role_session = SessionHelper.remote_session(accountid=dataset.AwsAccountId, region=dataset.region) self._client = self._pivot_role_session.client('s3') self._dataset = dataset diff --git a/backend/dataall/modules/datasets/aws/s3_location_client.py b/backend/dataall/modules/datasets/aws/s3_location_client.py index ef79e8484..3d1041183 100644 --- a/backend/dataall/modules/datasets/aws/s3_location_client.py +++ b/backend/dataall/modules/datasets/aws/s3_location_client.py @@ -12,7 +12,7 @@ def __init__(self, location: DatasetStorageLocation, dataset: Dataset): It first starts a session assuming the pivot role, then we define another session assuming the dataset role from the pivot role """ - pivot_role_session = SessionHelper.remote_session(accountid=location.AWSAccountId) + pivot_role_session = SessionHelper.remote_session(accountid=location.AWSAccountId, region=location.region) session = SessionHelper.get_session(base_session=pivot_role_session, role_arn=dataset.IAMDatasetAdminRoleArn) self._client = session.client('s3', region_name=location.region) self._location = location diff --git a/backend/dataall/modules/datasets/aws/s3_profiler_client.py b/backend/dataall/modules/datasets/aws/s3_profiler_client.py index 210f99f2a..0f0deb7f4 100644 --- a/backend/dataall/modules/datasets/aws/s3_profiler_client.py +++ b/backend/dataall/modules/datasets/aws/s3_profiler_client.py @@ -8,7 +8,7 @@ class S3ProfilerClient: def __init__(self, env: Environment): - self._client = SessionHelper.remote_session(env.AwsAccountId).client('s3', region_name=env.region) + self._client = SessionHelper.remote_session(env.AwsAccountId, env.region).client('s3', region_name=env.region) self._env = env def get_profiling_results_from_s3(self, dataset, table, run): diff --git a/backend/dataall/modules/datasets/aws/sns_dataset_client.py b/backend/dataall/modules/datasets/aws/sns_dataset_client.py index 5114f927b..9aaa9e63b 100644 --- a/backend/dataall/modules/datasets/aws/sns_dataset_client.py +++ b/backend/dataall/modules/datasets/aws/sns_dataset_client.py @@ -12,7 +12,7 @@ class SnsDatasetClient: def __init__(self, environment: Environment, dataset: Dataset): - aws_session = SessionHelper.remote_session(accountid=environment.AwsAccountId) + aws_session = SessionHelper.remote_session(accountid=environment.AwsAccountId, region=environment.region) self._client = aws_session.client('sns', region_name=environment.region) self._topic = ( diff --git a/backend/dataall/modules/datasets/cdk/dataset_stack.py b/backend/dataall/modules/datasets/cdk/dataset_stack.py index dbd19b07e..67cb8809e 100644 --- a/backend/dataall/modules/datasets/cdk/dataset_stack.py +++ b/backend/dataall/modules/datasets/cdk/dataset_stack.py @@ -95,10 +95,10 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): # Read input self.target_uri = target_uri - self.pivot_role_name = SessionHelper.get_delegation_role_name() dataset = self.get_target() env = self.get_env(dataset) env_group = self.get_env_group(dataset) + self.pivot_role_name = SessionHelper.get_delegation_role_name(region=env.region) quicksight_default_group_arn = None if self.has_quicksight_enabled(env): @@ -151,9 +151,7 @@ def __init__(self, scope, id, target_uri: str = None, **kwargs): 'kms:ReEncrypt*', 'kms:TagResource', 'kms:UntagResource', - 'kms:DeleteAlias', 'kms:DescribeKey', - 'kms:CreateAlias', 'kms:List*', ], resources=['*'], diff --git a/backend/dataall/modules/datasets/handlers/glue_table_sync_handler.py b/backend/dataall/modules/datasets/handlers/glue_table_sync_handler.py index 17eb96ba5..9d56df046 100644 --- a/backend/dataall/modules/datasets/handlers/glue_table_sync_handler.py +++ b/backend/dataall/modules/datasets/handlers/glue_table_sync_handler.py @@ -20,7 +20,7 @@ def update_table_columns(engine, task: Task): column: DatasetTableColumn = session.query(DatasetTableColumn).get(task.targetUri) table: DatasetTable = session.query(DatasetTable).get(column.tableUri) - aws_session = SessionHelper.remote_session(table.AWSAccountId) + aws_session = SessionHelper.remote_session(table.AWSAccountId, table.region) lf_client = LakeFormationTableClient(table=table, aws_session=aws_session) lf_client.grant_pivot_role_all_table_permissions() diff --git a/backend/dataall/modules/datasets/services/dataset_column_service.py b/backend/dataall/modules/datasets/services/dataset_column_service.py index 5e5214e83..8a95d2f32 100644 --- a/backend/dataall/modules/datasets/services/dataset_column_service.py +++ b/backend/dataall/modules/datasets/services/dataset_column_service.py @@ -51,7 +51,7 @@ def sync_table_columns(cls, table_uri: str): context = get_context() with context.db_engine.scoped_session() as session: table: DatasetTable = DatasetTableRepository.get_dataset_table_by_uri(session, table_uri) - aws = SessionHelper.remote_session(table.AWSAccountId) + aws = SessionHelper.remote_session(table.AWSAccountId, table.region) glue_table = GlueTableClient(aws, table).get_table() DatasetTableRepository.sync_table_columns(session, table, glue_table['Table']) diff --git a/backend/dataall/modules/datasets/services/dataset_service.py b/backend/dataall/modules/datasets/services/dataset_service.py index 05e601728..f47bf6e71 100644 --- a/backend/dataall/modules/datasets/services/dataset_service.py +++ b/backend/dataall/modules/datasets/services/dataset_service.py @@ -293,11 +293,13 @@ def get_dataset_assume_role_url(uri): ) role_arn = env_group.environmentIAMRoleArn account_id = shared_environment.AwsAccountId + region = shared_environment.region else: role_arn = dataset.IAMDatasetAdminRoleArn account_id = dataset.AwsAccountId + region = dataset.region - pivot_session = SessionHelper.remote_session(account_id) + pivot_session = SessionHelper.remote_session(account_id, region) aws_session = SessionHelper.get_session(base_session=pivot_session, role_arn=role_arn) url = SessionHelper.get_console_access_url( aws_session, @@ -354,7 +356,7 @@ def generate_dataset_access_token(uri): with get_context().db_engine.scoped_session() as session: dataset = DatasetRepository.get_dataset_by_uri(session, uri) - pivot_session = SessionHelper.remote_session(dataset.AwsAccountId) + pivot_session = SessionHelper.remote_session(dataset.AwsAccountId, dataset.region) aws_session = SessionHelper.get_session(base_session=pivot_session, role_arn=dataset.IAMDatasetAdminRoleArn) c = aws_session.get_credentials() credentials = { diff --git a/backend/dataall/modules/datasets/tasks/tables_syncer.py b/backend/dataall/modules/datasets/tasks/tables_syncer.py index 719e56ce7..67032618a 100644 --- a/backend/dataall/modules/datasets/tasks/tables_syncer.py +++ b/backend/dataall/modules/datasets/tasks/tables_syncer.py @@ -60,7 +60,7 @@ def sync_tables(engine): for table in tables: LakeFormationTableClient(table).grant_principals_all_table_permissions( principals=[ - SessionHelper.get_delegation_role_arn(env.AwsAccountId), + SessionHelper.get_delegation_role_arn(env.AwsAccountId, env.region), env_group.environmentIAMRoleArn, ], ) @@ -79,7 +79,7 @@ def sync_tables(engine): def is_assumable_pivot_role(env: Environment): - aws_session = SessionHelper.remote_session(accountid=env.AwsAccountId) + aws_session = SessionHelper.remote_session(accountid=env.AwsAccountId, region=env.region) if not aws_session: log.error(f'Failed to assume dataall pivot role in environment {env.AwsAccountId}') return False diff --git a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py index 98482e825..3c67d47d4 100644 --- a/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py +++ b/backend/dataall/modules/mlstudio/aws/sagemaker_studio_client.py @@ -8,7 +8,7 @@ def get_client(AwsAccountId, region): - session = SessionHelper.remote_session(AwsAccountId) + session = SessionHelper.remote_session(AwsAccountId, region) return session.client('sagemaker', region_name=region) diff --git a/backend/dataall/modules/notebooks/aws/sagemaker_notebook_client.py b/backend/dataall/modules/notebooks/aws/sagemaker_notebook_client.py index 0d674ca35..0988d1f66 100644 --- a/backend/dataall/modules/notebooks/aws/sagemaker_notebook_client.py +++ b/backend/dataall/modules/notebooks/aws/sagemaker_notebook_client.py @@ -13,7 +13,7 @@ class SagemakerClient: """ def __init__(self, notebook: SagemakerNotebook): - session = SessionHelper.remote_session(notebook.AWSAccountId) + session = SessionHelper.remote_session(notebook.AWSAccountId, notebook.region) self._client = session.client('sagemaker', region_name=notebook.region) self._instance_name = notebook.NotebookInstanceName diff --git a/backend/dataall/modules/worksheets/aws/athena_client.py b/backend/dataall/modules/worksheets/aws/athena_client.py index d2808808d..60b552b36 100644 --- a/backend/dataall/modules/worksheets/aws/athena_client.py +++ b/backend/dataall/modules/worksheets/aws/athena_client.py @@ -7,7 +7,7 @@ class AthenaClient: @staticmethod def run_athena_query(aws_account_id, env_group, s3_staging_dir, region, sql=None): - base_session = SessionHelper.remote_session(accountid=aws_account_id) + base_session = SessionHelper.remote_session(accountid=aws_account_id, region=region) boto3_session = SessionHelper.get_session(base_session=base_session, role_arn=env_group.environmentIAMRoleArn) creds = boto3_session.get_credentials() connection = connect( diff --git a/config.json b/config.json index a79a17d0f..140595224 100644 --- a/config.json +++ b/config.json @@ -38,7 +38,8 @@ }, "core": { "features": { - "env_aws_actions": true + "env_aws_actions": true, + "cdk_pivot_role_multiple_environments_same_account": false } } } \ No newline at end of file diff --git a/deploy/stacks/container.py b/deploy/stacks/container.py index 7ded368c0..288ae714d 100644 --- a/deploy/stacks/container.py +++ b/deploy/stacks/container.py @@ -472,7 +472,7 @@ def create_task_role(self, envname, resource_prefix, pivot_role_name): 'sts:AssumeRole', ], resources=[ - f'arn:aws:iam::*:role/{pivot_role_name}', + f'arn:aws:iam::*:role/{pivot_role_name}*', 'arn:aws:iam::*:role/cdk*', f'arn:aws:iam::{self.account}:role/{resource_prefix}-{envname}-ecs-tasks-role', ], diff --git a/deploy/stacks/lambda_api.py b/deploy/stacks/lambda_api.py index 3fce0d215..f2952d732 100644 --- a/deploy/stacks/lambda_api.py +++ b/deploy/stacks/lambda_api.py @@ -316,7 +316,7 @@ def create_function_role(self, envname, resource_prefix, fn_name, pivot_role_nam 'sts:AssumeRole', ], resources=[ - f'arn:aws:iam::*:role/{pivot_role_name}', + f'arn:aws:iam::*:role/{pivot_role_name}*', 'arn:aws:iam::*:role/cdk-hnb659fds-lookup-role-*', ], ), diff --git a/tests/modules/datasets/tasks/test_lf_share_manager.py b/tests/modules/datasets/tasks/test_lf_share_manager.py index 8d5a698a6..46aa422e5 100644 --- a/tests/modules/datasets/tasks/test_lf_share_manager.py +++ b/tests/modules/datasets/tasks/test_lf_share_manager.py @@ -108,7 +108,7 @@ def processor_with_mocks( ) mocker.patch( 'dataall.base.aws.iam.IAM.get_role_arn_by_name', - side_effect=lambda account_id, role_name: f'arn:aws:iam::{account_id}:role/{role_name}', + side_effect=lambda account_id, region, role_name: f'arn:aws:iam::{account_id}:role/{role_name}', ) mock_glue_client().get_glue_database.return_value = False @@ -166,7 +166,7 @@ def test_get_share_principals( processor, lf_client, glue_client, mock_glue_client = processor_with_mocks get_iam_role_arn_mock = mocker.patch( 'dataall.base.aws.iam.IAM.get_role_arn_by_name', - side_effect=lambda account_id, role_name: f'arn:aws:iam::{account_id}:role/{role_name}', + side_effect=lambda account_id, region, role_name: f'arn:aws:iam::{account_id}:role/{role_name}', ) # Then, it should return @@ -754,7 +754,7 @@ def test_check_catalog_account_exists_and_update_processor_with_catalog_exists( ) mocker.patch( 'dataall.base.aws.iam.IAM.get_role_arn_by_name', - side_effect=lambda account_id, role_name: f'arn:aws:iam::{account_id}:role/{role_name}', + side_effect=lambda account_id, region, role_name: f'arn:aws:iam::{account_id}:role/{role_name}', ) mock_glue_client().get_glue_database.return_value = False diff --git a/tests/modules/datasets/tasks/test_s3_access_point_share_manager.py b/tests/modules/datasets/tasks/test_s3_access_point_share_manager.py index 57da4b355..17124433b 100644 --- a/tests/modules/datasets/tasks/test_s3_access_point_share_manager.py +++ b/tests/modules/datasets/tasks/test_s3_access_point_share_manager.py @@ -370,11 +370,16 @@ def test_grant_target_role_access_policy_test_empty_policy( environmentUri=target_environment.environmentUri, role_name=share1.principalIAMRoleName, account=target_environment.AwsAccountId, + region=target_environment.region, resource_prefix=target_environment.resourcePrefix, ).generate_policy_name() # Then iam_update_role_policy_mock.assert_called_with( - target_environment.AwsAccountId, expected_policy_name, 'v1', json.dumps(expected_policy) + target_environment.AwsAccountId, + target_environment.region, + expected_policy_name, + 'v1', + json.dumps(expected_policy), ) @@ -919,11 +924,16 @@ def test_delete_target_role_access_policy_no_remaining_statement( environmentUri=target_environment.environmentUri, role_name=share1.principalIAMRoleName, account=target_environment.AwsAccountId, + region=target_environment.region, resource_prefix=target_environment.resourcePrefix, ).generate_policy_name() iam_update_role_policy_mock.assert_called_with( - target_environment.AwsAccountId, expected_policy_name, 'v1', json.dumps(expected_remaining_target_role_policy) + target_environment.AwsAccountId, + target_environment.region, + expected_policy_name, + 'v1', + json.dumps(expected_remaining_target_role_policy), ) @@ -1013,11 +1023,16 @@ def test_delete_target_role_access_policy_with_remaining_statement( environmentUri=target_environment.environmentUri, role_name=share1.principalIAMRoleName, account=target_environment.AwsAccountId, + region=target_environment.region, resource_prefix=target_environment.resourcePrefix, ).generate_policy_name() iam_update_role_policy_mock.assert_called_with( - target_environment.AwsAccountId, expected_policy_name, 'v1', json.dumps(expected_remaining_target_role_policy) + target_environment.AwsAccountId, + target_environment.region, + expected_policy_name, + 'v1', + json.dumps(expected_remaining_target_role_policy), )