From 78bacccc2a75520e9db47cc6a08123abd44b2979 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lars-Olav=20V=C3=A5gene?= <42863501+LarsV123@users.noreply.github.com> Date: Tue, 14 Jan 2025 08:43:15 +0100 Subject: [PATCH] style: Reformat with ruff (#5) --- cli.py | 4 +- commands/awslambda.py | 15 ++- commands/cognito.py | 17 ++- commands/customers.py | 30 ++++- commands/handle.py | 72 ++++++++--- commands/services/aws_utils.py | 10 +- commands/services/cognito_api.py | 46 ++++--- commands/services/customers_api.py | 58 +++++---- commands/services/dynamodb_export.py | 71 ++++++---- commands/services/external_user.py | 74 ++++++----- commands/services/handle_api.py | 47 ++++--- commands/services/handle_task_executor.py | 2 +- commands/services/handle_task_writer.py | 40 +++--- commands/services/lambda_api.py | 32 +++-- commands/services/publication_api.py | 55 +++++--- commands/services/users_api.py | 41 +++--- commands/users.py | 52 ++++++-- fileService.py | 151 ++++++++++------------ 18 files changed, 500 insertions(+), 317 deletions(-) diff --git a/cli.py b/cli.py index 0835247..9c0831d 100755 --- a/cli.py +++ b/cli.py @@ -7,10 +7,12 @@ from commands.customers import customers from commands.awslambda import awslambda + @click.group() def cli(): pass + cli.add_command(cognito) cli.add_command(handle) cli.add_command(users) @@ -18,4 +20,4 @@ def cli(): cli.add_command(awslambda) if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/commands/awslambda.py b/commands/awslambda.py index 4e144d2..08eb8d0 100644 --- a/commands/awslambda.py +++ b/commands/awslambda.py @@ -1,12 +1,19 @@ import click from commands.services.lambda_api import LambdaService + @click.group() def awslambda(): pass + @awslambda.command(help="Delete old versions of AWS Lambda functions.") -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -@click.option('--delete', is_flag=True, default=True, help='Delete old versions.') -def delete_old_versions(profile:str, delete:bool) -> None: - LambdaService(profile).delete_old_versions(delete) \ No newline at end of file +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +@click.option("--delete", is_flag=True, default=True, help="Delete old versions.") +def delete_old_versions(profile: str, delete: bool) -> None: + LambdaService(profile).delete_old_versions(delete) diff --git a/commands/cognito.py b/commands/cognito.py index eaea049..9163699 100644 --- a/commands/cognito.py +++ b/commands/cognito.py @@ -3,14 +3,21 @@ from commands.services.cognito_api import CognitoService from commands.services.aws_utils import prettify + @click.group() def cognito(): pass + @cognito.command(help="Search users by user attribute values") -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -@click.argument('search_term', required=True, nargs=-1) -def search(profile:str, search_term:str) -> None: - search_term = ' '.join(search_term) +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +@click.argument("search_term", required=True, nargs=-1) +def search(profile: str, search_term: str) -> None: + search_term = " ".join(search_term) result = CognitoService(profile).search(search_term) - click.echo(prettify(result)) \ No newline at end of file + click.echo(prettify(result)) diff --git a/commands/customers.py b/commands/customers.py index fb92e9c..f1a745b 100644 --- a/commands/customers.py +++ b/commands/customers.py @@ -1,20 +1,38 @@ import click -from commands.services.customers_api import list_missing_customers, list_duplicate_customers +from commands.services.customers_api import ( + list_missing_customers, + list_duplicate_customers, +) from commands.services.aws_utils import prettify + @click.group() def customers(): pass -@customers.command(help="Search customer references from users that does not exsist in the customer table") -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') + +@customers.command( + help="Search customer references from users that does not exsist in the customer table" +) +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) def list_missing(profile) -> None: result = list_missing_customers(profile) click.echo(prettify(result)) + @customers.command(help="Search dubplicate customer references (same cristin id)") -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -def list_duplicate(profile:str) -> None: +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +def list_duplicate(profile: str) -> None: result = list_duplicate_customers(profile) - click.echo(prettify(result)) \ No newline at end of file + click.echo(prettify(result)) diff --git a/commands/handle.py b/commands/handle.py index 5890ff7..51071f1 100644 --- a/commands/handle.py +++ b/commands/handle.py @@ -4,25 +4,51 @@ import shutil from boto3.dynamodb.conditions import Key from commands.services.handle_task_writer import HandleTaskWriterService -from commands.services.handle_task_executor import HandleTaskExecutorService +from commands.services.handle_task_executor import HandleTaskExecutorService from commands.services.dynamodb_export import DynamodbExport, get_account_alias + @click.group() def handle(): pass + @handle.command() -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -@click.option('-c', '--customer', required=True, help='Customer UUID. e.g. bb3d0c0c-5065-4623-9b98-5810983c2478') -@click.option('-r', '--resource-owner', required=True, help='Resource owner ID. e.g. ntnu@194.0.0.0') -@click.option('-o', '--output-folder', required=False, help='Output folder path. e.g. sikt-nva-sandbox-resources-ntnu@194.0.0.0-handle-tasks') -def prepare(profile:str, customer:str, resource_owner:str, output_folder:str) -> None: - table_pattern = '^nva-resources-master-pipelines-NvaPublicationApiPipeline-.*-nva-publication-api$' - condition = Key('PK0').eq(f'Resource:{customer}:{resource_owner}') +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +@click.option( + "-c", + "--customer", + required=True, + help="Customer UUID. e.g. bb3d0c0c-5065-4623-9b98-5810983c2478", +) +@click.option( + "-r", + "--resource-owner", + required=True, + help="Resource owner ID. e.g. ntnu@194.0.0.0", +) +@click.option( + "-o", + "--output-folder", + required=False, + help="Output folder path. e.g. sikt-nva-sandbox-resources-ntnu@194.0.0.0-handle-tasks", +) +def prepare( + profile: str, customer: str, resource_owner: str, output_folder: str +) -> None: + table_pattern = "^nva-resources-master-pipelines-NvaPublicationApiPipeline-.*-nva-publication-api$" + condition = Key("PK0").eq(f"Resource:{customer}:{resource_owner}") batch_size = 700 if not output_folder: - output_folder = f'{get_account_alias(profile)}-resources-{resource_owner}-handle-tasks' + output_folder = ( + f"{get_account_alias(profile)}-resources-{resource_owner}-handle-tasks" + ) # Create output folder if not exists if not os.path.exists(output_folder): @@ -31,13 +57,13 @@ def prepare(profile:str, customer:str, resource_owner:str, output_folder:str) -> action_counts = {} def process_batch(batch, batch_counter): - with open(f'{output_folder}/batch_{batch_counter}.jsonl', 'w') as outfile: + with open(f"{output_folder}/batch_{batch_counter}.jsonl", "w") as outfile: for data in batch: task = HandleTaskWriterService().process_item(data) - action = task.get('action') + action = task.get("action") action_counts[action] = action_counts.get(action, 0) + 1 json.dump(task, outfile) - outfile.write('\n') + outfile.write("\n") DynamodbExport(profile, table_pattern, condition, batch_size).process(process_batch) @@ -47,19 +73,29 @@ def process_batch(batch, batch_counter): @handle.command() -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -@click.option('-i', '--input-folder', required=True, help='Input folder path. e.g. sikt-nva-sandbox-resources-ntnu@194.0.0.0-handle-tasks') -def execute(profile:str, input_folder:str) -> None: - complete_folder = os.path.join(input_folder, 'complete') +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +@click.option( + "-i", + "--input-folder", + required=True, + help="Input folder path. e.g. sikt-nva-sandbox-resources-ntnu@194.0.0.0-handle-tasks", +) +def execute(profile: str, input_folder: str) -> None: + complete_folder = os.path.join(input_folder, "complete") os.makedirs(complete_folder, exist_ok=True) for batch_file in os.listdir(input_folder): file_path = os.path.join(input_folder, batch_file) if os.path.isfile(file_path): - with open(file_path, 'r') as infile: + with open(file_path, "r") as infile: batch = [json.loads(line) for line in infile] HandleTaskExecutorService(profile).execute(batch) - + # Move the file to the 'complete' folder after processing new_file_path = os.path.join(complete_folder, batch_file) shutil.move(file_path, new_file_path) diff --git a/commands/services/aws_utils.py b/commands/services/aws_utils.py index 544efe9..28aa7a2 100644 --- a/commands/services/aws_utils.py +++ b/commands/services/aws_utils.py @@ -2,18 +2,20 @@ import boto3 import json -def get_account_alias(profile:str=None) -> str: + +def get_account_alias(profile: str = None) -> str: # Create a default Boto3 session session = boto3.Session(profile_name=profile) if profile else boto3.Session() # Create an IAM client - iam = session.client('iam') + iam = session.client("iam") # Get the account alias - account_aliases = iam.list_account_aliases()['AccountAliases'] + account_aliases = iam.list_account_aliases()["AccountAliases"] # Return the first account alias or None if the list is empty return account_aliases[0] if account_aliases else None + def prettify(object) -> str: - return json.dumps(object, indent=2, sort_keys=True, default=str, ensure_ascii=False) \ No newline at end of file + return json.dumps(object, indent=2, sort_keys=True, default=str, ensure_ascii=False) diff --git a/commands/services/cognito_api.py b/commands/services/cognito_api.py index 31b0d49..57e0536 100644 --- a/commands/services/cognito_api.py +++ b/commands/services/cognito_api.py @@ -1,5 +1,6 @@ import boto3 + class CognitoService: def __init__(self, profile): self.profile = profile @@ -11,46 +12,55 @@ def search(self, search_term): return self._lookup_users_by_attribute_value(search_term, users) def _get_user_pool_id(self): - session = boto3.Session(profile_name=self.profile) if self.profile else boto3.Session() - client = session.client('ssm') + session = ( + boto3.Session(profile_name=self.profile) + if self.profile + else boto3.Session() + ) + client = session.client("ssm") - parameter_name = 'CognitoUserPoolId' + parameter_name = "CognitoUserPoolId" - response = client.get_parameter( - Name=parameter_name, - WithDecryption=True - ) + response = client.get_parameter(Name=parameter_name, WithDecryption=True) + + return response["Parameter"]["Value"] - return response['Parameter']['Value'] - def _get_all_users(self, user_pool_id): - session = boto3.Session(profile_name=self.profile) if self.profile else boto3.Session() - cognito = session.client('cognito-idp') - + session = ( + boto3.Session(profile_name=self.profile) + if self.profile + else boto3.Session() + ) + cognito = session.client("cognito-idp") + pagination_token = None users = [] while True: if pagination_token: - response = cognito.list_users(UserPoolId=user_pool_id, PaginationToken=pagination_token) + response = cognito.list_users( + UserPoolId=user_pool_id, PaginationToken=pagination_token + ) else: response = cognito.list_users(UserPoolId=user_pool_id) - users.extend(response['Users']) + users.extend(response["Users"]) - pagination_token = response.get('PaginationToken') + pagination_token = response.get("PaginationToken") if not pagination_token: break return users - + def _lookup_users_by_attribute_value(self, search_term, users): search_words = search_term.split() matches = [] for user in users: - user_attributes = ' '.join(attribute['Value'] for attribute in user['Attributes']) + user_attributes = " ".join( + attribute["Value"] for attribute in user["Attributes"] + ) if all(word in user_attributes for word in search_words): matches.append(user) - return matches if matches else None \ No newline at end of file + return matches if matches else None diff --git a/commands/services/customers_api.py b/commands/services/customers_api.py index 7ac25be..b7dc256 100644 --- a/commands/services/customers_api.py +++ b/commands/services/customers_api.py @@ -4,85 +4,95 @@ def list_missing_customers(profile): session = boto3.Session(profile_name=profile) if profile else boto3.Session() - dynamodb = session.resource('dynamodb') + dynamodb = session.resource("dynamodb") - customers_table = dynamodb.Table(_get_table_name(profile, 'nva-customers')) - users_table = dynamodb.Table(_get_table_name(profile, 'nva-users-and-roles')) + customers_table = dynamodb.Table(_get_table_name(profile, "nva-customers")) + users_table = dynamodb.Table(_get_table_name(profile, "nva-users-and-roles")) customer_identifiers = _extract_customer_identifiers(customers_table) missing_customers = _find_missing_customers(users_table, customer_identifiers) return missing_customers + def list_duplicate_customers(profile): session = boto3.Session(profile_name=profile) if profile else boto3.Session() - dynamodb = session.resource('dynamodb') - customers_table = dynamodb.Table(_get_table_name(profile, 'nva-customers')) + dynamodb = session.resource("dynamodb") + customers_table = dynamodb.Table(_get_table_name(profile, "nva-customers")) duplicate_customers = _find_duplicate_customers(customers_table) return duplicate_customers + def _find_duplicate_customers(customers_table): cristinId_counts = {} matching_items = [] customers_response = _scan_table(customers_table) for item in customers_response: - if 'cristinId' in item: + if "cristinId" in item: # Extract the first number from the cristinId - match = re.search(r'\d+', item['cristinId']) + match = re.search(r"\d+", item["cristinId"]) if match: first_number = match.group() - cristinId_counts[first_number] = cristinId_counts.get(first_number, 0) + 1 + cristinId_counts[first_number] = ( + cristinId_counts.get(first_number, 0) + 1 + ) if cristinId_counts[first_number] >= 2: matching_items.append(item) return matching_items + def _scan_table(table): items = [] response = table.scan() - items.extend(response['Items']) + items.extend(response["Items"]) - while 'LastEvaluatedKey' in response: - response = table.scan(ExclusiveStartKey=response['LastEvaluatedKey']) - items.extend(response['Items']) + while "LastEvaluatedKey" in response: + response = table.scan(ExclusiveStartKey=response["LastEvaluatedKey"]) + items.extend(response["Items"]) return items + def _extract_customer_identifiers(customers_table): customer_identifiers = set() customers_response = _scan_table(customers_table) for customer in customers_response: - customer_identifiers.add(customer['identifier']) + customer_identifiers.add(customer["identifier"]) return customer_identifiers + def _find_missing_customers(users_table, customer_identifiers): missing_customers = [] users_response = _scan_table(users_table) for user in users_response: - if 'institution' in user: + if "institution" in user: # Extract the customer identifier from the institution - match = re.search(r'(?<=customer/).+', user['institution']) + match = re.search(r"(?<=customer/).+", user["institution"]) if match: customer_id = match.group() if customer_id not in customer_identifiers: - missing_customers.append({ - 'PrimaryKeyHashKey': user['PrimaryKeyHashKey'], - 'MissingCustomerId': customer_id - }) + missing_customers.append( + { + "PrimaryKeyHashKey": user["PrimaryKeyHashKey"], + "MissingCustomerId": customer_id, + } + ) return missing_customers + def _get_table_name(profile, name): session = boto3.Session(profile_name=profile) if profile else boto3.Session() - dynamodb = session.client('dynamodb') + dynamodb = session.client("dynamodb") response = dynamodb.list_tables() - - for table_name in response['TableNames']: + + for table_name in response["TableNames"]: if table_name.startswith(name): return table_name - - raise ValueError('No valid table found.') \ No newline at end of file + + raise ValueError("No valid table found.") diff --git a/commands/services/dynamodb_export.py b/commands/services/dynamodb_export.py index 256d9fc..b42ead5 100644 --- a/commands/services/dynamodb_export.py +++ b/commands/services/dynamodb_export.py @@ -6,6 +6,7 @@ import re from boto3.dynamodb.types import Binary + class DynamodbExport: def __init__(self, profile, table_pattern, condition, batch_size): self.condition = condition @@ -16,87 +17,101 @@ def __init__(self, profile, table_pattern, condition, batch_size): def get_table(self): session = boto3.Session(profile_name=self.profile) - dynamodb = session.client('dynamodb') + dynamodb = session.client("dynamodb") response = dynamodb.list_tables() - table_names = response['TableNames'] - table_name = next((name for name in table_names if re.match(self.table_pattern, name)), None) - + table_names = response["TableNames"] + table_name = next( + (name for name in table_names if re.match(self.table_pattern, name)), None + ) + if table_name is None: print(f"No table found matching {self.table_pattern}") return - - dynamodb_resource = session.resource('dynamodb') - return dynamodb_resource.Table(table_name) + dynamodb_resource = session.resource("dynamodb") + return dynamodb_resource.Table(table_name) def _iterate_batches(self, table, custom_batch_processor): response = table.query( Limit=self.batch_size, KeyConditionExpression=self.condition, - ReturnConsumedCapacity='TOTAL' + ReturnConsumedCapacity="TOTAL", ) - items = response['Items'] + items = response["Items"] batch = self._inflate_batch(items) custom_batch_processor(batch, self.batch_counter) self.batch_counter += 1 total_count = len(items) - total_consumed_capacity = response['ConsumedCapacity']['CapacityUnits'] - print(f"Processed {len(items)} items, Total: {total_count}, ConsumedCapacity: {total_consumed_capacity}") - - while 'LastEvaluatedKey' in response: - response = table.query(ExclusiveStartKey=response['LastEvaluatedKey'], Limit=self.batch_size, KeyConditionExpression=self.condition, ReturnConsumedCapacity='TOTAL') - items = response['Items'] + total_consumed_capacity = response["ConsumedCapacity"]["CapacityUnits"] + print( + f"Processed {len(items)} items, Total: {total_count}, ConsumedCapacity: {total_consumed_capacity}" + ) + + while "LastEvaluatedKey" in response: + response = table.query( + ExclusiveStartKey=response["LastEvaluatedKey"], + Limit=self.batch_size, + KeyConditionExpression=self.condition, + ReturnConsumedCapacity="TOTAL", + ) + items = response["Items"] if items: batch = self._inflate_batch(items) custom_batch_processor(batch, self.batch_counter) self.batch_counter += 1 total_count += len(items) - total_consumed_capacity += response['ConsumedCapacity']['CapacityUnits'] - print(f"Processed {len(items)} items, Total: {total_count}, ConsumedCapacity: {total_consumed_capacity}") + total_consumed_capacity += response["ConsumedCapacity"]["CapacityUnits"] + print( + f"Processed {len(items)} items, Total: {total_count}, ConsumedCapacity: {total_consumed_capacity}" + ) def _inflate_batch(self, items): inflated_items = [] for item in items: - item = {k: (base64.b64encode(v.value).decode() if isinstance(v, Binary) else v) for k, v in item.items()} + item = { + k: (base64.b64encode(v.value).decode() if isinstance(v, Binary) else v) + for k, v in item.items() + } inflated_item = self._inflate_item(item) inflated_items.append(inflated_item) return inflated_items def _inflate_item(self, item): - if 'data' in item: - data = item['data'] + if "data" in item: + data = item["data"] decoded_data = base64.b64decode(data) inflated_data = zlib.decompress(decoded_data, -zlib.MAX_WBITS) - inflated_str = inflated_data.decode('utf-8') + inflated_str = inflated_data.decode("utf-8") return json.loads(inflated_str) def _save_inflated_items_to_file(self, inflated_items, batch_counter): if not os.path.exists(self.output_folder): os.makedirs(self.output_folder) - filename = os.path.join(self.output_folder, f'batch_{batch_counter}.jsonl') - with open(filename, 'w') as file: + filename = os.path.join(self.output_folder, f"batch_{batch_counter}.jsonl") + with open(filename, "w") as file: for inflated_item in inflated_items: file.write(json.dumps(inflated_item)) - file.write('\n') + file.write("\n") def save_to_folder(self, output_folder): self.output_folder = output_folder table = self.get_table() self._iterate_batches(table, self._save_inflated_items_to_file) - + def process(self, action): table = self.get_table() self._iterate_batches(table, action) + def get_account_alias(profile=None): # Create a default Boto3 session session = boto3.Session(profile_name=profile) if profile else boto3.Session() # Create an IAM client - iam = session.client('iam') + iam = session.client("iam") # Get the account alias - account_aliases = iam.list_account_aliases()['AccountAliases'] + account_aliases = iam.list_account_aliases()["AccountAliases"] # Return the first account alias or None if the list is empty - return account_aliases[0] if account_aliases else None \ No newline at end of file + return account_aliases[0] if account_aliases else None diff --git a/commands/services/external_user.py b/commands/services/external_user.py index 0264919..3d43e57 100644 --- a/commands/services/external_user.py +++ b/commands/services/external_user.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta from dataclasses import dataclass -''' +""" # example of usage from services import ExternalUserService @@ -18,41 +18,45 @@ external_user_service = ExternalUserService(profile) external_user = external_user_service.create(customer_id, intended_purpose, scopes) external_user.save_to_file() -''' +""" + + class ExternalUserService: def __init__(self, profile): self.session = boto3.Session(profile_name=profile) - self.ssm = self.session.client('ssm') - self.secretsmanager = self.session.client('secretsmanager') - self.api_domain = self._get_system_parameter('/NVA/ApiDomain') - self.cognito_uri = self._get_system_parameter('/NVA/CognitoUri') - self.client_credentials = self._get_secret('BackendCognitoClientCredentials') + self.ssm = self.session.client("ssm") + self.secretsmanager = self.session.client("secretsmanager") + self.api_domain = self._get_system_parameter("/NVA/ApiDomain") + self.cognito_uri = self._get_system_parameter("/NVA/CognitoUri") + self.client_credentials = self._get_secret("BackendCognitoClientCredentials") self.token = None self.token_expiry_time = None def _get_system_parameter(self, name): response = self.ssm.get_parameter(Name=name) - return response['Parameter']['Value'] + return response["Parameter"]["Value"] def _get_secret(self, name): response = self.secretsmanager.get_secret_value(SecretId=name) - secret_string = response['SecretString'] + secret_string = response["SecretString"] secret = json.loads(secret_string) return secret def _get_cognito_token(self): url = f"{self.cognito_uri}/oauth2/token" - headers = {'Content-Type': 'application/x-www-form-urlencoded'} + headers = {"Content-Type": "application/x-www-form-urlencoded"} data = { - 'grant_type': 'client_credentials', - 'client_id': self.client_credentials['backendClientId'], - 'client_secret': self.client_credentials['backendClientSecret'], + "grant_type": "client_credentials", + "client_id": self.client_credentials["backendClientId"], + "client_secret": self.client_credentials["backendClientSecret"], } response = requests.post(url, headers=headers, data=data) response.raise_for_status() response_json = response.json() - token_expiry_time = datetime.now() + timedelta(seconds=response_json['expires_in']) - return response_json['access_token'], token_expiry_time + token_expiry_time = datetime.now() + timedelta( + seconds=response_json["expires_in"] + ) + return response_json["access_token"], token_expiry_time def _get_token(self): if not self.token or self._is_token_expired(): @@ -63,20 +67,20 @@ def _is_token_expired(self): if not self.token_expiry_time: return True return datetime.now() > self.token_expiry_time - timedelta(seconds=30) - + def _create_external_client_token(self, scopes): url = f"https://{self.api_domain}/users-roles/external-clients" headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'Authorization': f"Bearer {self._get_token()}" + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {self._get_token()}", } data = { "clientName": f"{self.org_abbreviation}-{self.intended_purpose}-integration", "customerUri": self.customer_id, "cristinOrgUri": self.org_id, "actingUser": f"{self.intended_purpose}-integration@{self.org_abbreviation}", - "scopes": scopes + "scopes": scopes, } response = requests.post(url, headers=headers, json=data) response.raise_for_status() @@ -89,28 +93,28 @@ def _create_external_client_token(self, scopes): "customerUri": data["customerUri"], "cristinOrgUri": data["cristinOrgUri"], "actingUser": data["actingUser"], - "scopes": data["scopes"] + "scopes": data["scopes"], } - + def _get_customer_data(self, customer_id): url = f"https://{self.api_domain}/customer/{customer_id}" headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'Authorization': f"Bearer {self._get_token()}" + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {self._get_token()}", } response = requests.get(url, headers=headers) response.raise_for_status() # raises HTTPError for 4xx and 5xx status codes return response.json() def create(self, customer_id, intended_purpose, scopes): - customer_data = self._get_customer_data(customer_id) - self.org_id = customer_data['cristinId'] - self.customer_id = customer_data['id'] - self.org_abbreviation = customer_data['shortName'].lower() - self.intended_purpose = intended_purpose - client_data = self._create_external_client_token(scopes) - return ExternalUser(self.org_abbreviation, self.intended_purpose, client_data) + customer_data = self._get_customer_data(customer_id) + self.org_id = customer_data["cristinId"] + self.customer_id = customer_data["id"] + self.org_abbreviation = customer_data["shortName"].lower() + self.intended_purpose = intended_purpose + client_data = self._create_external_client_token(scopes) + return ExternalUser(self.org_abbreviation, self.intended_purpose, client_data) @dataclass @@ -120,5 +124,7 @@ class ExternalUser: client_data: dict def save_to_file(self): - with open(f"{self.org_abbreviation}-{self.intended_purpose}-credentials.json", 'w') as json_file: - json.dump(self.client_data, json_file, indent=4) \ No newline at end of file + with open( + f"{self.org_abbreviation}-{self.intended_purpose}-credentials.json", "w" + ) as json_file: + json.dump(self.client_data, json_file, indent=4) diff --git a/commands/services/handle_api.py b/commands/services/handle_api.py index cc5f61e..86d4ba5 100644 --- a/commands/services/handle_api.py +++ b/commands/services/handle_api.py @@ -2,7 +2,8 @@ import requests import json from datetime import datetime, timedelta -''' + +""" # example of usage from services import HandleApiService @@ -24,40 +25,44 @@ # Call the create_handle method create_response = service.create_handle(request_body) print(create_response) -''' +""" + + class HandleApiService: def __init__(self): self.session = boto3.Session() - self.ssm = self.session.client('ssm') - self.secretsmanager = self.session.client('secretsmanager') - self.api_domain = self._get_system_parameter('/NVA/ApiDomain') - self.cognito_uri = self._get_system_parameter('/NVA/CognitoUri') - self.client_credentials = self._get_secret('BackendCognitoClientCredentials') + self.ssm = self.session.client("ssm") + self.secretsmanager = self.session.client("secretsmanager") + self.api_domain = self._get_system_parameter("/NVA/ApiDomain") + self.cognito_uri = self._get_system_parameter("/NVA/CognitoUri") + self.client_credentials = self._get_secret("BackendCognitoClientCredentials") self.token = self._get_cognito_token() self.token_expiry_time = datetime.now() # Initialize with current time def _get_system_parameter(self, name): response = self.ssm.get_parameter(Name=name) - return response['Parameter']['Value'] + return response["Parameter"]["Value"] def _get_secret(self, name): response = self.secretsmanager.get_secret_value(SecretId=name) - secret_string = response['SecretString'] + secret_string = response["SecretString"] secret = json.loads(secret_string) return secret def _get_cognito_token(self): url = f"{self.cognito_uri}/oauth2/token" - headers = {'Content-Type': 'application/x-www-form-urlencoded'} + headers = {"Content-Type": "application/x-www-form-urlencoded"} data = { - 'grant_type': 'client_credentials', - 'client_id': self.client_credentials['backendClientId'], - 'client_secret': self.client_credentials['backendClientSecret'], + "grant_type": "client_credentials", + "client_id": self.client_credentials["backendClientId"], + "client_secret": self.client_credentials["backendClientSecret"], } response = requests.post(url, headers=headers, data=data) response_json = response.json() - self.token_expiry_time = datetime.now() + timedelta(seconds=response_json['expires_in']) # Set the expiry time - return response_json['access_token'] + self.token_expiry_time = datetime.now() + timedelta( + seconds=response_json["expires_in"] + ) # Set the expiry time + return response_json["access_token"] def _is_token_expired(self): # If there are less than 30 seconds until the token expires, consider it expired @@ -70,12 +75,18 @@ def _get_token(self): def update_handle(self, prefix, suffix, request_body): url = f"https://{self.api_domain}/handle/{prefix}/{suffix}" - headers = {'Authorization': f"Bearer {self._get_token()}", 'Content-Type': 'application/json'} + headers = { + "Authorization": f"Bearer {self._get_token()}", + "Content-Type": "application/json", + } response = requests.put(url, headers=headers, json=request_body) return response.json() def create_handle(self, request_body): url = f"https://{self.api_domain}/handle/" - headers = {'Authorization': f"Bearer {self._get_token()}", 'Content-Type': 'application/json'} + headers = { + "Authorization": f"Bearer {self._get_token()}", + "Content-Type": "application/json", + } response = requests.post(url, headers=headers, json=request_body) - return response.json() \ No newline at end of file + return response.json() diff --git a/commands/services/handle_task_executor.py b/commands/services/handle_task_executor.py index 806212c..2b9b35c 100644 --- a/commands/services/handle_task_executor.py +++ b/commands/services/handle_task_executor.py @@ -4,4 +4,4 @@ def __init__(self, profile): pass def execute(self, batch): - pass \ No newline at end of file + pass diff --git a/commands/services/handle_task_writer.py b/commands/services/handle_task_writer.py index ba79aaf..3fe3fab 100644 --- a/commands/services/handle_task_writer.py +++ b/commands/services/handle_task_writer.py @@ -1,39 +1,43 @@ class HandleTaskWriterService: def __init__(self): - self.sikt_prefix = '//hdl.handle.net/11250' - self.sikt_prefix_test = '//hdl.handle.net/11250.1' + self.sikt_prefix = "//hdl.handle.net/11250" + self.sikt_prefix_test = "//hdl.handle.net/11250.1" pass def _is_sikt_handle(self, handle): return self.sikt_prefix in handle or self.sikt_prefix_test in handle def _get_sikt_additional_identifier_handle(self, publication): - for additionalIdentifier in publication.get('additionalIdentifiers', []): - if additionalIdentifier.get('type') == 'HandleIdentifier' or \ - (additionalIdentifier.get('source') == 'handle' and additionalIdentifier.get('type') == 'AdditionalIdentifier'): - handle = additionalIdentifier.get('value') + for additionalIdentifier in publication.get("additionalIdentifiers", []): + if additionalIdentifier.get("type") == "HandleIdentifier" or ( + additionalIdentifier.get("source") == "handle" + and additionalIdentifier.get("type") == "AdditionalIdentifier" + ): + handle = additionalIdentifier.get("value") if self._is_sikt_handle(handle): return handle return None def process_item(self, publication): task = {} - additional_identifier_handle = self._get_sikt_additional_identifier_handle(publication) - top_handle = publication.get('publication') - task['identifiter'] = publication['identifier'] - task['publication'] = publication + additional_identifier_handle = self._get_sikt_additional_identifier_handle( + publication + ) + top_handle = publication.get("publication") + task["identifiter"] = publication["identifier"] + task["publication"] = publication if top_handle and self._is_sikt_handle(top_handle): - task['action'] = "nop" # all good, already sikt managed handle in place + task["action"] = "nop" # all good, already sikt managed handle in place elif top_handle and not self._is_sikt_handle(top_handle): if additional_identifier_handle: - task['action'] = "move_top_to_additional_and_promote_additional" - task['updated_handle'] = additional_identifier_handle + task["action"] = "move_top_to_additional_and_promote_additional" + task["updated_handle"] = additional_identifier_handle else: - task['action'] = "move_top_to_additional_and_create_new_top" + task["action"] = "move_top_to_additional_and_create_new_top" elif not top_handle and additional_identifier_handle: - task['action'] = "promote_additional" - task['updated_handle'] = additional_identifier_handle + task["action"] = "promote_additional" + task["updated_handle"] = additional_identifier_handle elif not top_handle and not additional_identifier_handle: - task['action'] = "create_new_top" - return task \ No newline at end of file + task["action"] = "create_new_top" + return task diff --git a/commands/services/lambda_api.py b/commands/services/lambda_api.py index 1455128..45d3200 100644 --- a/commands/services/lambda_api.py +++ b/commands/services/lambda_api.py @@ -1,5 +1,6 @@ import boto3 + class LambdaService: def __init__(self, profile): self.profile = profile @@ -7,21 +8,28 @@ def __init__(self, profile): def delete_old_versions(self, delete): session = boto3.Session(profile_name=self.profile) - client = session.client('lambda') + client = session.client("lambda") - functions_paginator = client.get_paginator('list_functions') - version_paginator = client.get_paginator('list_versions_by_function') + functions_paginator = client.get_paginator("list_functions") + version_paginator = client.get_paginator("list_versions_by_function") for function_page in functions_paginator.paginate(): - for function in function_page['Functions']: - aliases = client.list_aliases(FunctionName=function['FunctionArn']) - alias_versions = [alias['FunctionVersion'] for alias in aliases['Aliases']] - for version_page in version_paginator.paginate(FunctionName=function['FunctionArn']): - for version in version_page['Versions']: - arn = version['FunctionArn'] - if version['Version'] != function['Version'] and version['Version'] not in alias_versions: - print(' 🥊 {}'.format(arn)) + for function in function_page["Functions"]: + aliases = client.list_aliases(FunctionName=function["FunctionArn"]) + alias_versions = [ + alias["FunctionVersion"] for alias in aliases["Aliases"] + ] + for version_page in version_paginator.paginate( + FunctionName=function["FunctionArn"] + ): + for version in version_page["Versions"]: + arn = version["FunctionArn"] + if ( + version["Version"] != function["Version"] + and version["Version"] not in alias_versions + ): + print(" 🥊 {}".format(arn)) if delete: client.delete_function(FunctionName=arn) else: - print(' 💚 {}'.format(arn)) \ No newline at end of file + print(" 💚 {}".format(arn)) diff --git a/commands/services/publication_api.py b/commands/services/publication_api.py index fbf31c3..2369bcf 100644 --- a/commands/services/publication_api.py +++ b/commands/services/publication_api.py @@ -2,47 +2,57 @@ import requests import json from datetime import datetime, timedelta -''' + +""" # example of usage -''' +""" + + class PublicationApiService: def __init__(self, client_id=None, client_secret=None): self.session = boto3.Session() - self.ssm = self.session.client('ssm') - self.secretsmanager = self.session.client('secretsmanager') - self.api_domain = self._get_system_parameter('/NVA/ApiDomain') - self.cognito_uri = self._get_system_parameter('/NVA/CognitoUri') + self.ssm = self.session.client("ssm") + self.secretsmanager = self.session.client("secretsmanager") + self.api_domain = self._get_system_parameter("/NVA/ApiDomain") + self.cognito_uri = self._get_system_parameter("/NVA/CognitoUri") if client_id and client_secret: - self.client_credentials = {'backendClientId': client_id, 'backendClientSecret': client_secret} + self.client_credentials = { + "backendClientId": client_id, + "backendClientSecret": client_secret, + } else: - self.client_credentials = self._get_secret('BackendCognitoClientCredentials') + self.client_credentials = self._get_secret( + "BackendCognitoClientCredentials" + ) self.token = self._get_cognito_token() self.token_expiry_time = datetime.now() # Initialize with current time def _get_system_parameter(self, name): response = self.ssm.get_parameter(Name=name) - return response['Parameter']['Value'] + return response["Parameter"]["Value"] def _get_secret(self, name): response = self.secretsmanager.get_secret_value(SecretId=name) - secret_string = response['SecretString'] + secret_string = response["SecretString"] secret = json.loads(secret_string) return secret def _get_cognito_token(self): url = f"{self.cognito_uri}/oauth2/token" - headers = {'Content-Type': 'application/x-www-form-urlencoded'} + headers = {"Content-Type": "application/x-www-form-urlencoded"} data = { - 'grant_type': 'client_credentials', - 'client_id': self.client_credentials['backendClientId'], - 'client_secret': self.client_credentials['backendClientSecret'], + "grant_type": "client_credentials", + "client_id": self.client_credentials["backendClientId"], + "client_secret": self.client_credentials["backendClientSecret"], } response = requests.post(url, headers=headers, data=data) response_json = response.json() - self.token_expiry_time = datetime.now() + timedelta(seconds=response_json['expires_in']) # Set the expiry time - return response_json['access_token'] + self.token_expiry_time = datetime.now() + timedelta( + seconds=response_json["expires_in"] + ) # Set the expiry time + return response_json["access_token"] def _is_token_expired(self): # If there are less than 30 seconds until the token expires, consider it expired @@ -55,7 +65,10 @@ def _get_token(self): def fetch_publication(self, publicationIdentifier, doNotRedirect=True): url = f"https://{self.api_domain}/publication/{publicationIdentifier}?doNotRedirect={doNotRedirect}" - headers = {'Authorization': f"Bearer {self._get_token()}", 'Accept': 'application/json'} + headers = { + "Authorization": f"Bearer {self._get_token()}", + "Accept": "application/json", + } response = requests.get(url, headers=headers) if response.status_code == 200: # If the status code indicates success return response.json() @@ -65,6 +78,10 @@ def fetch_publication(self, publicationIdentifier, doNotRedirect=True): def update_publication(self, publicationIdentifier, request_body): url = f"https://{self.api_domain}/publication/{publicationIdentifier}" - headers = {'Authorization': f"Bearer {self._get_token()}", 'Content-Type': 'application/json', 'Accept': 'application/json'} + headers = { + "Authorization": f"Bearer {self._get_token()}", + "Content-Type": "application/json", + "Accept": "application/json", + } response = requests.put(url, headers=headers, json=request_body) - return response.json() \ No newline at end of file + return response.json() diff --git a/commands/services/users_api.py b/commands/services/users_api.py index 8fe5dec..188329e 100644 --- a/commands/services/users_api.py +++ b/commands/services/users_api.py @@ -1,5 +1,6 @@ import boto3 + class UsersAndRolesService: def __init__(self, profile): self.profile = profile @@ -7,8 +8,12 @@ def __init__(self, profile): def search(self, search_term): table_name = self._get_table_name() - session = boto3.Session(profile_name=self.profile) if self.profile else boto3.Session() - dynamodb = session.resource('dynamodb') + session = ( + boto3.Session(profile_name=self.profile) + if self.profile + else boto3.Session() + ) + dynamodb = session.resource("dynamodb") table = dynamodb.Table(table_name) # Split search term into individual words @@ -20,31 +25,35 @@ def search(self, search_term): # Collect all items that match the value matching_items = [] - while 'LastEvaluatedKey' in response: - matching_items.extend(self._items_search(response['Items'], search_words)) + while "LastEvaluatedKey" in response: + matching_items.extend(self._items_search(response["Items"], search_words)) # Paginate results - response = table.scan(ExclusiveStartKey=response['LastEvaluatedKey']) + response = table.scan(ExclusiveStartKey=response["LastEvaluatedKey"]) # Don't forget to process the last page of results - matching_items.extend(self._items_search(response['Items'], search_words)) + matching_items.extend(self._items_search(response["Items"], search_words)) return matching_items - + def _items_search(self, items, search_words): matching_items = [] for item in items: - item_values = ' '.join(str(value) for value in item.values()) + item_values = " ".join(str(value) for value in item.values()) if all(word in item_values for word in search_words): matching_items.append(item) return matching_items - + def _get_table_name(self): - session = boto3.Session(profile_name=self.profile) if self.profile else boto3.Session() - dynamodb = session.client('dynamodb') + session = ( + boto3.Session(profile_name=self.profile) + if self.profile + else boto3.Session() + ) + dynamodb = session.client("dynamodb") response = dynamodb.list_tables() - - for table_name in response['TableNames']: - if table_name.startswith('nva-users-and-roles'): + + for table_name in response["TableNames"]: + if table_name.startswith("nva-users-and-roles"): return table_name - - raise ValueError('No valid table found.') \ No newline at end of file + + raise ValueError("No valid table found.") diff --git a/commands/users.py b/commands/users.py index 91f0563..78f8747 100644 --- a/commands/users.py +++ b/commands/users.py @@ -4,24 +4,56 @@ from commands.services.aws_utils import prettify from commands.services.external_user import ExternalUserService + @click.group() def users(): pass + @users.command(help="Search users by user values") -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -@click.argument('search_term', required=True, nargs=-1) -def search(profile:str, search_term:str) -> None: - search_term = ' '.join(search_term) +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +@click.argument("search_term", required=True, nargs=-1) +def search(profile: str, search_term: str) -> None: + search_term = " ".join(search_term) result = UsersAndRolesService(profile).search(search_term) click.echo(prettify(result)) + @users.command(help="Add external API user") -@click.option('--profile', envvar='AWS_PROFILE', default='default', help='The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config') -@click.option('-c', '--customer', required=True, help='Customer UUID. e.g. bb3d0c0c-5065-4623-9b98-5810983c2478') -@click.option('-i', '--intended_purpose', required=True, help='The intended purpose. e.g. oslomet-thesis-integration') -@click.option('-s', '--scopes', required=True, help='Comma-separated list of scopes without whitespace, e.g., https://api.nva.unit.no/scopes/third-party/publication-read,https://api.nva.unit.no/scopes/third-party/publication-upsert') -def create_external(profile:str, customer:str, intended_purpose:str, scopes:str) -> None: - external_user = ExternalUserService(profile).create(customer, intended_purpose, scopes.split(',')) +@click.option( + "--profile", + envvar="AWS_PROFILE", + default="default", + help="The AWS profile to use. e.g. sikt-nva-sandbox, configure your profiles in ~/.aws/config", +) +@click.option( + "-c", + "--customer", + required=True, + help="Customer UUID. e.g. bb3d0c0c-5065-4623-9b98-5810983c2478", +) +@click.option( + "-i", + "--intended_purpose", + required=True, + help="The intended purpose. e.g. oslomet-thesis-integration", +) +@click.option( + "-s", + "--scopes", + required=True, + help="Comma-separated list of scopes without whitespace, e.g., https://api.nva.unit.no/scopes/third-party/publication-read,https://api.nva.unit.no/scopes/third-party/publication-upsert", +) +def create_external( + profile: str, customer: str, intended_purpose: str, scopes: str +) -> None: + external_user = ExternalUserService(profile).create( + customer, intended_purpose, scopes.split(",") + ) external_user.save_to_file() click.echo(prettify(external_user.client_data)) diff --git a/fileService.py b/fileService.py index 5f2639c..b8e4a44 100644 --- a/fileService.py +++ b/fileService.py @@ -5,19 +5,18 @@ from datetime import datetime, timedelta import pytz -UTC=pytz.UTC +UTC = pytz.UTC OneWeekAgo = UTC.localize(datetime.now() - timedelta(weeks=1)) -MetadataKey = 'nva-publication-identifier' +MetadataKey = "nva-publication-identifier" def delete_untagged_files(s3_client, account_id): - storage_bucket = f'nva-resource-storage-{account_id}' + storage_bucket = f"nva-resource-storage-{account_id}" - paginator = s3_client.get_paginator('list_objects_v2') + paginator = s3_client.get_paginator("list_objects_v2") page_iterator = paginator.paginate( - Bucket=storage_bucket, - PaginationConfig={'PageSize': 1000} + Bucket=storage_bucket, PaginationConfig={"PageSize": 1000} ) evaluated_files = 0 deleted_files = 0 @@ -27,174 +26,164 @@ def delete_untagged_files(s3_client, account_id): objects_to_delete = [] for page in page_iterator: - for obj in page['Contents']: - key = obj['Key'] - last_modified = obj['LastModified'] + for obj in page["Contents"]: + key = obj["Key"] + last_modified = obj["LastModified"] if last_modified < OneWeekAgo: metadata = fetch_metadata(s3_client, account_id, key) if MetadataKey not in metadata: - objects_to_delete.append({'Key': key}) + objects_to_delete.append({"Key": key}) deleted_files = deleted_files + 1 if len(objects_to_delete) >= 999: # Delete the object if the metadata key is missing delete_response = s3_client.delete_objects( - Bucket=storage_bucket, - Delete={'Objects': objects_to_delete}) + Bucket=storage_bucket, Delete={"Objects": objects_to_delete} + ) report_delete_response(len(objects_to_delete), delete_response) objects_to_delete.clear() - print(f'Deleted 999 files missing metadata key {MetadataKey})') + print(f"Deleted 999 files missing metadata key {MetadataKey})") evaluated_files = evaluated_files + 1 if evaluated_files % 100 == 0: - print(f'Evaluated {evaluated_files} files, deleted {deleted_files}') + print(f"Evaluated {evaluated_files} files, deleted {deleted_files}") if len(objects_to_delete) > 0: s3_client.delete_objects( - Bucket=storage_bucket, - Delete={'Objects': objects_to_delete}, - Quiet=True) + Bucket=storage_bucket, Delete={"Objects": objects_to_delete}, Quiet=True + ) - print(f'Evaluated {evaluated_files} files, deleted {deleted_files}') + print(f"Evaluated {evaluated_files} files, deleted {deleted_files}") def report_delete_response(expected_deletes, response): - deleted = response['Deleted'] - print(f'Deleted {len(deleted)} of {expected_deletes}') + deleted = response["Deleted"] + print(f"Deleted {len(deleted)} of {expected_deletes}") - if 'Errors' in response: - errors = response['Errors'] + if "Errors" in response: + errors = response["Errors"] if len(errors) > 0: print(errors) def should_delete_object(obj, metadata): - last_modified = obj['LastModified'] + last_modified = obj["LastModified"] return MetadataKey not in metadata and last_modified < OneWeekAgo def fetch_metadata(s3_client, account_id, key): - storage_bucket = f'nva-resource-storage-{account_id}' - return s3_client.head_object(Bucket=storage_bucket, - Key=key)['Metadata'] + storage_bucket = f"nva-resource-storage-{account_id}" + return s3_client.head_object(Bucket=storage_bucket, Key=key)["Metadata"] def tag_referenced_files(dynamo_client, s3_resource, account_id, resources_table_name): - storage_bucket = f'nva-resource-storage-{account_id}' + storage_bucket = f"nva-resource-storage-{account_id}" tagged_files = 0 evaluated_files = 0 - paginator = dynamo_client.get_paginator('scan') + paginator = dynamo_client.get_paginator("scan") page_iterator = paginator.paginate( TableName=resources_table_name, - IndexName='ResourcesByIdentifier', - PaginationConfig={'PageSize': 700} + IndexName="ResourcesByIdentifier", + PaginationConfig={"PageSize": 700}, ) for page in page_iterator: - - items = page['Items'] + items = page["Items"] for item in items: data = extract_item_data(item) result = json.loads(data) - identifier = result['identifier'] - if 'entityDescription' in result: - entity_description = result['entityDescription'] - if 'publicationDate' in entity_description: - if 'associatedArtifacts' in result: - associated_artifacts = result['associatedArtifacts'] + identifier = result["identifier"] + if "entityDescription" in result: + entity_description = result["entityDescription"] + if "publicationDate" in entity_description: + if "associatedArtifacts" in result: + associated_artifacts = result["associatedArtifacts"] for associated_artifact in associated_artifacts: - if 'identifier' in associated_artifact: + if "identifier" in associated_artifact: evaluated_files = evaluated_files + 1 - key = associated_artifact['identifier'] + key = associated_artifact["identifier"] tagged_files = tagged_files + update_file_metadata( - s3_resource, - identifier, - key, - storage_bucket) + s3_resource, identifier, key, storage_bucket + ) if evaluated_files % 100 == 0: - print(f'Evaluated {evaluated_files} files, ' - + f'tagged {tagged_files}') + print( + f"Evaluated {evaluated_files} files, " + + f"tagged {tagged_files}" + ) - print(f'Evaluated {evaluated_files} files, ' - + f'tagged {tagged_files}') + print(f"Evaluated {evaluated_files} files, " + f"tagged {tagged_files}") def reset_tags(s3_client, s3_resource, accountId): - storage_bucket = f'nva-resource-storage-{accountId}' + storage_bucket = f"nva-resource-storage-{accountId}" - paginator = s3_client.get_paginator('list_objects_v2') + paginator = s3_client.get_paginator("list_objects_v2") page_iterator = paginator.paginate( - Bucket=storage_bucket, - PaginationConfig={'PageSize': 1000} + Bucket=storage_bucket, PaginationConfig={"PageSize": 1000} ) count = 0 for page in page_iterator: - for obj in page['Contents']: - key = obj['Key'] + for obj in page["Contents"]: + key = obj["Key"] s3_resource.Object(storage_bucket, key).copy_from( - CopySource={'Bucket': storage_bucket, 'Key': key}, + CopySource={"Bucket": storage_bucket, "Key": key}, Metadata={}, - MetadataDirective='REPLACE' + MetadataDirective="REPLACE", ) count = count + 1 - print(f'Reset tags on {count} objects!') - print('Done!') + print(f"Reset tags on {count} objects!") + print("Done!") def extract_item_data(item): - gz_data = item['data'] - return zlib.decompress(gz_data['B'], -zlib.MAX_WBITS) + gz_data = item["data"] + return zlib.decompress(gz_data["B"], -zlib.MAX_WBITS) -def update_file_metadata( - s3_resource, - publication_identifier, - file_key, - bucket_name): +def update_file_metadata(s3_resource, publication_identifier, file_key, bucket_name): target = s3_resource.Object(bucket_name, file_key) if MetadataKey in target.metadata: return 0 else: - target.metadata.update( - {'nva-publication-identifier': publication_identifier}) + target.metadata.update({"nva-publication-identifier": publication_identifier}) s3_resource.Object(bucket_name, file_key).copy_from( - CopySource={'Bucket': bucket_name, 'Key': target.key}, + CopySource={"Bucket": bucket_name, "Key": target.key}, Metadata=target.metadata, - MetadataDirective='REPLACE' + MetadataDirective="REPLACE", ) - print('Updated metadata for file ' + file_key) + print("Updated metadata for file " + file_key) return 1 -if __name__ == '__main__': +if __name__ == "__main__": argParser = argparse.ArgumentParser() - argParser.add_argument("command", - choices=[ - "tag-files", - "delete-untagged-files", - "reset-tags"]) + argParser.add_argument( + "command", choices=["tag-files", "delete-untagged-files", "reset-tags"] + ) argParser.add_argument("resourcesTableName") args = argParser.parse_args() - _dynamodb_client = boto3.client('dynamodb', region_name='eu-west-1') - _s3_client = boto3.client('s3', region_name='eu-west-1') - _s3_resource = boto3.resource('s3', region_name='eu-west-1') + _dynamodb_client = boto3.client("dynamodb", region_name="eu-west-1") + _s3_client = boto3.client("s3", region_name="eu-west-1") + _s3_resource = boto3.resource("s3", region_name="eu-west-1") _resources_table_name = args.resourcesTableName _session = boto3.Session() - _sts_client = _session.client('sts') + _sts_client = _session.client("sts") _accountId = _sts_client.get_caller_identity() if args.command == "tag-files": - tag_referenced_files(_dynamodb_client, _s3_resource, _accountId, _resources_table_name) + tag_referenced_files( + _dynamodb_client, _s3_resource, _accountId, _resources_table_name + ) elif args.command == "delete-untagged-files": delete_untagged_files(_s3_client, _accountId) elif args.command == "reset-tags":