Skip to content

Commit

Permalink
Add Central catalog support
Browse files Browse the repository at this point in the history
  • Loading branch information
blitzmohit committed Dec 7, 2023
1 parent 473a1b6 commit d85b416
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 77 deletions.
10 changes: 10 additions & 0 deletions backend/dataall/base/aws/sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,13 @@ def generate_console_url(credentials, session_duration=None, region='eu-west-1',

# Send final URL to stdout
return request_url

@staticmethod
def is_assumable_pivot_role(accountid):
aws_session = SessionHelper.remote_session(accountid=accountid)
if aws_session is None:
log.error(
f'Failed to assume dataall pivot role in environment {accountid}'
)
return False
return True
37 changes: 37 additions & 0 deletions backend/dataall/modules/dataset_sharing/aws/glue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from botocore.exceptions import ClientError

from dataall.base.aws.sts import SessionHelper
from dataall.modules.dataset_sharing.db.share_object_models import Catalog

log = logging.getLogger(__name__)

Expand All @@ -13,6 +14,7 @@ def __init__(self, account_id, region, database):
self._client = aws_session.client('glue', region_name=region)
self._database = database
self._account_id = account_id
self._region = region

def create_database(self, location):
try:
Expand Down Expand Up @@ -130,3 +132,38 @@ def delete_database(self):
f'due to: {e}'
)
raise e

def get_source_catalog(self):
""" Get the source catalog account details """
try:
log.info(f'Fetching source catalog details for database {self._database}...')
response = self._client.get_database(CatalogId=self._account_id, Name=self._database)
linked_database = response.get('Database', {}).get('TargetDatabase', {})
log.info(f'Fetched source catalog details for database {self._database} are: {linked_database}...')
if linked_database:
return Catalog(account_id=linked_database.get('CatalogId'),
database_name=linked_database.get('DatabaseName'),
region=linked_database.get('Region', self._region))
except Exception as e:
log.exception(f'Could not fetch source catalog details for database {self._database} due to {e}')
raise e
return None

def get_database_tags(self):
# Get tags from the glue database
account_id = self._account_id
database = self._database
region = self._region

try:
log.info(f'Getting tags for database {database}...')
resource_arn = f'arn:aws:glue:{region}:{account_id}:database/{database}'
response = self._client.get_tags(ResourceArn=resource_arn)
tags = response['Tags']

log.info(f'Successfully retrieved tags: {tags}')

return tags
except Exception as e:
log.exception(f'Could not get tags for database {database} due to {e}')
raise e
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import dataclass
from datetime import datetime, timedelta
from uuid import uuid4

Expand Down Expand Up @@ -58,3 +59,11 @@ class ShareObjectItem(Base):
S3AccessPointName = Column(String, nullable=True)
status = Column(String, nullable=False, default=ShareItemStatus.PendingApproval.value)
action = Column(String, nullable=True)


@dataclass
class Catalog:
""" Can be expanded to include other details once broader catalog support is added to data.all """
account_id: str
database_name: str
region: str
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from dataall.modules.dataset_sharing.services.share_processors.s3_process_share import ProcessS3Share

from dataall.base.db import Engine
from dataall.modules.dataset_sharing.db.enums import ShareObjectActions, ShareItemStatus, ShareableType
from dataall.modules.dataset_sharing.db.enums import ShareObjectActions, ShareItemStatus, ShareableType, \
ShareItemActions
from dataall.modules.dataset_sharing.db.share_object_repositories import ShareObjectSM, ShareObjectRepository, ShareItemSM
from dataall.modules.dataset_sharing.aws.glue_client import GlueClient

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,32 +69,22 @@ def approve_share(cls, engine: Engine, share_uri: str) -> bool:
)
log.info(f'sharing folders succeeded = {approved_folders_succeed}')

if source_environment.AwsAccountId != target_environment.AwsAccountId:
processor = ProcessLFCrossAccountShare(
session,
dataset,
share,
shared_tables,
[],
source_environment,
target_environment,
env_group,
)
processor = DataSharingService.create_lf_processor(
session=session,
dataset=dataset,
share=share,
shared_tables=shared_tables,
revoked_tables=[],
source_environment=source_environment,
target_environment=target_environment,
env_group=env_group,
)
if processor:
log.info(f'Granting permissions to tables: {shared_tables}')
approved_tables_succeed = processor.process_approved_shares()
log.info(f'sharing tables succeeded = {approved_tables_succeed}')
else:
processor = ProcessLFSameAccountShare(
session,
dataset,
share,
shared_tables,
[],
source_environment,
target_environment,
env_group
)

log.info(f'Granting permissions to tables: {shared_tables}')
approved_tables_succeed = processor.process_approved_shares()
log.info(f'sharing tables succeeded = {approved_tables_succeed}')
approved_tables_succeed = False

new_share_state = share_sm.run_transition(ShareObjectActions.Finish.value)
share_sm.update_state(session, share, new_share_state)
Expand Down Expand Up @@ -173,32 +165,22 @@ def revoke_share(cls, engine: Engine, share_uri: str):
)
log.info(f"Clean up S3 successful = {clean_up_folders}")

if source_environment.AwsAccountId != target_environment.AwsAccountId:
processor = ProcessLFCrossAccountShare(
session,
dataset,
share,
[],
revoked_tables,
source_environment,
target_environment,
env_group,
)
processor = DataSharingService.create_lf_processor(
session=session,
dataset=dataset,
share=share,
shared_tables=[],
revoked_tables=revoked_tables,
source_environment=source_environment,
target_environment=target_environment,
env_group=env_group,
)
if processor:
log.info(f'Revoking permissions to tables: {revoked_tables}')
revoked_tables_succeed = processor.process_revoked_shares()
log.info(f'revoking tables succeeded = {revoked_tables_succeed}')
else:
processor = ProcessLFSameAccountShare(
session,
dataset,
share,
[],
revoked_tables,
source_environment,
target_environment,
env_group)

log.info(f'Revoking permissions to tables: {revoked_tables}')
revoked_tables_succeed = processor.process_revoked_shares()
log.info(f'revoking tables succeeded = {revoked_tables_succeed}')

revoked_tables_succeed = False
existing_shared_items = ShareObjectRepository.check_existing_shared_items_of_type(
session,
share_uri,
Expand All @@ -218,3 +200,63 @@ def revoke_share(cls, engine: Engine, share_uri: str):
share_sm.update_state(session, share, new_share_state)

return revoked_tables_succeed and revoked_folders_succeed

@staticmethod
def create_lf_processor(session,
dataset,
share,
shared_tables,
revoked_tables,
source_environment,
target_environment,
env_group):
try:
catalog_details = GlueClient(database=dataset.GlueDatabaseName,
account_id=source_environment.AwsAccountId,
region=source_environment.region).get_source_catalog()

source_account_id = catalog_details.account_id if catalog_details else source_environment.AwsAccountId

if source_account_id != target_environment.AwsAccountId:
processor = ProcessLFCrossAccountShare(
session,
dataset,
share,
shared_tables,
revoked_tables,
source_environment,
target_environment,
env_group,
catalog_details
)
else:
processor = ProcessLFSameAccountShare(
session,
dataset,
share,
shared_tables,
revoked_tables,
source_environment,
target_environment,
env_group,
)
return processor
except Exception as e:
log.error(f"Error creating LF processor: {e}")
for table in shared_tables:
DataSharingService._handle_table_share_failure(session, share, table, ShareItemStatus.Share_Approved.value)
for table in revoked_tables:
DataSharingService._handle_table_share_failure(session, share, table, ShareItemStatus.Revoke_Approved.value)

@staticmethod
def _handle_table_share_failure(session, share, table, share_item_status):
""" Mark the share item as failed for the approved/revoked tables """
log.error(f'Marking share item as failed for table {table.GlueTableName}')
share_item = ShareObjectRepository.find_sharable_item(
session, share.shareUri, table.tableUri
)
share_item_sm = ShareItemSM(share_item_status)
new_state = share_item_sm.run_transition(ShareObjectActions.Start.value)
share_item_sm.update_state_single_item(session, share_item, new_state)
new_state = share_item_sm.run_transition(ShareItemActions.Failure.value)
share_item_sm.update_state_single_item(session, share_item, new_state)
Loading

0 comments on commit d85b416

Please sign in to comment.