diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b2c79dca..6be41cba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,11 +17,13 @@ repos: rev: v1.4.1 hooks: - id: remove-tabs - #- repo: https://github.com/pre-commit/mirrors-isort - # rev: v5.10.1 - # hooks: - # - id: isort - # name: isort (python) - # entry: isort - # language: python - # types: [python] + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + - repo: https://github.com/psf/black + rev: 23.3.0 + hooks: + - id: black + language_version: python3 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..7847e0a3 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,7 @@ +[tool.black] +line-length = 88 +target-version = ['py39'] +include = '\.pyi?$' + +[tool.isort] +profile = "black" diff --git a/src/glue/jobs/compare_parquet_datasets.py b/src/glue/jobs/compare_parquet_datasets.py index 525bf547..7793549d 100644 --- a/src/glue/jobs/compare_parquet_datasets.py +++ b/src/glue/jobs/compare_parquet_datasets.py @@ -1,10 +1,10 @@ -from collections import namedtuple import datetime import json import logging import os import sys import zipfile +from collections import namedtuple from io import BytesIO, StringIO from typing import Dict, List, NamedTuple, Union diff --git a/src/glue/jobs/json_to_parquet.py b/src/glue/jobs/json_to_parquet.py index 3d0741a0..2ac661b3 100644 --- a/src/glue/jobs/json_to_parquet.py +++ b/src/glue/jobs/json_to_parquet.py @@ -20,13 +20,13 @@ import pandas from awsglue import DynamicFrame from awsglue.context import GlueContext -from awsglue.job import Job from awsglue.gluetypes import StructType +from awsglue.job import Job from awsglue.utils import getResolvedOptions from pyspark import SparkContext from pyspark.sql import Window -from pyspark.sql.functions import row_number, col from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import col, row_number # Configure logger to use ECS formatting logger = logging.getLogger(__name__) @@ -46,14 +46,23 @@ "fitbitintradaycombined": ["ParticipantIdentifier", "Type", "DateTime"], "fitbitrestingheartrates": ["ParticipantIdentifier", "Date"], "fitbitsleeplogs": ["ParticipantIdentifier", "LogId"], - "healthkitv2characteristics": ["ParticipantIdentifier", "HealthKitCharacteristicKey"], + "healthkitv2characteristics": [ + "ParticipantIdentifier", + "HealthKitCharacteristicKey", + ], "healthkitv2samples": ["ParticipantIdentifier", "HealthKitSampleKey"], "healthkitv2heartbeat": ["ParticipantIdentifier", "HealthKitHeartbeatSampleKey"], "healthkitv2statistics": ["ParticipantIdentifier", "HealthKitStatisticKey"], - "healthkitv2clinicalrecords": ["ParticipantIdentifier", "HealthKitClinicalRecordKey"], + "healthkitv2clinicalrecords": [ + "ParticipantIdentifier", + "HealthKitClinicalRecordKey", + ], "healthkitv2electrocardiogram": ["ParticipantIdentifier", "HealthKitECGSampleKey"], "healthkitv2workouts": ["ParticipantIdentifier", "HealthKitWorkoutKey"], - "healthkitv2activitysummaries": ["ParticipantIdentifier", "HealthKitActivitySummaryKey"], + "healthkitv2activitysummaries": [ + "ParticipantIdentifier", + "HealthKitActivitySummaryKey", + ], "garminactivitydetailssummary": ["ParticipantIdentifier", "SummaryId"], "garminactivitysummary": ["ParticipantIdentifier", "SummaryId"], "garminbloodpressuresummary": ["ParticipantIdentifier", "SummaryId"], @@ -125,7 +134,7 @@ def get_table( glue_context: GlueContext, record_counts: dict, logger_context: dict, - ) -> DynamicFrame: +) -> DynamicFrame: """ Return a table as a DynamicFrame with an unambiguous schema. Additionally, we drop any superfluous partition_* fields which are added by Glue. @@ -168,7 +177,7 @@ def drop_table_duplicates( data_type: str, record_counts: dict[str, list], logger_context: dict, - ) -> DataFrame: +) -> DataFrame: """ Drop duplicate samples and superflous partition columns. @@ -193,19 +202,15 @@ def drop_table_duplicates( spark_df = table.toDF() if "InsertedDate" in spark_df.columns: window_ordered = window_unordered.orderBy( - col("InsertedDate").desc(), - col("export_end_date").desc() + col("InsertedDate").desc(), col("export_end_date").desc() ) else: - window_ordered = window_unordered.orderBy( - col("export_end_date").desc() - ) + window_ordered = window_unordered.orderBy(col("export_end_date").desc()) table_no_duplicates = ( - spark_df - .withColumn('ranking', row_number().over(window_ordered)) - .filter("ranking == 1") - .drop("ranking") - .cache() + spark_df.withColumn("ranking", row_number().over(window_ordered)) + .filter("ranking == 1") + .drop("ranking") + .cache() ) count_records_for_event( table=table_no_duplicates, @@ -224,7 +229,7 @@ def drop_deleted_healthkit_data( glue_database: str, record_counts: dict[str, list], logger_context: dict, - ) -> DataFrame: +) -> DataFrame: """ Drop records from a HealthKit table. @@ -255,8 +260,8 @@ def drop_deleted_healthkit_data( glue_client.get_table(DatabaseName=glue_database, Name=deleted_table_name) except glue_client.exceptions.EntityNotFoundException as error: logger.error( - f"Did not find table with name '{deleted_table_name}' ", - f"in database {glue_database}." + f"Did not find table with name '{deleted_table_name}' ", + f"in database {glue_database}.", ) raise error deleted_table_logger_context = deepcopy(logger_context) @@ -270,7 +275,9 @@ def drop_deleted_healthkit_data( logger_context=deleted_table_logger_context, ) if deleted_table_raw.count() == 0: - logger.info(f"The table for data type {deleted_data_type} did not contain any records.") + logger.info( + f"The table for data type {deleted_data_type} did not contain any records." + ) return table # we use `data_type` rather than `deleted_data_type` here because they share # an index (we don't bother including `deleted_data_type` in `INDEX_FIELD_MAP`). @@ -281,9 +288,9 @@ def drop_deleted_healthkit_data( logger_context=deleted_table_logger_context, ) table_with_deleted_samples_removed = table.join( - other=deleted_table, - on=INDEX_FIELD_MAP[data_type], - how="left_anti", + other=deleted_table, + on=INDEX_FIELD_MAP[data_type], + how="left_anti", ) count_records_for_event( table=table_with_deleted_samples_removed, @@ -300,7 +307,7 @@ def archive_existing_datasets( workflow_name: str, workflow_run_id: str, delete_upon_completion: bool, - ) -> list[dict]: +) -> list[dict]: """ Archives existing datasets in S3 by copying them to a timestamped subfolder within an "archive" folder. The format of the timestamped subfolder is: @@ -368,7 +375,7 @@ def write_table_to_s3( workflow_name: str, workflow_run_id: str, records_per_partition: int = int(1e6), - ) -> None: +) -> None: """ Write a DynamicFrame to S3 as a parquet dataset. @@ -441,7 +448,7 @@ def count_records_for_event( event: CountEventType, record_counts: dict[str, list], logger_context: dict, - ) -> dict[str, list]: +) -> dict[str, list]: """ Compute record count statistics for each `export_end_date`. @@ -488,7 +495,7 @@ def store_record_counts( namespace: str, workflow_name: str, workflow_run_id: str, - ) -> dict[str, str]: +) -> dict[str, str]: """ Uploads record counts as S3 objects. @@ -534,7 +541,7 @@ def add_index_to_table( table_name: str, processed_tables: dict[str, DynamicFrame], unprocessed_tables: dict[str, DynamicFrame], - ) -> DataFrame: +) -> DataFrame: """Add partition and index fields to a DynamicFrame. A DynamicFrame containing the top-level fields already includes the index @@ -667,9 +674,7 @@ def main() -> None: logger_context=logger_context, ) table_dynamic = DynamicFrame.fromDF( - dataframe=table, - glue_ctx=glue_context, - name=table_name + dataframe=table, glue_ctx=glue_context, name=table_name ) # Export new table records to parquet if has_nested_fields(table.schema): diff --git a/src/glue/jobs/s3_to_json.py b/src/glue/jobs/s3_to_json.py index 130c1a88..7d01669b 100644 --- a/src/glue/jobs/s3_to_json.py +++ b/src/glue/jobs/s3_to_json.py @@ -14,15 +14,16 @@ import sys import typing import zipfile + import boto3 import ecs_logging from awsglue.utils import getResolvedOptions DATA_TYPES_WITH_SUBTYPE = [ - "HealthKitV2Samples", - "HealthKitV2Statistics", - "HealthKitV2Samples_Deleted", - "HealthKitV2Statistics_Deleted" + "HealthKitV2Samples", + "HealthKitV2Statistics", + "HealthKitV2Samples_Deleted", + "HealthKitV2Statistics_Deleted", ] logger = logging.getLogger(__name__) @@ -32,13 +33,15 @@ logger.addHandler(handler) logger.propagate = False + def transform_object_to_array_of_objects( - json_obj_to_replace: dict, - key_name: str, - key_type: type, - value_name: str, - value_type: type, - logger_context: dict={},) -> list: + json_obj_to_replace: dict, + key_name: str, + key_type: type, + value_name: str, + value_type: type, + logger_context: dict = {}, +) -> list: """ Transforms a dictionary object into an array of dictionaries with specified key and value types. @@ -93,33 +96,26 @@ def transform_object_to_array_of_objects( except ValueError as error: key_value = None _log_error_transform_object_to_array_of_objects( - value=k, - value_type=key_type, - error=error, - logger_context=logger_context + value=k, value_type=key_type, error=error, logger_context=logger_context ) try: value_value = value_type(v) except ValueError as error: value_value = None _log_error_transform_object_to_array_of_objects( - value=v, - value_type=value_type, - error=error, - logger_context=logger_context + value=v, + value_type=value_type, + error=error, + logger_context=logger_context, ) - obj = { - key_name: key_value, - value_name: value_value - } + obj = {key_name: key_value, value_name: value_value} array_of_obj.append(obj) return array_of_obj + def _log_error_transform_object_to_array_of_objects( - value: typing.Any, - value_type: type, - error, - logger_context: dict) -> None: + value: typing.Any, value_type: type, error, logger_context: dict +) -> None: """ Logging helper for `transform_object_to_array_of_objects` @@ -135,40 +131,41 @@ def _log_error_transform_object_to_array_of_objects( """ value_error = "Failed to cast %s to %s." logger.error( - value_error, value, value_type, - extra=dict( - merge_dicts( - logger_context, - { - "error.message": repr(error), - "error.type": type(error).__name__, - "event.kind": "alert", - "event.category": ["configuration"], - "event.type": ["change"], - "event.outcome": "failure" - } - ) + value_error, + value, + value_type, + extra=dict( + merge_dicts( + logger_context, + { + "error.message": repr(error), + "error.type": type(error).__name__, + "event.kind": "alert", + "event.category": ["configuration"], + "event.type": ["change"], + "event.outcome": "failure", + }, ) + ), ) logger.warning( - "Setting %s to None", value, - extra=dict( - merge_dicts( - logger_context, - { - "event.kind": "alert", - "event.category": ["configuration"], - "event.type": ["deletion"], - "event.outcome": "success" - } - ) + "Setting %s to None", + value, + extra=dict( + merge_dicts( + logger_context, + { + "event.kind": "alert", + "event.category": ["configuration"], + "event.type": ["deletion"], + "event.outcome": "success", + }, ) + ), ) -def transform_json( - json_obj: dict, - metadata: dict, - logger_context: dict={}) -> dict: + +def transform_json(json_obj: dict, metadata: dict, logger_context: dict = {}) -> dict: """ Perform the following transformations: @@ -199,10 +196,7 @@ def transform_json( Returns: json_obj (dict) The JSON object with the relevant transformations applied. """ - json_obj = _add_universal_properties( - json_obj=json_obj, - metadata=metadata - ) + json_obj = _add_universal_properties(json_obj=json_obj, metadata=metadata) if metadata["type"] in DATA_TYPES_WITH_SUBTYPE: # This puts the `Type` property back where Apple intended it to be json_obj["Type"] = metadata["subtype"] @@ -211,51 +205,59 @@ def transform_json( json_obj["Value"] = json.loads(json_obj["Value"]) if metadata["type"] == "EnrolledParticipants": json_obj = _cast_custom_fields_to_array( - json_obj=json_obj, - logger_context=logger_context, + json_obj=json_obj, + logger_context=logger_context, ) # These Garmin data types have fields which would be better formatted # as an array of objects. garmin_transform_types = { - "GarminDailySummary": { - "TimeOffsetHeartRateSamples": (("OffsetInSeconds", int), ("HeartRate", int)) - }, - "GarminHrvSummary": { - "HrvValues": (("OffsetInSeconds", int), ("Hrv", int)) - }, - "GarminPulseOxSummary": { - "TimeOffsetSpo2Values": (("OffsetInSeconds", int), ("Spo2Value", int)) - }, - "GarminRespirationSummary": { - "TimeOffsetEpochToBreaths": (("OffsetInSeconds", int), ("Breaths", float)) - }, - "GarminSleepSummary": { - "TimeOffsetSleepSpo2": (("OffsetInSeconds", int), ("Spo2Value", int)), - "TimeOffsetSleepRespiration": (("OffsetInSeconds", int), ("Breaths", float)) - }, - "GarminStressDetailSummary": { - "TimeOffsetStressLevelValues": (("OffsetInSeconds", int), ("StressLevel", int)), - "TimeOffsetBodyBatteryValues": (("OffsetInSeconds", int), ("BodyBattery", int)) - }, - "GarminThirdPartyDailySummary": { - "TimeOffsetHeartRateSamples": (("OffsetInSeconds", int), ("HeartRate", int)) - }, - "GarminHealthSnapshotSummary": { - "Summaries.EpochSummaries": (("OffsetInSeconds", int), ("Value", float)) - } + "GarminDailySummary": { + "TimeOffsetHeartRateSamples": (("OffsetInSeconds", int), ("HeartRate", int)) + }, + "GarminHrvSummary": {"HrvValues": (("OffsetInSeconds", int), ("Hrv", int))}, + "GarminPulseOxSummary": { + "TimeOffsetSpo2Values": (("OffsetInSeconds", int), ("Spo2Value", int)) + }, + "GarminRespirationSummary": { + "TimeOffsetEpochToBreaths": (("OffsetInSeconds", int), ("Breaths", float)) + }, + "GarminSleepSummary": { + "TimeOffsetSleepSpo2": (("OffsetInSeconds", int), ("Spo2Value", int)), + "TimeOffsetSleepRespiration": ( + ("OffsetInSeconds", int), + ("Breaths", float), + ), + }, + "GarminStressDetailSummary": { + "TimeOffsetStressLevelValues": ( + ("OffsetInSeconds", int), + ("StressLevel", int), + ), + "TimeOffsetBodyBatteryValues": ( + ("OffsetInSeconds", int), + ("BodyBattery", int), + ), + }, + "GarminThirdPartyDailySummary": { + "TimeOffsetHeartRateSamples": (("OffsetInSeconds", int), ("HeartRate", int)) + }, + "GarminHealthSnapshotSummary": { + "Summaries.EpochSummaries": (("OffsetInSeconds", int), ("Value", float)) + }, } if metadata["type"] in garmin_transform_types: json_obj = _transform_garmin_data_types( - json_obj=json_obj, - data_type_transforms=garmin_transform_types[metadata["type"]], - logger_context=logger_context, + json_obj=json_obj, + data_type_transforms=garmin_transform_types[metadata["type"]], + logger_context=logger_context, ) return json_obj + def _add_universal_properties( - json_obj: dict, - metadata: dict, + json_obj: dict, + metadata: dict, ) -> dict: """ Adds properties which ought to exist for every JSON object. @@ -282,6 +284,7 @@ def _add_universal_properties( json_obj["cohort"] = metadata["cohort"] return json_obj + def _cast_custom_fields_to_array(json_obj: dict, logger_context: dict) -> dict: """ Cast `CustomFields` property values to an array. @@ -301,10 +304,9 @@ def _cast_custom_fields_to_array(json_obj: dict, logger_context: dict) -> dict: json_obj (dict) The JSON object with the relevant transformations applied. """ for field_name in ["Symptoms", "Treatments"]: - if ( - field_name in json_obj["CustomFields"] - and isinstance(json_obj["CustomFields"][field_name], str) - ): + if field_name in json_obj["CustomFields"] and isinstance( + json_obj["CustomFields"][field_name], str + ): if len(json_obj["CustomFields"][field_name]) > 0: # This JSON string was written in a couple different ways # in the testing data: "[{\\\"id\\\": ..." and "[{\"id\": ..." @@ -313,38 +315,41 @@ def _cast_custom_fields_to_array(json_obj: dict, logger_context: dict) -> dict: # written as an object rather than a string). try: json_obj["CustomFields"][field_name] = json.loads( - json_obj["CustomFields"][field_name] + json_obj["CustomFields"][field_name] ) except json.JSONDecodeError as error: # If it's not propertly formatted JSON, then we # can't read it, and instead store an empty list logger.error( - (f"Problem CustomFields.{field_name}: " - f"{json_obj['CustomFields'][field_name]}"), - extra=dict( - merge_dicts( - logger_context, - { - "error.message": repr(error), - "error.type": "json.JSONDecodeError", - "event.kind": "alert", - "event.category": ["change"], - "event.type": ["error"], - "event.outcome": "failure", - } - ) + ( + f"Problem CustomFields.{field_name}: " + f"{json_obj['CustomFields'][field_name]}" + ), + extra=dict( + merge_dicts( + logger_context, + { + "error.message": repr(error), + "error.type": "json.JSONDecodeError", + "event.kind": "alert", + "event.category": ["change"], + "event.type": ["error"], + "event.outcome": "failure", + }, ) + ), ) json_obj["CustomFields"][field_name] = [] else: json_obj["CustomFields"][field_name] = [] return json_obj + def _transform_garmin_data_types( - json_obj: dict, - data_type_transforms: dict, - logger_context: dict, - ) -> dict: + json_obj: dict, + data_type_transforms: dict, + logger_context: dict, +) -> dict: """ Transform objects to an array of objects for relevant Garmin data types. @@ -378,12 +383,12 @@ def _transform_garmin_data_types( prop_name = property_hierarchy[0] if prop_name in json_obj: array_of_obj = transform_object_to_array_of_objects( - json_obj_to_replace=json_obj[prop_name], - key_name=key_name, - key_type=key_type, - value_name=value_name, - value_type=value_type, - logger_context=logger_context + json_obj_to_replace=json_obj[prop_name], + key_name=key_name, + key_type=key_type, + value_name=value_name, + value_type=value_type, + logger_context=logger_context, ) json_obj[prop_name] = array_of_obj if len(property_hierarchy) == 2: @@ -393,16 +398,17 @@ def _transform_garmin_data_types( for obj in json_obj[prop_name]: if sub_prop_name in obj: array_of_obj = transform_object_to_array_of_objects( - json_obj_to_replace=obj[sub_prop_name], - key_name=key_name, - key_type=key_type, - value_name=value_name, - value_type=value_type, - logger_context=logger_context + json_obj_to_replace=obj[sub_prop_name], + key_name=key_name, + key_type=key_type, + value_name=value_name, + value_type=value_type, + logger_context=logger_context, ) obj[sub_prop_name] = array_of_obj return json_obj + def get_output_filename(metadata: dict, part_number: int) -> str: """ Get a formatted file name. @@ -423,32 +429,32 @@ def get_output_filename(metadata: dict, part_number: int) -> str: """ if metadata["type"] in DATA_TYPES_WITH_SUBTYPE: output_fname = "{}_{}_{}-{}.part{}.ndjson".format( - metadata["type"], - metadata["subtype"], - metadata["start_date"].strftime("%Y%m%d"), - metadata["end_date"].strftime("%Y%m%d"), - part_number + metadata["type"], + metadata["subtype"], + metadata["start_date"].strftime("%Y%m%d"), + metadata["end_date"].strftime("%Y%m%d"), + part_number, ) elif metadata["start_date"] is None: output_fname = "{}_{}.part{}.ndjson".format( - metadata["type"], - metadata["end_date"].strftime("%Y%m%d"), - part_number - ) + metadata["type"], metadata["end_date"].strftime("%Y%m%d"), part_number + ) else: output_fname = "{}_{}-{}.part{}.ndjson".format( - metadata["type"], - metadata["start_date"].strftime("%Y%m%d"), - metadata["end_date"].strftime("%Y%m%d"), - part_number + metadata["type"], + metadata["start_date"].strftime("%Y%m%d"), + metadata["end_date"].strftime("%Y%m%d"), + part_number, ) return output_fname + def transform_block( - input_json: typing.IO, - metadata: dict, - logger_context: dict={}, - block_size: int=10000): + input_json: typing.IO, + metadata: dict, + logger_context: dict = {}, + block_size: int = 10000, +): """ A generator function which yields a block of transformed JSON records. @@ -472,25 +478,25 @@ def transform_block( for json_line in input_json: json_obj = json.loads(json_line) json_obj = transform_json( - json_obj=json_obj, - metadata=metadata, - logger_context=logger_context + json_obj=json_obj, metadata=metadata, logger_context=logger_context ) block.append(json_obj) if len(block) == block_size: yield block block = [] - if block: # yield final block + if block: # yield final block yield block + def write_file_to_json_dataset( - z: zipfile.ZipFile, - json_path: str, - metadata: dict, - workflow_run_properties: dict, - logger_context: dict={}, - delete_upon_successful_upload: bool=True, - file_size_limit: float=1e8) -> list: + z: zipfile.ZipFile, + json_path: str, + metadata: dict, + workflow_run_properties: dict, + logger_context: dict = {}, + delete_upon_successful_upload: bool = True, + file_size_limit: float = 1e8, +) -> list: """ Write JSON from a zipfile to a JSON dataset. @@ -518,14 +524,12 @@ def write_file_to_json_dataset( """ # Configuration related to where we write our part files part_dir = os.path.join( - f"dataset={metadata['type']}", f"cohort={metadata['cohort']}") + f"dataset={metadata['type']}", f"cohort={metadata['cohort']}" + ) os.makedirs(part_dir, exist_ok=True) part_number = 0 output_path = get_part_path( - metadata=metadata, - part_number=part_number, - part_dir=part_dir, - touch=True + metadata=metadata, part_number=part_number, part_dir=part_dir, touch=True ) # We will attach file metadata to the uploaded S3 object @@ -535,28 +539,26 @@ def write_file_to_json_dataset( current_output_path = output_path line_count = 0 for transformed_block in transform_block( - input_json=input_json, - metadata=metadata, - logger_context=logger_context + input_json=input_json, metadata=metadata, logger_context=logger_context ): current_file_size = os.path.getsize(current_output_path) if current_file_size > file_size_limit: # Upload completed part file _upload_file_to_json_dataset( - file_path=current_output_path, - s3_metadata=s3_metadata, - workflow_run_properties=workflow_run_properties, - delete_upon_successful_upload=delete_upon_successful_upload + file_path=current_output_path, + s3_metadata=s3_metadata, + workflow_run_properties=workflow_run_properties, + delete_upon_successful_upload=delete_upon_successful_upload, ) uploaded_files.append(current_output_path) # Update output path to next part part_number += 1 current_output_path = get_part_path( - metadata=metadata, - part_number=part_number, - part_dir=part_dir, - touch=True + metadata=metadata, + part_number=part_number, + part_dir=part_dir, + touch=True, ) with open(current_output_path, "a") as f_out: # Write block data to part file @@ -565,10 +567,10 @@ def write_file_to_json_dataset( f_out.write("{}\n".format(json.dumps(transformed_record))) # Upload final block _upload_file_to_json_dataset( - file_path=current_output_path, - s3_metadata=s3_metadata, - workflow_run_properties=workflow_run_properties, - delete_upon_successful_upload=delete_upon_successful_upload + file_path=current_output_path, + s3_metadata=s3_metadata, + workflow_run_properties=workflow_run_properties, + delete_upon_successful_upload=delete_upon_successful_upload, ) uploaded_files.append(current_output_path) logger_extra = dict( @@ -581,16 +583,16 @@ def write_file_to_json_dataset( "event.type": ["info", "creation"], "event.action": "list-file-properties", "labels": { - k: v.isoformat() - if isinstance(v, datetime.datetime) else v + k: v.isoformat() if isinstance(v, datetime.datetime) else v for k, v in metadata.items() - } - } + }, + }, ) ) logger.info("Output file attributes", extra=logger_extra) return uploaded_files + def _derive_str_metadata(metadata: dict) -> dict: """ Format metadata values as strings @@ -601,19 +603,19 @@ def _derive_str_metadata(metadata: dict) -> dict: Returns: (dict) The S3 metadata """ - s3_metadata = { - k: v for k, v in metadata.items() if v is not None - } + s3_metadata = {k: v for k, v in metadata.items() if v is not None} for k, v in s3_metadata.items(): if isinstance(v, datetime.datetime): s3_metadata[k] = v.isoformat() return s3_metadata + def _upload_file_to_json_dataset( - file_path: str, - s3_metadata: dict, - workflow_run_properties: dict, - delete_upon_successful_upload: bool,) -> str: + file_path: str, + s3_metadata: dict, + workflow_run_properties: dict, + delete_upon_successful_upload: bool, +) -> str: """ A helper function for `write_file_to_json_dataset` which handles the actual uploading of the data to S3. @@ -633,19 +635,19 @@ def _upload_file_to_json_dataset( s3_output_key = os.path.join( workflow_run_properties["namespace"], workflow_run_properties["json_prefix"], - file_path + file_path, ) basic_file_info = get_basic_file_info(file_path=file_path) with open(file_path, "rb") as f_in: response = s3_client.put_object( - Body = f_in, - Bucket = workflow_run_properties["json_bucket"], - Key = s3_output_key, - Metadata = s3_metadata + Body=f_in, + Bucket=workflow_run_properties["json_bucket"], + Key=s3_output_key, + Metadata=s3_metadata, ) logger.info( "Upload to S3", - extra = { + extra={ **basic_file_info, "event.kind": "event", "event.category": ["database"], @@ -654,14 +656,15 @@ def _upload_file_to_json_dataset( "labels": { **s3_metadata, "bucket": workflow_run_properties["json_bucket"], - "key": s3_output_key - } - } + "key": s3_output_key, + }, + }, ) if delete_upon_successful_upload: os.remove(file_path) return s3_output_key + def merge_dicts(x: dict, y: dict) -> typing.Generator: """ Merge two dictionaries recursively. @@ -679,9 +682,9 @@ def merge_dicts(x: dict, y: dict) -> typing.Generator: overlapping_keys = x.keys() & y.keys() for key in all_keys: if ( - key in overlapping_keys - and isinstance(x[key], dict) - and isinstance(y[key], dict) + key in overlapping_keys + and isinstance(x[key], dict) + and isinstance(y[key], dict) ): # Merge child dictionaries yield (key, dict(merge_dicts(x[key], y[key]))) @@ -697,11 +700,13 @@ def merge_dicts(x: dict, y: dict) -> typing.Generator: # The key:value pair from y is retained yield (key, y[key]) + def get_part_path( - metadata: dict, - part_number: int, - part_dir: str, - touch: bool,): + metadata: dict, + part_number: int, + part_dir: str, + touch: bool, +): """ A helper function for `write_file_to_json_dataset` @@ -721,10 +726,7 @@ def get_part_path( FileExistsError: If touch is True and a file already exists at the part path. """ - output_filename = get_output_filename( - metadata=metadata, - part_number=part_number - ) + output_filename = get_output_filename(metadata=metadata, part_number=part_number) output_path = os.path.join(part_dir, output_filename) if touch: os.makedirs(part_dir, exist_ok=True) @@ -733,6 +735,7 @@ def get_part_path( pass return output_path + def get_metadata(basename: str) -> dict: """ Get metadata of a file by parsing its basename. @@ -756,20 +759,20 @@ def get_metadata(basename: str) -> dict: metadata["type"] = basename_components[0] if "-" in basename_components[-1]: start_date, end_date = basename_components[-1].split("-") - metadata["start_date"] = \ - datetime.datetime.strptime(start_date, "%Y%m%d") - metadata["end_date"] = \ - datetime.datetime.strptime(end_date, "%Y%m%d") + metadata["start_date"] = datetime.datetime.strptime(start_date, "%Y%m%d") + metadata["end_date"] = datetime.datetime.strptime(end_date, "%Y%m%d") else: metadata["start_date"] = None - metadata["end_date"] = \ - datetime.datetime.strptime(basename_components[-1], "%Y%m%d") + metadata["end_date"] = datetime.datetime.strptime( + basename_components[-1], "%Y%m%d" + ) if metadata["type"] in DATA_TYPES_WITH_SUBTYPE: metadata["subtype"] = basename_components[1] if "HealthKitV2" in metadata["type"] and basename_components[-2] == "Deleted": metadata["type"] = "{}_Deleted".format(metadata["type"]) return metadata + def get_basic_file_info(file_path: str) -> dict: """ Returns a dictionary of basic information about a file. @@ -788,15 +791,12 @@ def get_basic_file_info(file_path: str) -> dict: "file.type": "file", "file.path": file_path, "file.name": os.path.basename(file_path), - "file.extension": os.path.splitext(file_path)[-1][1:] + "file.extension": os.path.splitext(file_path)[-1][1:], } return basic_file_info -def process_record( - s3_obj: dict, - cohort: str, - workflow_run_properties: dict - ) -> None: + +def process_record(s3_obj: dict, cohort: str, workflow_run_properties: dict) -> None: """ Write the contents of a .zip archive stored on S3 to their respective JSON dataset. Each file contained in the .zip archive has its line count logged. @@ -814,104 +814,102 @@ def process_record( """ with zipfile.ZipFile(io.BytesIO(s3_obj["Body"])) as z: non_empty_contents = [ - f.filename for f in z.filelist - if "/" not in f.filename and - "Manifest" not in f.filename and - f.file_size > 0 + f.filename + for f in z.filelist + if "/" not in f.filename + and "Manifest" not in f.filename + and f.file_size > 0 ] for json_path in non_empty_contents: with z.open(json_path, "r") as f: line_count = sum(1 for _ in f) basic_file_info = get_basic_file_info(file_path=json_path) metadata = get_metadata( - basename=os.path.basename(json_path), + basename=os.path.basename(json_path), ) metadata["cohort"] = cohort metadata_str_keys = _derive_str_metadata(metadata=metadata) logger_context = { - **basic_file_info, - "labels": metadata_str_keys, - "process.parent.pid": workflow_run_properties["WORKFLOW_RUN_ID"], - "process.parent.name": workflow_run_properties["WORKFLOW_NAME"], + **basic_file_info, + "labels": metadata_str_keys, + "process.parent.pid": workflow_run_properties["WORKFLOW_RUN_ID"], + "process.parent.name": workflow_run_properties["WORKFLOW_NAME"], } logger.info( - "Input file attributes", - extra=dict( - merge_dicts( - logger_context, - { - "file.size": sys.getsizeof(json_path), - "file.LineCount": line_count, - "event.kind": "metric", - "event.category": ["file"], - "event.type": ["info", "access"], - "event.action": "list-file-properties", - } - ) + "Input file attributes", + extra=dict( + merge_dicts( + logger_context, + { + "file.size": sys.getsizeof(json_path), + "file.LineCount": line_count, + "event.kind": "metric", + "event.category": ["file"], + "event.type": ["info", "access"], + "event.action": "list-file-properties", + }, ) + ), ) write_file_to_json_dataset( - z=z, - json_path=json_path, - metadata=metadata, - workflow_run_properties=workflow_run_properties, - logger_context=logger_context) + z=z, + json_path=json_path, + metadata=metadata, + workflow_run_properties=workflow_run_properties, + logger_context=logger_context, + ) -def main() -> None: +def main() -> None: # Instantiate boto clients glue_client = boto3.client("glue") s3_client = boto3.client("s3") # Get job and workflow arguments - args = getResolvedOptions( - sys.argv, - ["WORKFLOW_NAME", - "WORKFLOW_RUN_ID" - ] - ) + args = getResolvedOptions(sys.argv, ["WORKFLOW_NAME", "WORKFLOW_RUN_ID"]) workflow_run_properties = glue_client.get_workflow_run_properties( - Name=args["WORKFLOW_NAME"], - RunId=args["WORKFLOW_RUN_ID"])["RunProperties"] + Name=args["WORKFLOW_NAME"], RunId=args["WORKFLOW_RUN_ID"] + )["RunProperties"] workflow_run_properties = {**args, **workflow_run_properties} logger.debug( - "getResolvedOptions", - extra={ - "event.kind": "event", - "event.category": ["process"], - "event.type": ["info"], - "event.action": "get-job-arguments", - "labels": args - } + "getResolvedOptions", + extra={ + "event.kind": "event", + "event.category": ["process"], + "event.type": ["info"], + "event.action": "get-job-arguments", + "labels": args, + }, ) logger.debug( - "get_workflow_run_properties", - extra={ - "event.kind": "event", - "event.category": ["process"], - "event.type": ["info"], - "event.action": "get-workflow-arguments", - "labels": workflow_run_properties - } + "get_workflow_run_properties", + extra={ + "event.kind": "event", + "event.category": ["process"], + "event.type": ["info"], + "event.action": "get-workflow-arguments", + "labels": workflow_run_properties, + }, ) # Load messages to be processed messages = json.loads(workflow_run_properties["messages"]) for message in messages: logger.info( - "Retrieving S3 object", - extra={ - "event.kind": "event", - "event.category": ["database"], - "event.type": ["access"], - "event.action": "get-bucket-object", - "labels": {"bucket": message["source_bucket"], - "key": message["source_key"]} - } + "Retrieving S3 object", + extra={ + "event.kind": "event", + "event.category": ["database"], + "event.type": ["access"], + "event.action": "get-bucket-object", + "labels": { + "bucket": message["source_bucket"], + "key": message["source_key"], + }, + }, ) s3_obj = s3_client.get_object( - Bucket = message["source_bucket"], - Key = message["source_key"] + Bucket=message["source_bucket"], Key=message["source_key"] ) s3_obj["Body"] = s3_obj["Body"].read() cohort = None @@ -921,25 +919,28 @@ def main() -> None: cohort = "pediatric_v1" else: logger.warning( - "Could not determine the cohort of object at %s" - "This file will not be written to a JSON dataset.", - f"s3://{message['source_bucket']}/{message['source_key']}. ", - extra={ - "file.name": message["source_key"], - "event.kind": "alert", - "event.category": ["configuration"], - "event.type": ["creation"], - "event.outcome": "failure", - "labels": {"bucket": message["source_bucket"], - "key": message["source_key"]} - } + "Could not determine the cohort of object at %s" + "This file will not be written to a JSON dataset.", + f"s3://{message['source_bucket']}/{message['source_key']}. ", + extra={ + "file.name": message["source_key"], + "event.kind": "alert", + "event.category": ["configuration"], + "event.type": ["creation"], + "event.outcome": "failure", + "labels": { + "bucket": message["source_bucket"], + "key": message["source_key"], + }, + }, ) continue process_record( - s3_obj=s3_obj, - cohort=cohort, - workflow_run_properties=workflow_run_properties + s3_obj=s3_obj, + cohort=cohort, + workflow_run_properties=workflow_run_properties, ) + if __name__ == "__main__": main() diff --git a/src/lambda_function/dispatch/app.py b/src/lambda_function/dispatch/app.py index 35f14930..00564073 100644 --- a/src/lambda_function/dispatch/app.py +++ b/src/lambda_function/dispatch/app.py @@ -9,7 +9,7 @@ import logging import os import zipfile -from typing import Optional # use | for type hints in 3.10+ +from typing import Optional # use | for type hints in 3.10+ from urllib import parse import boto3 @@ -17,6 +17,7 @@ logger = logging.getLogger() logger.setLevel(logging.INFO) + def filter_object_info(object_info: dict) -> Optional[dict]: """ Filter out objects that should not be processed. @@ -60,6 +61,7 @@ def filter_object_info(object_info: dict) -> Optional[dict]: return None return object_info + def get_object_info(s3_event: dict) -> dict: """ Derive object info from an S3 event. @@ -78,6 +80,7 @@ def get_object_info(s3_event: dict) -> dict: } return object_info + def get_archive_contents(archive_path: str, bucket: str, key: str) -> list[dict]: """ Inspect a ZIP archive for its file contents. @@ -100,20 +103,21 @@ def get_archive_contents(archive_path: str, bucket: str, key: str) -> list[dict] with zipfile.ZipFile(archive_path, "r") as archive: for path in archive.infolist(): if ( - "/" not in path.filename # necessary for pilot data only + "/" not in path.filename # necessary for pilot data only and "Manifest" not in path.filename and path.file_size > 0 ): file_info = { - "Bucket": bucket, - "Key": key, - "Path": path.filename, - "FileSize": path.file_size + "Bucket": bucket, + "Key": key, + "Path": path.filename, + "FileSize": path.file_size, } archive_contents.append(file_info) return archive_contents -def lambda_handler(event: dict, context:dict) -> None: + +def lambda_handler(event: dict, context: dict) -> None: """ This function serves as the entrypoint and will be triggered upon polling the input-to-dispatch SQS queue. @@ -138,16 +142,17 @@ def lambda_handler(event: dict, context:dict) -> None: s3_client=s3_client, sns_client=sns_client, dispatch_sns_arn=dispatch_sns_arn, - temp_zip_path=temp_zip_path + temp_zip_path=temp_zip_path, ) + def main( - event: dict, - context: dict, - sns_client: "botocore.client.SNS", - s3_client: "botocore.client.S3", - dispatch_sns_arn: str, - temp_zip_path: str + event: dict, + context: dict, + sns_client: "botocore.client.SNS", + s3_client: "botocore.client.S3", + dispatch_sns_arn: str, + temp_zip_path: str, ) -> None: """ This function should be invoked by `lambda_handler`. @@ -170,21 +175,20 @@ def main( logger.info(f"Received SNS message: {sns_message}") all_object_info_list = map(get_object_info, sns_message["Records"]) valid_object_info_list = [ - object_info - for object_info in all_object_info_list - if filter_object_info(object_info) is not None + object_info + for object_info in all_object_info_list + if filter_object_info(object_info) is not None ] for object_info in valid_object_info_list: s3_client.download_file(Filename=temp_zip_path, **object_info) logger.info(f"Getting archive contents for {object_info}") archive_contents = get_archive_contents( - archive_path=temp_zip_path, - bucket=object_info["Bucket"], - key=object_info["Key"] + archive_path=temp_zip_path, + bucket=object_info["Bucket"], + key=object_info["Key"], ) for file_info in archive_contents: logger.info(f"Publishing {file_info} to {dispatch_sns_arn}") sns_client.publish( - TopicArn=dispatch_sns_arn, - Message=json.dumps(file_info) + TopicArn=dispatch_sns_arn, Message=json.dumps(file_info) ) diff --git a/src/lambda_function/s3_event_config/app.py b/src/lambda_function/s3_event_config/app.py index 8d8d54ab..6cc22fb6 100644 --- a/src/lambda_function/s3_event_config/app.py +++ b/src/lambda_function/s3_event_config/app.py @@ -8,9 +8,9 @@ Only certain notification configurations work. Only `QueueConfigurations` are expected. """ -import os import json import logging +import os import typing from collections import defaultdict from enum import Enum @@ -31,7 +31,7 @@ def lambda_handler(event, context): delete_notification( s3_client=s3_client, bucket=os.environ["S3_SOURCE_BUCKET_NAME"], - bucket_key_prefix=os.environ["BUCKET_KEY_PREFIX"] + bucket_key_prefix=os.environ["BUCKET_KEY_PREFIX"], ) logger.info("Sending response to custom resource after Delete") elif event["RequestType"] in ["Update", "Create"]: @@ -53,6 +53,7 @@ class NotificationConfigurationType(Enum): """ Supported types for an S3 event configuration. """ + Topic = "Topic" Queue = "Queue" LambdaFunction = "LambdaFunction" @@ -62,6 +63,7 @@ class NotificationConfiguration: """ An abstraction of S3 event configurations. """ + def __init__(self, notification_type: NotificationConfigurationType, value: dict): self.type = notification_type.value self.value = value @@ -78,7 +80,9 @@ def get_arn(self): elif self.type == NotificationConfigurationType.LambdaFunction.value: arn = self.value["LambdaFunctionArn"] else: - raise ValueError(f"{self.type} is not a recognized notification configuration type.") + raise ValueError( + f"{self.type} is not a recognized notification configuration type." + ) return arn @@ -86,6 +90,7 @@ class BucketNotificationConfigurations: """ A convenience class for working with a collection of `NotificationConfiguration`s. """ + def __init__(self, notification_configurations: list[NotificationConfiguration]): self.configs = notification_configurations @@ -103,7 +108,7 @@ def to_dict(self): def get_bucket_notification_configurations( s3_client: boto3.client, bucket: str, - ) -> BucketNotificationConfigurations: +) -> BucketNotificationConfigurations: """ Gets the existing bucket notification configuration and the existing notification configurations for a specific destination type. @@ -116,30 +121,31 @@ def get_bucket_notification_configurations( Returns: BucketNotificationConfigurations """ - bucket_notification_configuration = \ - s3_client.get_bucket_notification_configuration(Bucket=bucket) + bucket_notification_configuration = s3_client.get_bucket_notification_configuration( + Bucket=bucket + ) all_notification_configurations = [] for configuration_type in NotificationConfigurationType: configuration_type_name = f"{configuration_type.value}Configurations" if configuration_type_name in bucket_notification_configuration: notification_configurations = [ - NotificationConfiguration( - notification_type=configuration_type, - value=config - ) - for config - in bucket_notification_configuration[configuration_type_name] + NotificationConfiguration( + notification_type=configuration_type, value=config + ) + for config in bucket_notification_configuration[configuration_type_name] ] all_notification_configurations.extend(notification_configurations) - bucket_notification_configurations = BucketNotificationConfigurations(all_notification_configurations) + bucket_notification_configurations = BucketNotificationConfigurations( + all_notification_configurations + ) return bucket_notification_configurations def get_notification_configuration( bucket_notification_configurations: BucketNotificationConfigurations, - bucket_key_prefix: typing.Union[str,None]=None, - bucket_key_suffix: typing.Union[str,None]=None, - ) -> typing.Union[NotificationConfiguration,None]: + bucket_key_prefix: typing.Union[str, None] = None, + bucket_key_suffix: typing.Union[str, None] = None, +) -> typing.Union[NotificationConfiguration, None]: """ Filter the list of existing notifications based on the unique S3 key prefix and suffix. @@ -165,11 +171,11 @@ def get_notification_configuration( for filter_rule in filter_rules: if filter_rule["Name"] == "Prefix" and bucket_key_prefix is not None: common_prefix_path = bool( - os.path.commonpath([filter_rule["Value"], bucket_key_prefix]) + os.path.commonpath([filter_rule["Value"], bucket_key_prefix]) ) elif filter_rule["Name"] == "Suffix" and bucket_key_suffix is not None: common_suffix_path = bool( - os.path.commonpath([filter_rule["Value"], bucket_key_suffix]) + os.path.commonpath([filter_rule["Value"], bucket_key_suffix]) ) if common_prefix_path and common_suffix_path: return notification_configuration @@ -178,7 +184,7 @@ def get_notification_configuration( def create_formatted_message( bucket: str, destination_type: str, destination_arn: str - ) -> str: +) -> str: """Creates a formatted message for logging purposes. Arguments: @@ -191,6 +197,7 @@ def create_formatted_message( """ return f"Bucket: {bucket}, DestinationType: {destination_type}, DestinationArn: {destination_arn}" + def normalize_filter_rules(config: NotificationConfiguration): """ Modify the filter rules of a notification configuration so that it is get/put agnostic. @@ -216,10 +223,10 @@ def normalize_filter_rules(config: NotificationConfiguration): config.value["Filter"]["Key"]["FilterRules"] = new_filter_rules return config + def notification_configuration_matches( - config: NotificationConfiguration, - other_config: NotificationConfiguration - ) -> bool: + config: NotificationConfiguration, other_config: NotificationConfiguration +) -> bool: """Determines if two S3 event notification configurations are functionally equivalent. Two notification configurations are considered equivalent if: @@ -237,38 +244,31 @@ def notification_configuration_matches( config = normalize_filter_rules(config) other_config = normalize_filter_rules(other_config) arn_match = other_config.arn == config.arn - events_match = ( - set(other_config.value["Events"]) == - set(config.value["Events"]) - ) - filter_rule_names_match = ( - { - filter_rule["Name"] - for filter_rule - in other_config.value["Filter"]["Key"]["FilterRules"] - } == - { - filter_rule["Name"] - for filter_rule - in config.value["Filter"]["Key"]["FilterRules"] - } - ) + events_match = set(other_config.value["Events"]) == set(config.value["Events"]) + filter_rule_names_match = { + filter_rule["Name"] + for filter_rule in other_config.value["Filter"]["Key"]["FilterRules"] + } == { + filter_rule["Name"] + for filter_rule in config.value["Filter"]["Key"]["FilterRules"] + } filter_rule_values_match = all( [ any( [ filter_rule["Value"] == other_filter_rule["Value"] - for filter_rule - in config.value["Filter"]["Key"]["FilterRules"] + for filter_rule in config.value["Filter"]["Key"]["FilterRules"] if filter_rule["Name"] == other_filter_rule["Name"] ] ) - for other_filter_rule - in other_config.value["Filter"]["Key"]["FilterRules"] + for other_filter_rule in other_config.value["Filter"]["Key"]["FilterRules"] ] ) configurations_match = ( - arn_match and events_match and filter_rule_names_match and filter_rule_values_match + arn_match + and events_match + and filter_rule_names_match + and filter_rule_values_match ) return configurations_match @@ -279,7 +279,7 @@ def add_notification( destination_arn: str, bucket: str, bucket_key_prefix: str, - ) -> None: +) -> None: """Adds the S3 notification configuration to an existing bucket. Notification configurations are identified by their unique prefix/suffix filter rules. @@ -303,18 +303,19 @@ def add_notification( "Events": ["s3:ObjectCreated:*"], "Filter": { "Key": { - "FilterRules": [{"Name": "prefix", "Value": os.path.join(bucket_key_prefix, "")}] + "FilterRules": [ + {"Name": "prefix", "Value": os.path.join(bucket_key_prefix, "")} + ] } }, } new_notification_configuration = NotificationConfiguration( notification_type=NotificationConfigurationType(destination_type), - value=new_notification_configuration_value + value=new_notification_configuration_value, ) ### Get any matching notification configuration bucket_notification_configurations = get_bucket_notification_configurations( - s3_client=s3_client, - bucket=bucket + s3_client=s3_client, bucket=bucket ) matching_notification_configuration = get_notification_configuration( bucket_notification_configurations=bucket_notification_configurations, @@ -322,8 +323,8 @@ def add_notification( ) if matching_notification_configuration is not None: is_the_same_configuration = notification_configuration_matches( - config=matching_notification_configuration, - other_config=new_notification_configuration + config=matching_notification_configuration, + other_config=new_notification_configuration, ) if is_the_same_configuration: logger.info( @@ -358,8 +359,8 @@ def add_notification( def delete_notification( - s3_client: boto3.client, bucket: str, bucket_key_prefix: str - ) -> None: + s3_client: boto3.client, bucket: str, bucket_key_prefix: str +) -> None: """ Deletes the S3 notification configuration from an existing bucket based on its unique S3 key prefix/suffix filter rules. @@ -387,9 +388,9 @@ def delete_notification( ) return bucket_notification_configurations.configs = [ - config for config - in bucket_notification_configurations.configs - if config.arn != matching_notification_configuration.arn + config + for config in bucket_notification_configurations.configs + if config.arn != matching_notification_configuration.arn ] ### Delete matching notification configuration logger.info( @@ -397,7 +398,7 @@ def delete_notification( + create_formatted_message( bucket=bucket, destination_type=matching_notification_configuration.type, - destination_arn=matching_notification_configuration.arn + destination_arn=matching_notification_configuration.arn, ) ) s3_client.put_bucket_notification_configuration( @@ -410,5 +411,6 @@ def delete_notification( + create_formatted_message( bucket=bucket, destination_type=matching_notification_configuration.type, - destination_arn=matching_notification_configuration.arn) + destination_arn=matching_notification_configuration.arn, + ) ) diff --git a/src/lambda_function/s3_to_glue/app.py b/src/lambda_function/s3_to_glue/app.py index 51e0cfab..b1f3d209 100644 --- a/src/lambda_function/s3_to_glue/app.py +++ b/src/lambda_function/s3_to_glue/app.py @@ -4,12 +4,13 @@ Subsequently, the S3 objects which were contained in the SQS event are written as a JSON string to the `messages` workflow run property. """ -import os import json import logging -import boto3 +import os from urllib import parse +import boto3 + logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -78,7 +79,7 @@ def submit_s3_to_json_workflow(objects_info: list[dict[str, str]], workflow_name ) -def is_s3_test_event(record : dict) -> bool: +def is_s3_test_event(record: dict) -> bool: """ AWS always sends a s3 test event to the SQS queue whenever a new file is uploaded. We want to skip those @@ -92,6 +93,7 @@ def is_s3_test_event(record : dict) -> bool: else: return False + def get_object_info(s3_event) -> dict: """ Derive object info formatted for submission to Glue from an S3 event. @@ -110,6 +112,7 @@ def get_object_info(s3_event) -> dict: } return object_info + def lambda_handler(event, context) -> None: """ This main lambda function will be triggered by a SQS event and will @@ -145,8 +148,8 @@ def lambda_handler(event, context) -> None: f"{os.environ['S3_TO_JSON_WORKFLOW_NAME']}: {json.dumps(s3_objects_info)}" ) submit_s3_to_json_workflow( - objects_info=s3_objects_info, - workflow_name=os.environ["S3_TO_JSON_WORKFLOW_NAME"] + objects_info=s3_objects_info, + workflow_name=os.environ["S3_TO_JSON_WORKFLOW_NAME"], ) else: logger.info( diff --git a/src/lambda_function/s3_to_glue/events/generate_test_event.py b/src/lambda_function/s3_to_glue/events/generate_test_event.py index 4fa9914d..366c6ea7 100755 --- a/src/lambda_function/s3_to_glue/events/generate_test_event.py +++ b/src/lambda_function/s3_to_glue/events/generate_test_event.py @@ -12,51 +12,59 @@ for detailed information how to use this event with the Lambda. """ -import os import argparse import json +import os + import boto3 -SINGLE_RECORD_OUTFILE = 'single-record.json' -MULTI_RECORD_OUTFILE = 'records.json' +SINGLE_RECORD_OUTFILE = "single-record.json" +MULTI_RECORD_OUTFILE = "records.json" def read_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Generate a JSON file of a mocked S3 event for testing.", - formatter_class = argparse.ArgumentDefaultsHelpFormatter + description="Generate a JSON file of a mocked S3 event for testing.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "--input-bucket", - default="recover-dev-input-data", - help="S3 bucket name containing input data" + "--input-bucket", + default="recover-dev-input-data", + help="S3 bucket name containing input data", ) parser.add_argument( - "--input-key", - default="main/2023-01-12T22--02--17Z_77fefff8-b0e2-4c1b-b0c5-405554c92368", - help="A specific S3 key to generate an event for." + "--input-key", + default="main/2023-01-12T22--02--17Z_77fefff8-b0e2-4c1b-b0c5-405554c92368", + help="A specific S3 key to generate an event for.", ) parser.add_argument( - "--input-key-prefix", - help=("Takes precedence over `--input-key`. If you want " - "to generate a single event containing all data under a specific " - "S3 key prefix, specify that here.") + "--input-key-prefix", + help=( + "Takes precedence over `--input-key`. If you want " + "to generate a single event containing all data under a specific " + "S3 key prefix, specify that here." + ), ) parser.add_argument( - "--input-key-file", - help=("Takes precedence over `--input-key` and `--input-key-prefix`. " - "If you want to generate a single event containing all keys within" - "a newline delimited file, specify the path to that file here.") + "--input-key-file", + help=( + "Takes precedence over `--input-key` and `--input-key-prefix`. " + "If you want to generate a single event containing all keys within" + "a newline delimited file, specify the path to that file here." + ), ) parser.add_argument( - "--output-directory", - default = "./", - help=("Specifies the directory that the S3 notification json gets saved to. " - "Defaults to current directory that the script is running from. ") + "--output-directory", + default="./", + help=( + "Specifies the directory that the S3 notification json gets saved to. " + "Defaults to current directory that the script is running from. " + ), ) args = parser.parse_args() return args + def create_event(bucket: str, key: str, key_prefix: str, key_file: str) -> dict: """ Create an SQS event wrapping a SNS notification of an S3 event notification(s) @@ -88,28 +96,23 @@ def create_event(bucket: str, key: str, key_prefix: str, key_file: str) -> dict: test_data = key_file_contents.split("\n") elif key_prefix is not None: s3_client = boto3.client("s3") - test_objects = s3_client.list_objects_v2( - Bucket=bucket, - Prefix=key_prefix - ) + test_objects = s3_client.list_objects_v2(Bucket=bucket, Prefix=key_prefix) test_data = [ - obj["Key"] for obj in test_objects["Contents"] - if not obj["Key"].endswith("/") + obj["Key"] + for obj in test_objects["Contents"] + if not obj["Key"].endswith("/") ] else: test_data = [key] - s3_events = [ - create_s3_event_record(bucket=bucket, key=k) for k in test_data - ] - sns_notifications = [ - create_sns_notification(s3_event) for s3_event in s3_events - ] + s3_events = [create_s3_event_record(bucket=bucket, key=k) for k in test_data] + sns_notifications = [create_sns_notification(s3_event) for s3_event in s3_events] sqs_messages = [ - create_sqs_message(sns_notification) for sns_notification in sns_notifications + create_sqs_message(sns_notification) for sns_notification in sns_notifications ] sqs_event = {"Records": sqs_messages} return sqs_event + def create_s3_event_record(bucket: str, key: str) -> dict: """ Create an S3 event notification "Record" for an individual S3 object. @@ -123,43 +126,38 @@ def create_s3_event_record(bucket: str, key: str) -> dict: in an S3 event notification """ s3_event_record = { - "eventVersion": "2.0", - "eventSource": "aws:s3", - "awsRegion": "us-east-1", - "eventTime": "1970-01-01T00:00:00.000Z", - "eventName": "ObjectCreated:Put", - "userIdentity": { - "principalId": "EXAMPLE" - }, - "requestParameters": { - "sourceIPAddress": "127.0.0.1" - }, - "responseElements": { - "x-amz-request-id": "EXAMPLE123456789", - "x-amz-id-2": "EXAMPLE123/5678abcdefghijklambdaisawesome/mnopqrstuvwxyzABCDEFGH" - }, - "s3": { - "s3SchemaVersion": "1.0", - "configurationId": "testConfigRule", - "bucket": { - "name": "{bucket}", - "ownerIdentity": { - "principalId": "EXAMPLE" - }, - "arn": "arn:aws:s3:::bucket_arn" + "eventVersion": "2.0", + "eventSource": "aws:s3", + "awsRegion": "us-east-1", + "eventTime": "1970-01-01T00:00:00.000Z", + "eventName": "ObjectCreated:Put", + "userIdentity": {"principalId": "EXAMPLE"}, + "requestParameters": {"sourceIPAddress": "127.0.0.1"}, + "responseElements": { + "x-amz-request-id": "EXAMPLE123456789", + "x-amz-id-2": "EXAMPLE123/5678abcdefghijklambdaisawesome/mnopqrstuvwxyzABCDEFGH", + }, + "s3": { + "s3SchemaVersion": "1.0", + "configurationId": "testConfigRule", + "bucket": { + "name": "{bucket}", + "ownerIdentity": {"principalId": "EXAMPLE"}, + "arn": "arn:aws:s3:::bucket_arn", + }, + "object": { + "key": "{key}", + "size": 1024, + "eTag": "0123456789abcdef0123456789abcdef", + "sequencer": "0A1B2C3D4E5F678901", + }, }, - "object": { - "key": "{key}", - "size": 1024, - "eTag": "0123456789abcdef0123456789abcdef", - "sequencer": "0A1B2C3D4E5F678901" - } - } } s3_event_record["s3"]["bucket"]["name"] = bucket s3_event_record["s3"]["object"]["key"] = key return s3_event_record + def create_sqs_message(sns_notification: dict) -> dict: """ Create an SQS message wrapper around an individual SNS notification. @@ -174,25 +172,26 @@ def create_sqs_message(sns_notification: dict) -> dict: dict: A dictionary formatted as an SQS message """ sqs_event_record = { - "messageId": "bf7be842", - "receiptHandle": "AQEBLdQhbUa", - "body": None, - "attributes": { - "ApproximateReceiveCount": "1", - "SentTimestamp": "1694541052297", - "SenderId": "AIDAJHIPRHEMV73VRJEBU", - "ApproximateFirstReceiveTimestamp": "1694541052299" - }, - "messageAttributes": {}, - "md5OfMessageAttributes": None, - "md5OfBody": "abdc58591d121b6a0334fb44fd45aceb", - "eventSource": "aws:sqs", - "eventSourceARN": "arn:aws:sqs:us-east-1:914833433684:mynamespace-sqs-S3ToLambda-Queue", - "awsRegion": "us-east-1" + "messageId": "bf7be842", + "receiptHandle": "AQEBLdQhbUa", + "body": None, + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1694541052297", + "SenderId": "AIDAJHIPRHEMV73VRJEBU", + "ApproximateFirstReceiveTimestamp": "1694541052299", + }, + "messageAttributes": {}, + "md5OfMessageAttributes": None, + "md5OfBody": "abdc58591d121b6a0334fb44fd45aceb", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-1:914833433684:mynamespace-sqs-S3ToLambda-Queue", + "awsRegion": "us-east-1", } sqs_event_record["body"] = json.dumps(sns_notification) return sqs_event_record + def create_sns_notification(s3_event_record): """ Create an SNS message wrapper for an individual S3 event notification. @@ -207,32 +206,37 @@ def create_sns_notification(s3_event_record): dict: A dictionary formatted as an SQS message """ sns_notification = { - "Type": "string", - "MessageId": "string", - "TopicArn": "string", - "Subject": "string", - "Message": "string", - "Timestamp": "string" + "Type": "string", + "MessageId": "string", + "TopicArn": "string", + "Subject": "string", + "Message": "string", + "Timestamp": "string", } sns_notification["Message"] = json.dumps({"Records": [s3_event_record]}) return sns_notification + def main() -> None: args = read_args() print("Generating mock S3 event...") sqs_event = create_event( - bucket=args.input_bucket, - key=args.input_key, - key_prefix=args.input_key_prefix, - key_file=args.input_key_file + bucket=args.input_bucket, + key=args.input_key, + key_prefix=args.input_key_prefix, + key_file=args.input_key_file, ) if args.input_key_file is not None or args.input_key_prefix is not None: - with open(os.path.join(args.output_directory, MULTI_RECORD_OUTFILE), "w") as outfile: + with open( + os.path.join(args.output_directory, MULTI_RECORD_OUTFILE), "w" + ) as outfile: json.dump(sqs_event, outfile) print(f"Event with multiple records written to {outfile.name}.") else: - with open(os.path.join(args.output_directory, SINGLE_RECORD_OUTFILE), "w") as outfile: + with open( + os.path.join(args.output_directory, SINGLE_RECORD_OUTFILE), "w" + ) as outfile: json.dump(sqs_event, outfile) print(f"Event with single record written to {outfile.name}.") print("Done.") diff --git a/src/scripts/consume_logs/consume_logs.py b/src/scripts/consume_logs/consume_logs.py index 3a5e9af3..866107ac 100644 --- a/src/scripts/consume_logs/consume_logs.py +++ b/src/scripts/consume_logs/consume_logs.py @@ -22,53 +22,58 @@ is written which contains only the subset of files which have a line_count_difference != 0. The schema is the same as above. """ +import argparse import datetime import json import time from collections import defaultdict -from typing import List, Dict +from typing import Dict, List -import argparse import boto3 import pandas + def read_args(): parser = argparse.ArgumentParser( - description="Query S3 to JSON CloudWatch logs and compare line counts.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + description="Query S3 to JSON CloudWatch logs and compare line counts.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument( - "--log-group-name", - help="The name of the log group to query.", - default="/aws-glue/python-jobs/error" + "--log-group-name", + help="The name of the log group to query.", + default="/aws-glue/python-jobs/error", ) parser.add_argument( - "--query", - help="The query to run against the log group.", - default='fields @message | filter event.action = "list-file-properties"' + "--query", + help="The query to run against the log group.", + default='fields @message | filter event.action = "list-file-properties"', ) parser.add_argument( - "--start-datetime", - help=( - "Query start time (local time) expressed in a format parseable by " - "datetime.datetime.strptime. This argument should be " - "formatted as `--time-format`."), - required=True, + "--start-datetime", + help=( + "Query start time (local time) expressed in a format parseable by " + "datetime.datetime.strptime. This argument should be " + "formatted as `--time-format`." + ), + required=True, ) parser.add_argument( - "--end-datetime", - help=( - "Default is \"now\". Query end time (local time) expressed " - "in a format parseable by datetime.datetime.strptime. This " - "argument should be formatted as `--time-format`."), + "--end-datetime", + help=( + 'Default is "now". Query end time (local time) expressed ' + "in a format parseable by datetime.datetime.strptime. This " + "argument should be formatted as `--time-format`." + ), ) parser.add_argument( - "--datetime-format", - help="The time format to use with datetime.datetime.strptime", - default="%Y-%m-%d %H:%M:%S", + "--datetime-format", + help="The time format to use with datetime.datetime.strptime", + default="%Y-%m-%d %H:%M:%S", ) args = parser.parse_args() return args + def get_seconds_since_epoch(datetime_str: str, datetime_format: str) -> int: """ Returns the seconds since epoch for a specific datetime (local time) @@ -89,12 +94,13 @@ def get_seconds_since_epoch(datetime_str: str, datetime_format: str) -> int: seconds_since_epoch = int(parsed_datetime.timestamp()) return seconds_since_epoch + def query_logs( - log_group_name: str, - query_string: str, - start_unix_time: int, - end_unix_time: int, - **kwargs: dict, + log_group_name: str, + query_string: str, + start_unix_time: int, + end_unix_time: int, + **kwargs: dict, ) -> list: """ Query a CloudWatch log group. @@ -118,23 +124,24 @@ def query_logs( """ logs_client = kwargs.get("logs_client", boto3.client("logs")) start_query_response = logs_client.start_query( - logGroupName=log_group_name, - startTime=start_unix_time, - endTime=end_unix_time, - queryString=query_string + logGroupName=log_group_name, + startTime=start_unix_time, + endTime=end_unix_time, + queryString=query_string, ) query_response = logs_client.get_query_results( - queryId=start_query_response["queryId"] + queryId=start_query_response["queryId"] ) _check_for_failed_query(query_response) while query_response["status"] != "Complete": time.sleep(1) query_response = logs_client.get_query_results( - queryId=start_query_response["queryId"] + queryId=start_query_response["queryId"] ) _check_for_failed_query(query_response) return query_response["results"] + def _check_for_failed_query(query_response) -> None: """ Helper function for checking whether a query response has not succeeded. @@ -142,7 +149,10 @@ def _check_for_failed_query(query_response) -> None: if query_response["status"] in ["Failed", "Cancelled", "Timeout", "Unknown"]: raise UserWarning(f"Query failed with status \"{query_response['status']}\"") -def group_query_result_by_workflow_run(query_results: List[List[dict]]) -> Dict[str, List[dict]]: + +def group_query_result_by_workflow_run( + query_results: List[List[dict]], +) -> Dict[str, List[dict]]: """ Associates log records with their workflow run ID. @@ -158,8 +168,8 @@ def group_query_result_by_workflow_run(query_results: List[List[dict]]) -> Dict[ """ # e.g., [{'@message': '{...}', '@ptr': '...'}, ...] log_records = [ - {field['field']: field['value'] for field in log_message} - for log_message in query_results + {field["field"]: field["value"] for field in log_message} + for log_message in query_results ] workflow_run_logs = defaultdict(list) for log_record in log_records: @@ -167,6 +177,7 @@ def group_query_result_by_workflow_run(query_results: List[List[dict]]) -> Dict[ workflow_run_logs[log_message["process"]["parent"]["pid"]].append(log_message) return workflow_run_logs + def transform_logs_to_dataframe(log_messages: List[dict]) -> pandas.DataFrame: """ Construct a pandas DataFrame from log records @@ -187,47 +198,40 @@ def transform_logs_to_dataframe(log_messages: List[dict]) -> pandas.DataFrame: dataframe_records = [] for log_message in log_messages: if ( - "event" not in log_message - or "type" not in log_message["event"] - or not any( - [ - k in log_message["event"]["type"] - for k in ["access", "creation"] - ] - ) - or all( - [ - k in log_message["event"]["type"] - for k in ["access", "creation"] - ] - ) + "event" not in log_message + or "type" not in log_message["event"] + or not any( + [k in log_message["event"]["type"] for k in ["access", "creation"]] + ) + or all([k in log_message["event"]["type"] for k in ["access", "creation"]]) ): raise KeyError( - "Did not find event.type in log message or " - "event.type contained unexpected values " - "for workflow run ID {log_message['process']['parent']['pid']} and " - f"file {json.dumps(log_message['file']['labels'])}" + "Did not find event.type in log message or " + "event.type contained unexpected values " + "for workflow run ID {log_message['process']['parent']['pid']} and " + f"file {json.dumps(log_message['file']['labels'])}" ) if "access" in log_message["event"]["type"]: event_type = "access" else: event_type = "creation" dataframe_record = { - "cohort": log_message["labels"]["cohort"], - "file_name": log_message["file"]["name"], - "event_type": event_type, - "line_count": log_message["file"]["LineCount"] + "cohort": log_message["labels"]["cohort"], + "file_name": log_message["file"]["name"], + "event_type": event_type, + "line_count": log_message["file"]["LineCount"], } dataframe_records.append(dataframe_record) log_dataframe = pandas.DataFrame.from_records(dataframe_records) return log_dataframe + def report_results( - workflow_run_event_comparison: dict, - comparison_report_path: str="consume_logs_comparison_report.csv", - missing_data_report_path: str="consume_logs_missing_data_report.csv", - testing:bool=False, - ) -> pandas.DataFrame: + workflow_run_event_comparison: dict, + comparison_report_path: str = "consume_logs_comparison_report.csv", + missing_data_report_path: str = "consume_logs_missing_data_report.csv", + testing: bool = False, +) -> pandas.DataFrame: """ Report any missing data and save the results. @@ -250,8 +254,7 @@ def report_results( pandas.DataFrame """ all_comparisons = pandas.concat( - workflow_run_event_comparison, - names=["workflow_run_id", "index"] + workflow_run_event_comparison, names=["workflow_run_id", "index"] ) all_missing_data = pandas.DataFrame() for workflow_run in workflow_run_event_comparison: @@ -260,10 +263,10 @@ def report_results( ) if len(missing_data) != 0: print( - "Discovered differences between records read/write " - f"in workflow run {workflow_run}" + "Discovered differences between records read/write " + f"in workflow run {workflow_run}" ) - missing_data = missing_data.assign(workflow_run_id = workflow_run) + missing_data = missing_data.assign(workflow_run_id=workflow_run) all_missing_data = pandas.concat([all_missing_data, missing_data]) if len(all_missing_data) > 0: print(f"Writing missing data information to {missing_data_report_path}") @@ -278,56 +281,56 @@ def report_results( all_comparisons.to_csv(comparison_report_path) return all_missing_data + def main() -> None: args = read_args() start_unix_time = get_seconds_since_epoch( - datetime_str=args.start_datetime, - datetime_format=args.datetime_format, + datetime_str=args.start_datetime, + datetime_format=args.datetime_format, ) end_unix_time = get_seconds_since_epoch( - datetime_str=args.end_datetime, - datetime_format=args.datetime_format, + datetime_str=args.end_datetime, + datetime_format=args.datetime_format, ) file_property_logs = query_logs( - log_group_name=args.log_group_name, - query_string=args.query, - start_unix_time=start_unix_time, - end_unix_time=end_unix_time, + log_group_name=args.log_group_name, + query_string=args.query, + start_unix_time=start_unix_time, + end_unix_time=end_unix_time, ) if len(file_property_logs) == 0: print( - f"The query '{args.query}' did not return any results " - f"in the time range {start_unix_time}-{end_unix_time}" + f"The query '{args.query}' did not return any results " + f"in the time range {start_unix_time}-{end_unix_time}" ) return - workflow_run_logs = group_query_result_by_workflow_run(query_results=file_property_logs) + workflow_run_logs = group_query_result_by_workflow_run( + query_results=file_property_logs + ) workflow_run_event_comparison = {} for workflow_run in workflow_run_logs: workflow_run_dataframe = transform_logs_to_dataframe( - log_messages=workflow_run_logs[workflow_run] + log_messages=workflow_run_logs[workflow_run] ) - access_events = ( - workflow_run_dataframe - .query("event_type == 'access'") - .drop("event_type", axis=1) + access_events = workflow_run_dataframe.query("event_type == 'access'").drop( + "event_type", axis=1 ) - creation_events = ( - workflow_run_dataframe - .query("event_type == 'creation'") - .drop("event_type", axis=1) + creation_events = workflow_run_dataframe.query("event_type == 'creation'").drop( + "event_type", axis=1 ) event_comparison = access_events.merge( - creation_events, - how="left", - on=["cohort", "file_name"], - suffixes=("_access", "_creation") + creation_events, + how="left", + on=["cohort", "file_name"], + suffixes=("_access", "_creation"), + ) + event_comparison["line_count_difference"] = ( + event_comparison["line_count_access"] + - event_comparison["line_count_creation"] ) - event_comparison["line_count_difference"] = \ - event_comparison["line_count_access"] - event_comparison["line_count_creation"] workflow_run_event_comparison[workflow_run] = event_comparison - report_results( - workflow_run_event_comparison=workflow_run_event_comparison - ) + report_results(workflow_run_event_comparison=workflow_run_event_comparison) + if __name__ == "__main__": main() diff --git a/src/scripts/manage_artifacts/artifacts.py b/src/scripts/manage_artifacts/artifacts.py index d496c33c..cef93993 100755 --- a/src/scripts/manage_artifacts/artifacts.py +++ b/src/scripts/manage_artifacts/artifacts.py @@ -1,5 +1,5 @@ -import os import argparse +import os import subprocess @@ -9,7 +9,7 @@ def read_args(): """ parser = argparse.ArgumentParser(description="") parser.add_argument("--namespace") - parser.add_argument("--cfn_bucket", required = True) + parser.add_argument("--cfn_bucket", required=True) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--upload", action="store_true") group.add_argument("--remove", action="store_true") @@ -18,12 +18,12 @@ def read_args(): return args -def execute_command(cmd : str): +def execute_command(cmd: str): print(f'Invoking command: {" ".join(cmd)}') subprocess.run(cmd) -def upload(namespace : str, cfn_bucket : str): +def upload(namespace: str, cfn_bucket: str): """Copy Glue scripts to the artifacts bucket""" scripts_local_path = "src/glue/" @@ -33,7 +33,9 @@ def upload(namespace : str, cfn_bucket : str): """Copies Lambda code and template to the artifacts bucket""" lambda_local_path = "src/lambda_function/" - lambda_s3_path = os.path.join("s3://", cfn_bucket, namespace, "src/lambda_function/") + lambda_s3_path = os.path.join( + "s3://", cfn_bucket, namespace, "src/lambda_function/" + ) cmd = ["aws", "s3", "sync", lambda_local_path, lambda_s3_path] execute_command(cmd) @@ -44,14 +46,14 @@ def upload(namespace : str, cfn_bucket : str): execute_command(cmd) -def delete(namespace : str, cfn_bucket : str): +def delete(namespace: str, cfn_bucket: str): """Removes all files recursively for namespace""" s3_path = os.path.join("s3://", cfn_bucket, namespace) cmd = ["aws", "s3", "rm", "--recursive", s3_path] execute_command(cmd) -def list_namespaces(cfn_bucket : str): +def list_namespaces(cfn_bucket: str): """List all namespaces""" s3_path = os.path.join("s3://", cfn_bucket) cmd = ["aws", "s3", "ls", s3_path] diff --git a/src/scripts/manage_artifacts/clean_for_integration_test.py b/src/scripts/manage_artifacts/clean_for_integration_test.py index 45fbae87..cb9668a4 100755 --- a/src/scripts/manage_artifacts/clean_for_integration_test.py +++ b/src/scripts/manage_artifacts/clean_for_integration_test.py @@ -8,6 +8,7 @@ """ import argparse + import boto3 @@ -30,16 +31,16 @@ def delete_objects(bucket_prefix: str, bucket: str) -> None: """ print(f"Cleaning bucket: {bucket} with prefix: {bucket_prefix}") - s3_client = boto3.client('s3') + s3_client = boto3.client("s3") response = s3_client.list_objects_v2(Bucket=bucket, Prefix=bucket_prefix) # Check if objects are found - if 'Contents' in response: - for obj in response['Contents']: - object_key = obj['Key'] + if "Contents" in response: + for obj in response["Contents"]: + object_key = obj["Key"] # Skip the owner.txt file so it does not need to be re-created - if object_key.endswith('owner.txt'): + if object_key.endswith("owner.txt"): continue elif "main" in object_key: raise ValueError("Cannot delete objects in the main directory") @@ -52,12 +53,11 @@ def main() -> None: allow the owner.txt file to be kept.""" args = read_args() - if not args.bucket_prefix or args.bucket_prefix[-1] != '/': + if not args.bucket_prefix or args.bucket_prefix[-1] != "/": raise ValueError("Bucket prefix must be provided and end with a '/'") if "main" in args.bucket_prefix: - raise ValueError( - "Cannot delete objects in the main directory") + raise ValueError("Cannot delete objects in the main directory") try: delete_objects(bucket_prefix=args.bucket_prefix, bucket=args.bucket) diff --git a/src/scripts/setup_external_storage/setup_external_storage.py b/src/scripts/setup_external_storage/setup_external_storage.py index 0f58e750..abdcbc8f 100644 --- a/src/scripts/setup_external_storage/setup_external_storage.py +++ b/src/scripts/setup_external_storage/setup_external_storage.py @@ -2,11 +2,11 @@ Create an STS-enabled folder on Synapse over an S3 location. """ -import os -import json import argparse -import boto3 +import json +import os +import boto3 import synapseclient @@ -36,16 +36,25 @@ def read_args(): action="store_true", help="Whether this storage location should be STS enabled", ) - parser.add_argument("--profile", - help=("Optional. The AWS profile to use. Uses the default " - "profile if not specified.")) - parser.add_argument("--ssm-parameter", - help=("Optional. The name of the SSM parameter containing " - "the Synapse personal access token. " - "If not provided, cached credentials are used")) + parser.add_argument( + "--profile", + help=( + "Optional. The AWS profile to use. Uses the default " + "profile if not specified." + ), + ) + parser.add_argument( + "--ssm-parameter", + help=( + "Optional. The name of the SSM parameter containing " + "the Synapse personal access token. " + "If not provided, cached credentials are used" + ), + ) args = parser.parse_args() return args + def get_synapse_client(ssm_parameter=None, aws_session=None): """ Return an authenticated Synapse client. @@ -60,23 +69,20 @@ def get_synapse_client(ssm_parameter=None, aws_session=None): """ if ssm_parameter is not None: ssm_client = aws_session.client("ssm") - token = ssm_client.get_parameter( - Name=ssm_parameter, - WithDecryption=True) + token = ssm_client.get_parameter(Name=ssm_parameter, WithDecryption=True) syn = synapseclient.Synapse() syn.login(authToken=token["Parameter"]["Value"]) - else: # try cached credentials + else: # try cached credentials syn = synapseclient.login() return syn + def main(): args = read_args() aws_session = boto3.session.Session( - profile_name=args.profile, - region_name="us-east-1") - syn = get_synapse_client( - ssm_parameter=args.ssm_parameter, - aws_session=aws_session) + profile_name=args.profile, region_name="us-east-1" + ) + syn = get_synapse_client(ssm_parameter=args.ssm_parameter, aws_session=aws_session) synapse_folder, storage_location, synapse_project = syn.create_s3_storage_location( parent=args.synapse_parent, folder_name=args.synapse_folder_name, diff --git a/tests/conftest.py b/tests/conftest.py index e3a16a70..27b42d35 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -51,17 +51,16 @@ def dataset_fixture(request): @pytest.fixture def mock_s3_environment(mock_s3_bucket): - """This allows us to persist the bucket and s3 client - """ + """This allows us to persist the bucket and s3 client""" with mock_s3(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket=mock_s3_bucket) yield s3 @pytest.fixture def mock_s3_bucket(): - bucket_name = 'test-bucket' + bucket_name = "test-bucket" yield bucket_name diff --git a/tests/test_compare_parquet_datasets.py b/tests/test_compare_parquet_datasets.py index 2e4752a6..576d967c 100644 --- a/tests/test_compare_parquet_datasets.py +++ b/tests/test_compare_parquet_datasets.py @@ -1,7 +1,7 @@ -from collections import namedtuple import json import re import zipfile +from collections import namedtuple from io import BytesIO from unittest import mock @@ -388,7 +388,6 @@ def test_that_get_integration_test_exports_json_throws_json_decode_error( def test_that_get_exports_filter_values_returns_expected_results( s3, data_type, filelist, expected_filter ): - with mock.patch.object( compare_parquet, "get_integration_test_exports_json" ) as patch_test_exports, mock.patch.object( diff --git a/tests/test_consume_logs.py b/tests/test_consume_logs.py index c8853feb..ed0db332 100644 --- a/tests/test_consume_logs.py +++ b/tests/test_consume_logs.py @@ -1,88 +1,99 @@ import datetime + import boto3 import pytest from botocore.stub import Stubber from pandas import DataFrame, Index + from src.scripts.consume_logs import consume_logs + @pytest.fixture def log_group_name(): return "test-log-group" + @pytest.fixture def start_unix_time(): return 1635619200 + @pytest.fixture def end_unix_time(start_unix_time): return start_unix_time + 10 + @pytest.fixture def query_string(): return "display message" + def test_get_seconds_since_epoch(): now = datetime.datetime.now() actual_seconds_since_epoch = int(now.timestamp()) datetime_format = "%Y-%m-%d %H:%M:%S" datetime_str = now.strftime(datetime_format) calculated_seconds_since_epoch = consume_logs.get_seconds_since_epoch( - datetime_str=datetime_str, - datetime_format=datetime_format, + datetime_str=datetime_str, + datetime_format=datetime_format, ) assert calculated_seconds_since_epoch == actual_seconds_since_epoch + def test_get_seconds_since_epoch_datetime_str_is_none(): now = datetime.datetime.now() actual_seconds_since_epoch = int(now.timestamp()) calculated_seconds_since_epoch = consume_logs.get_seconds_since_epoch( - datetime_str=None, - datetime_format=None, + datetime_str=None, + datetime_format=None, ) # We don't know what should be the exact time computed within # `get_seconds_since_epoch`, but it ought to be within a second of # `actual_seconds_since_epoch` assert abs(calculated_seconds_since_epoch - actual_seconds_since_epoch) <= 1 + def test_query_logs_status_complete( - log_group_name, start_unix_time, end_unix_time, query_string - ): + log_group_name, start_unix_time, end_unix_time, query_string +): logs_client = _stub_logs_client( - status="Complete", - log_group_name=log_group_name, - start_unix_time=start_unix_time, - end_unix_time=end_unix_time, - query_string=query_string, + status="Complete", + log_group_name=log_group_name, + start_unix_time=start_unix_time, + end_unix_time=end_unix_time, + query_string=query_string, ) query_result = consume_logs.query_logs( - log_group_name=log_group_name, - query_string=query_string, - start_unix_time=start_unix_time, - end_unix_time=end_unix_time, - logs_client=logs_client, + log_group_name=log_group_name, + query_string=query_string, + start_unix_time=start_unix_time, + end_unix_time=end_unix_time, + logs_client=logs_client, ) assert len(query_result) == 1 assert query_result[0][0]["field"] == "message" + def test_query_logs_status_failed( - log_group_name, start_unix_time, end_unix_time, query_string - ): + log_group_name, start_unix_time, end_unix_time, query_string +): logs_client = _stub_logs_client( - status="Failed", - log_group_name=log_group_name, - start_unix_time=start_unix_time, - end_unix_time=end_unix_time, - query_string=query_string, + status="Failed", + log_group_name=log_group_name, + start_unix_time=start_unix_time, + end_unix_time=end_unix_time, + query_string=query_string, ) with pytest.raises(UserWarning): query_result = consume_logs.query_logs( - log_group_name=log_group_name, - query_string=query_string, - start_unix_time=start_unix_time, - end_unix_time=end_unix_time, - logs_client=logs_client, + log_group_name=log_group_name, + query_string=query_string, + start_unix_time=start_unix_time, + end_unix_time=end_unix_time, + logs_client=logs_client, ) + def test_check_for_failed_query(): with pytest.raises(UserWarning): consume_logs._check_for_failed_query({"status": "Failed"}) @@ -93,107 +104,58 @@ def test_check_for_failed_query(): with pytest.raises(UserWarning): consume_logs._check_for_failed_query({"status": "Unknown"}) + def test_group_query_results_by_workflow_run(): query_results = [ - [ - { - 'field': '@message', - 'value': '{"process": {"parent": {"pid": "one"}}}' - } - ], - [ - { - 'field': '@message', - 'value': '{"process": {"parent": {"pid": "one"}}}' - } - ], - [ - { - 'field': '@message', - 'value': '{"process": {"parent": {"pid": "two"}}}' - } - ] + [{"field": "@message", "value": '{"process": {"parent": {"pid": "one"}}}'}], + [{"field": "@message", "value": '{"process": {"parent": {"pid": "one"}}}'}], + [{"field": "@message", "value": '{"process": {"parent": {"pid": "two"}}}'}], ] workflow_run_groups = consume_logs.group_query_result_by_workflow_run( - query_results=query_results + query_results=query_results ) assert len(workflow_run_groups["one"]) == 2 assert len(workflow_run_groups["two"]) == 1 + def test_transform_logs_to_dataframe(): log_messages = [ { - "labels": { - "cohort": "adults_v1" - }, - "file": { - "name": "file_one", - "LineCount": 1 - }, - "event": { - "type": ["access"] - } + "labels": {"cohort": "adults_v1"}, + "file": {"name": "file_one", "LineCount": 1}, + "event": {"type": ["access"]}, }, { - "labels": { - "cohort": "adults_v1" - }, - "file": { - "name": "file_one", - "LineCount": 1 - }, - "event": { - "type": ["creation"] - } + "labels": {"cohort": "adults_v1"}, + "file": {"name": "file_one", "LineCount": 1}, + "event": {"type": ["creation"]}, }, ] resulting_dataframe = consume_logs.transform_logs_to_dataframe( - log_messages=log_messages + log_messages=log_messages ) expected_columns = ["cohort", "file_name", "event_type", "line_count"] for col in expected_columns: assert col in resulting_dataframe.columns + @pytest.mark.parametrize( "log_messages", [ - [ - { - "notevent": None - } - ], - [ - { - "event": {"nottype": None} - } - ], - [ - { - "event": { - "type": [ - "notaccess", "info" - ] - } - } - ], - [ - { - "event": { - "type": [ - "access", "creation" - ] - } - } - ], - ] + [{"notevent": None}], + [{"event": {"nottype": None}}], + [{"event": {"type": ["notaccess", "info"]}}], + [{"event": {"type": ["access", "creation"]}}], + ], ) def test_transform_logs_to_dataframe_malformatted_event_type(log_messages): with pytest.raises(KeyError): consume_logs.transform_logs_to_dataframe(log_messages=log_messages) + def _stub_logs_client( - status, log_group_name, start_unix_time, end_unix_time, query_string - ): + status, log_group_name, start_unix_time, end_unix_time, query_string +): """ A helper function for stubbing the boto3 logs client @@ -205,49 +167,44 @@ def _stub_logs_client( start_query_response = {"queryId": "a-query-id"} get_query_results_response = { "status": status, - "results": [[{"field": "message", "value": "Test log message 1"}]] + "results": [[{"field": "message", "value": "Test log message 1"}]], } # Set up the Stubber to mock the CloudWatch Logs API calls logs_client_stub.add_response( - "start_query", - start_query_response, - expected_params={ - "logGroupName": log_group_name, - "startTime": start_unix_time, - "endTime": end_unix_time, - "queryString": query_string - } + "start_query", + start_query_response, + expected_params={ + "logGroupName": log_group_name, + "startTime": start_unix_time, + "endTime": end_unix_time, + "queryString": query_string, + }, ) logs_client_stub.add_response( - "get_query_results", - get_query_results_response, - expected_params={ - "queryId": start_query_response["queryId"] - } + "get_query_results", + get_query_results_response, + expected_params={"queryId": start_query_response["queryId"]}, ) logs_client_stub.activate() return logs_client + @pytest.mark.parametrize( "workflow_run_event_comparison,missing_data", [ - ( - {"one": DataFrame({"line_count_difference": [0]})}, - DataFrame() - ), + ({"one": DataFrame({"line_count_difference": [0]})}, DataFrame()), ( {"one": DataFrame({"line_count_difference": [0, 1]})}, DataFrame( {"line_count_difference": [1]}, - index=Index(["one"], name="workflow_run_id") - ) - ) - ] + index=Index(["one"], name="workflow_run_id"), + ), + ), + ], ) def test_report_results(workflow_run_event_comparison, missing_data): this_missing_data = consume_logs.report_results( - workflow_run_event_comparison=workflow_run_event_comparison, - testing=True + workflow_run_event_comparison=workflow_run_event_comparison, testing=True ) assert this_missing_data.equals(missing_data) diff --git a/tests/test_lambda_dispatch.py b/tests/test_lambda_dispatch.py index 5de358a5..1acdd2b9 100644 --- a/tests/test_lambda_dispatch.py +++ b/tests/test_lambda_dispatch.py @@ -1,15 +1,17 @@ import json import os -from pathlib import Path import shutil import tempfile import zipfile +from pathlib import Path +from unittest import mock import boto3 import pytest -from moto import mock_sns, mock_s3 +from moto import mock_s3, mock_sns + from src.lambda_function.dispatch import app -from unittest import mock + @pytest.fixture def s3_event(): @@ -43,18 +45,20 @@ def s3_event(): } return s3_event + @pytest.fixture def sns_message_template(): sns_message_template = { - "Type": "string", - "MessageId": "string", - "TopicArn": "string", - "Subject": "string", - "Message": "string", - "Timestamp": "string" + "Type": "string", + "MessageId": "string", + "TopicArn": "string", + "Subject": "string", + "Message": "string", + "Timestamp": "string", } return sns_message_template + @pytest.fixture def sqs_message_template(): sqs_msg = { @@ -80,33 +84,34 @@ def sqs_message_template(): } yield sqs_msg + @pytest.fixture def event(s3_event, sns_message_template, sqs_message_template): sns_message_template["Message"] = json.dumps({"Records": [s3_event]}) sqs_message_template["Records"][0]["body"] = json.dumps(sns_message_template) return sqs_message_template + @pytest.fixture def archive_json_paths(): archive_json_paths = [ - "HealthKitV2Workouts_20240508-20240509.json", # normal file - "empty.json", # should have file size 0 and be ignored - "Manifest.json", # should be ignored - "dir/containing/stuff.json" # should be ignored + "HealthKitV2Workouts_20240508-20240509.json", # normal file + "empty.json", # should have file size 0 and be ignored + "Manifest.json", # should be ignored + "dir/containing/stuff.json", # should be ignored ] return archive_json_paths + @pytest.fixture def temp_zip_file(): - temp_zip_file = tempfile.NamedTemporaryFile( - delete=False, - suffix='.zip' - ) + temp_zip_file = tempfile.NamedTemporaryFile(delete=False, suffix=".zip") return temp_zip_file + @pytest.fixture def archive_path(archive_json_paths, temp_zip_file): - with zipfile.ZipFile(temp_zip_file.name, 'w', zipfile.ZIP_DEFLATED) as zip_file: + with zipfile.ZipFile(temp_zip_file.name, "w", zipfile.ZIP_DEFLATED) as zip_file: for file_path in archive_json_paths: if "/" in file_path: os.makedirs(os.path.dirname(file_path)) @@ -123,17 +128,23 @@ def archive_path(archive_json_paths, temp_zip_file): yield temp_zip_file.name os.remove(temp_zip_file.name) + def test_get_object_info(s3_event): object_info = app.get_object_info(s3_event=s3_event) assert object_info["Bucket"] == s3_event["s3"]["bucket"]["name"] assert object_info["Key"] == s3_event["s3"]["object"]["key"] + def test_get_object_info_unicode_characters_in_key(s3_event): - s3_event["s3"]["object"]["key"] = \ - "main/2023-09-26T00%3A06%3A39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" + s3_event["s3"]["object"][ + "key" + ] = "main/2023-09-26T00%3A06%3A39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" object_info = app.get_object_info(s3_event=s3_event) - assert object_info["Key"] == \ - "main/2023-09-26T00:06:39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" + assert ( + object_info["Key"] + == "main/2023-09-26T00:06:39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" + ) + @pytest.mark.parametrize( "object_info,expected", @@ -186,21 +197,22 @@ def test_get_object_info_unicode_characters_in_key(s3_event): ], ) def test_that_filter_object_info_returns_expected_result(object_info, expected): - assert app.filter_object_info(object_info) == expected + assert app.filter_object_info(object_info) == expected + def test_get_archive_contents(archive_path, archive_json_paths): dummy_bucket = "dummy_bucket" dummy_key = "dummy_key" archive_contents = app.get_archive_contents( - archive_path=archive_path, - bucket=dummy_bucket, - key=dummy_key + archive_path=archive_path, bucket=dummy_bucket, key=dummy_key ) assert all([content["Bucket"] == dummy_bucket for content in archive_contents]) assert all([content["Key"] == dummy_key for content in archive_contents]) assert all([content["FileSize"] > 0 for content in archive_contents]) - assert set([content["Path"] for content in archive_contents]) == \ - set(["HealthKitV2Workouts_20240508-20240509.json"]) + assert set([content["Path"] for content in archive_contents]) == set( + ["HealthKitV2Workouts_20240508-20240509.json"] + ) + @mock_sns @mock_s3 @@ -210,19 +222,17 @@ def test_main(event, temp_zip_file, s3_event, archive_path): bucket = s3_event["s3"]["bucket"]["name"] key = s3_event["s3"]["object"]["key"] s3_client.create_bucket(Bucket=bucket) - s3_client.upload_file( - Filename=archive_path, - Bucket=bucket, - Key=key - ) + s3_client.upload_file(Filename=archive_path, Bucket=bucket, Key=key) dispatch_sns = sns_client.create_topic(Name="test-sns-topic") - with mock.patch.object(sns_client, "publish", wraps=sns_client.publish) as mock_publish: + with mock.patch.object( + sns_client, "publish", wraps=sns_client.publish + ) as mock_publish: app.main( - event=event, - context=dict(), - sns_client=sns_client, - s3_client=s3_client, - dispatch_sns_arn=dispatch_sns["TopicArn"], - temp_zip_path=temp_zip_file.name + event=event, + context=dict(), + sns_client=sns_client, + s3_client=s3_client, + dispatch_sns_arn=dispatch_sns["TopicArn"], + temp_zip_path=temp_zip_file.name, ) mock_publish.assert_called() diff --git a/tests/test_s3_event_config_lambda.py b/tests/test_s3_event_config_lambda.py index 60d0e1d6..79eaa7b2 100644 --- a/tests/test_s3_event_config_lambda.py +++ b/tests/test_s3_event_config_lambda.py @@ -1,8 +1,9 @@ import copy from unittest import mock + import boto3 -from moto import mock_s3, mock_lambda, mock_iam, mock_sqs, mock_sns import pytest +from moto import mock_iam, mock_lambda, mock_s3, mock_sns, mock_sqs from src.lambda_function.s3_event_config import app @@ -30,6 +31,7 @@ def mock_lambda_function(mock_iam_role): ) yield client.get_function(FunctionName="some_function") + @pytest.fixture def mock_sns_topic_arn(): with mock_sns(): @@ -48,6 +50,7 @@ def mock_sqs_queue(mock_aws_credentials): QueueUrl=queue_url["QueueUrl"], AttributeNames=["QueueArn"] ) + @pytest.fixture def notification_configuration(): return app.NotificationConfiguration( @@ -56,15 +59,12 @@ def notification_configuration(): "Events": ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"], "TopicArn": "arn:aws:sns:bla", "Filter": { - "Key": { - "FilterRules": [ - {"Name": "Prefix", "Value": "documents/"} - ] - } - } - } + "Key": {"FilterRules": [{"Name": "Prefix", "Value": "documents/"}]} + }, + }, ) + @pytest.fixture def bucket_notification_configurations(notification_configuration): ### Topic Configuration @@ -73,45 +73,55 @@ def bucket_notification_configurations(notification_configuration): queue_configuration_value = copy.deepcopy(notification_configuration.value) del queue_configuration_value["TopicArn"] queue_configuration_value["QueueArn"] = "arn:aws:sqs:bla" - queue_configuration_value["Filter"]["Key"]["FilterRules"][0] = \ - {"Name": "Suffix", "Value": "jpeg"} + queue_configuration_value["Filter"]["Key"]["FilterRules"][0] = { + "Name": "Suffix", + "Value": "jpeg", + } queue_configuration = app.NotificationConfiguration( - notification_type=app.NotificationConfigurationType("Queue"), - value=queue_configuration_value + notification_type=app.NotificationConfigurationType("Queue"), + value=queue_configuration_value, ) ### Lambda Configuration lambda_configuration_value = copy.deepcopy(notification_configuration.value) del lambda_configuration_value["TopicArn"] lambda_configuration_value["LambdaFunctionArn"] = "arn:aws:lambda:bla" - lambda_configuration_value["Filter"]["Key"]["FilterRules"][0] = \ - {"Name": "Suffix", "Value": "jpeg"} + lambda_configuration_value["Filter"]["Key"]["FilterRules"][0] = { + "Name": "Suffix", + "Value": "jpeg", + } lambda_configuration_value["Filter"]["Key"]["FilterRules"].append( - {"Name": "Prefix", "Value": "pictures/"} + {"Name": "Prefix", "Value": "pictures/"} ) lambda_configuration = app.NotificationConfiguration( - notification_type=app.NotificationConfigurationType("LambdaFunction"), - value=lambda_configuration_value + notification_type=app.NotificationConfigurationType("LambdaFunction"), + value=lambda_configuration_value, ) bucket_notification_configurations = app.BucketNotificationConfigurations( [topic_configuration, queue_configuration, lambda_configuration] ) return bucket_notification_configurations + class TestBucketNotificationConfigurations: def test_init(self, notification_configuration): configs = [notification_configuration, notification_configuration] - bucket_notification_configurations = app.BucketNotificationConfigurations(configs) + bucket_notification_configurations = app.BucketNotificationConfigurations( + configs + ) assert bucket_notification_configurations.configs == configs def test_to_dict(self, notification_configuration): other_notification_configuration = copy.deepcopy(notification_configuration) other_notification_configuration.type = "LambdaFunction" configs = [notification_configuration, other_notification_configuration] - bucket_notification_configurations = app.BucketNotificationConfigurations(configs) + bucket_notification_configurations = app.BucketNotificationConfigurations( + configs + ) bnc_as_dict = bucket_notification_configurations.to_dict() assert "TopicConfigurations" in bnc_as_dict assert "LambdaFunctionConfigurations" in bnc_as_dict + class TestGetBucketNotificationConfigurations: @mock_s3 def test_get_configurations(self, s3, notification_configuration): @@ -133,112 +143,121 @@ def test_get_configurations(self, s3, notification_configuration): "QueueConfigurations": [queue_configuration], "TopicConfigurations": [topic_configuration], "LambdaFunctionConfigurations": [lambda_configuration], - "EventBridgeConfiguration": event_bridge_configuration + "EventBridgeConfiguration": event_bridge_configuration, }, ): - bucket_notification_configurations = app.get_bucket_notification_configurations( - s3_client=s3, - bucket="some_bucket" + bucket_notification_configurations = ( + app.get_bucket_notification_configurations( + s3_client=s3, bucket="some_bucket" + ) ) # We should ignore 'EventBridgeConfiguration' assert len(bucket_notification_configurations.configs) == 3 + class TestGetNotificationConfiguration: def test_no_prefix_matching_suffix(self, bucket_notification_configurations): # No prefix provided, suffix provided matching_notification_configuration = app.get_notification_configuration( - bucket_notification_configurations, bucket_key_suffix="jpeg" + bucket_notification_configurations, bucket_key_suffix="jpeg" ) assert matching_notification_configuration is not None assert matching_notification_configuration.type == "Queue" def test_no_suffix_matching_prefix(self, bucket_notification_configurations): matching_notification_configuration = app.get_notification_configuration( - bucket_notification_configurations, bucket_key_prefix="documents" + bucket_notification_configurations, bucket_key_prefix="documents" ) assert matching_notification_configuration is not None assert matching_notification_configuration.type == "Topic" - def test_matching_prefix_not_matching_suffix(self, bucket_notification_configurations): + def test_matching_prefix_not_matching_suffix( + self, bucket_notification_configurations + ): matching_notification_configuration = app.get_notification_configuration( - bucket_notification_configurations=bucket_notification_configurations, - bucket_key_prefix="pictures", - bucket_key_suffix="png" + bucket_notification_configurations=bucket_notification_configurations, + bucket_key_prefix="pictures", + bucket_key_suffix="png", ) assert matching_notification_configuration is None - def test_matching_suffix_not_matching_prefix(self, bucket_notification_configurations): + def test_matching_suffix_not_matching_prefix( + self, bucket_notification_configurations + ): matching_notification_configuration = app.get_notification_configuration( - bucket_notification_configurations=bucket_notification_configurations, - bucket_key_prefix="documents", - bucket_key_suffix="jpeg" + bucket_notification_configurations=bucket_notification_configurations, + bucket_key_prefix="documents", + bucket_key_suffix="jpeg", ) assert matching_notification_configuration is None def test_no_match(self, bucket_notification_configurations): matching_notification_configuration = app.get_notification_configuration( - bucket_notification_configurations=bucket_notification_configurations, - bucket_key_prefix="downloads", + bucket_notification_configurations=bucket_notification_configurations, + bucket_key_prefix="downloads", ) assert matching_notification_configuration is None + class TestNormalizeFilterRules: def test_normalize_filter_rules(self, notification_configuration): normalized_notification_configuration = app.normalize_filter_rules( - config=notification_configuration + config=notification_configuration ) assert all( - [ - rule["Name"].lower() == rule["Name"] - for rule - in notification_configuration.value["Filter"]["Key"]["FilterRules"] + [ + rule["Name"].lower() == rule["Name"] + for rule in notification_configuration.value["Filter"]["Key"][ + "FilterRules" ] + ] ) + class TestNotificationConfigurationMatches: def test_all_true(self, notification_configuration): assert app.notification_configuration_matches( - config=notification_configuration, - other_config=notification_configuration + config=notification_configuration, other_config=notification_configuration ) def test_arn_false(self, notification_configuration): other_notification_configuration = copy.deepcopy(notification_configuration) other_notification_configuration.arn = "arn:aws:sns:hubba" assert not app.notification_configuration_matches( - config=notification_configuration, - other_config=other_notification_configuration + config=notification_configuration, + other_config=other_notification_configuration, ) def test_events_false(self, notification_configuration): other_notification_configuration = copy.deepcopy(notification_configuration) other_notification_configuration.value["Events"] = ["s3:ObjectCreated*"] assert not app.notification_configuration_matches( - config=notification_configuration, - other_config=other_notification_configuration + config=notification_configuration, + other_config=other_notification_configuration, ) def test_filter_rule_names_false(self, notification_configuration): other_notification_configuration = copy.deepcopy(notification_configuration) other_notification_configuration.value["Filter"]["Key"]["FilterRules"] = [ - {"Name": "Prefix", "Value": "documents/"}, - {"Name": "Suffix", "Value": "jpeg"}, + {"Name": "Prefix", "Value": "documents/"}, + {"Name": "Suffix", "Value": "jpeg"}, ] assert not app.notification_configuration_matches( - config=notification_configuration, - other_config=other_notification_configuration + config=notification_configuration, + other_config=other_notification_configuration, ) def test_filter_rule_values_false(self, notification_configuration): other_notification_configuration = copy.deepcopy(notification_configuration) other_notification_configuration.value["Filter"]["Key"]["FilterRules"] = [ - {"Name": "Prefix", "Value": "pictures/"} + {"Name": "Prefix", "Value": "pictures/"} ] assert not app.notification_configuration_matches( - config=notification_configuration, - other_config=other_notification_configuration + config=notification_configuration, + other_config=other_notification_configuration, ) + class TestAddNotification: @mock_s3 def test_adds_expected_settings_for_lambda(self, s3, mock_lambda_function): @@ -264,7 +283,14 @@ def test_adds_expected_settings_for_lambda(self, s3, mock_lambda_function): "s3:ObjectCreated:*" ] # moto prefix/Prefix discrepancy - assert len(get_config["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 + assert ( + len( + get_config["LambdaFunctionConfigurations"][0]["Filter"]["Key"][ + "FilterRules" + ] + ) + == 1 + ) @mock_s3 def test_adds_expected_settings_for_sns(self, s3, mock_sns_topic_arn): @@ -283,12 +309,12 @@ def test_adds_expected_settings_for_sns(self, s3, mock_sns_topic_arn): ) get_config = s3.get_bucket_notification_configuration(Bucket="some_bucket") assert get_config["TopicConfigurations"][0]["TopicArn"] == mock_sns_topic_arn - assert get_config["TopicConfigurations"][0]["Events"] == [ - "s3:ObjectCreated:*" - ] + assert get_config["TopicConfigurations"][0]["Events"] == ["s3:ObjectCreated:*"] # moto prefix/Prefix discrepancy - assert len(get_config["TopicConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 - + assert ( + len(get_config["TopicConfigurations"][0]["Filter"]["Key"]["FilterRules"]) + == 1 + ) @mock_s3 def test_adds_expected_settings_for_sqs(self, s3, mock_sqs_queue): @@ -312,8 +338,10 @@ def test_adds_expected_settings_for_sqs(self, s3, mock_sqs_queue): ) assert get_config["QueueConfigurations"][0]["Events"] == ["s3:ObjectCreated:*"] # moto prefix/Prefix discrepancy - assert len(get_config["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 - + assert ( + len(get_config["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"]) + == 1 + ) @mock_s3 def test_raise_exception_if_config_exists_for_prefix( @@ -326,11 +354,7 @@ def test_raise_exception_if_config_exists_for_prefix( with mock.patch.object( s3, "get_bucket_notification_configuration", - return_value={ - f"TopicConfigurations": [ - notification_configuration.value - ] - }, + return_value={f"TopicConfigurations": [notification_configuration.value]}, ): with pytest.raises(RuntimeError): app.add_notification( @@ -338,7 +362,9 @@ def test_raise_exception_if_config_exists_for_prefix( "Queue", "arn:aws:sqs:bla", "some_bucket", - notification_configuration.value["Filter"]["Key"]["FilterRules"][0]["Value"], + notification_configuration.value["Filter"]["Key"]["FilterRules"][0][ + "Value" + ], ) @mock_s3 @@ -351,17 +377,13 @@ def test_does_nothing_if_notification_already_exists(self, s3): "TopicArn": "arn:aws:sns:bla", "Events": ["s3:ObjectCreated:*"], "Filter": { - "Key": { - "FilterRules": [{"Name": "Prefix", "Value": f"documents/"}] - } + "Key": {"FilterRules": [{"Name": "Prefix", "Value": f"documents/"}]} }, } with mock.patch.object( s3, "get_bucket_notification_configuration", - return_value={ - f"TopicConfigurations": [notification_configuration] - }, + return_value={f"TopicConfigurations": [notification_configuration]}, ), mock.patch.object(s3, "put_bucket_notification_configuration") as put_config: # WHEN I add the existing matching `LambdaFunction` configuration app.add_notification( @@ -369,7 +391,9 @@ def test_does_nothing_if_notification_already_exists(self, s3): destination_type="Topic", destination_arn=notification_configuration["TopicArn"], bucket="some_bucket", - bucket_key_prefix=notification_configuration["Filter"]["Key"]["FilterRules"][0]["Value"], + bucket_key_prefix=notification_configuration["Filter"]["Key"][ + "FilterRules" + ][0]["Value"], ) # AND I get the notification configuration @@ -378,7 +402,6 @@ def test_does_nothing_if_notification_already_exists(self, s3): # THEN I expect nothing to have been saved in our mocked environment assert not put_config.called - @mock_s3 def test_does_nothing_if_notification_already_exists_even_in_different_dict_order( self, s3, mock_lambda_function @@ -424,7 +447,6 @@ def test_does_nothing_if_notification_already_exists_even_in_different_dict_orde # THEN I expect nothing to have been saved in our mocked environment assert not put_config.called - @mock_s3 def test_adds_config_if_requested_notification_does_not_exist( self, s3, mock_lambda_function, mock_sqs_queue @@ -474,13 +496,21 @@ def test_adds_config_if_requested_notification_does_not_exist( "s3:ObjectCreated:*" ] # moto prefix/Prefix discrepancy - assert len(get_config["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 + assert ( + len( + get_config["LambdaFunctionConfigurations"][0]["Filter"]["Key"][ + "FilterRules" + ] + ) + == 1 + ) + class TestDeleteNotification: @mock_s3 def test_is_successful_for_configuration_that_exists( self, s3, mock_lambda_function - ): + ): # GIVEN an S3 bucket s3.create_bucket(Bucket="some_bucket") @@ -508,15 +538,12 @@ def test_is_successful_for_configuration_that_exists( ): # WHEN I delete the notification app.delete_notification( - s3_client=s3, - bucket="some_bucket", - bucket_key_prefix="test_folder" + s3_client=s3, bucket="some_bucket", bucket_key_prefix="test_folder" ) # THEN the notification should be deleted get_config = s3.get_bucket_notification_configuration(Bucket="some_bucket") assert "LambdaFunctionConfigurations" not in get_config - @mock_s3 def test_does_nothing_when_deleting_configuration_that_does_not_exist( self, s3, mock_lambda_function @@ -551,7 +578,7 @@ def test_does_nothing_when_deleting_configuration_that_does_not_exist( app.delete_notification( s3_client=s3, bucket="some_bucket", - bucket_key_prefix="another_test_folder" + bucket_key_prefix="another_test_folder", ) # THEN nothing should have been called assert not put_config.called diff --git a/tests/test_s3_to_glue_lambda.py b/tests/test_s3_to_glue_lambda.py index d51580d2..83ef13cd 100644 --- a/tests/test_s3_to_glue_lambda.py +++ b/tests/test_s3_to_glue_lambda.py @@ -1,8 +1,8 @@ -import pytest +import json from unittest import mock -import json import boto3 +import pytest from moto import mock_sqs from src.lambda_function.s3_to_glue import app @@ -31,12 +31,12 @@ def sqs_queue(self, sqs_queue_name): @pytest.fixture def sns_message(self): sns_message_wrapper = { - "Type": "string", - "MessageId": "string", - "TopicArn": "string", - "Subject": "string", - "Message": "string", - "Timestamp": "string" + "Type": "string", + "MessageId": "string", + "TopicArn": "string", + "Subject": "string", + "Message": "string", + "Timestamp": "string", } return sns_message_wrapper @@ -165,11 +165,14 @@ def test_that_lambda_handler_calls_submit_s3_to_json_workflow_if_queue_has_messa ) def test_get_object_info_unicode_characters_in_key(self, s3_event): - s3_event["s3"]["object"]["key"] = \ - "main/2023-09-26T00%3A06%3A39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" + s3_event["s3"]["object"][ + "key" + ] = "main/2023-09-26T00%3A06%3A39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" object_info = app.get_object_info(s3_event=s3_event) - assert object_info["source_key"] == \ - "main/2023-09-26T00:06:39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" + assert ( + object_info["source_key"] + == "main/2023-09-26T00:06:39Z_d873eafb-554f-4f8a-9e61-cdbcb7de07eb" + ) @pytest.mark.parametrize( "object_info,expected", @@ -226,7 +229,6 @@ def test_that_filter_object_info_returns_expected_result( ): assert app.filter_object_info(object_info) == expected - def test_that_is_s3_test_event_returns_true_when_s3_test_event_is_present( self, s3_test_event ): diff --git a/tests/test_s3_to_json.py b/tests/test_s3_to_json.py index 3b5a2a63..5cede403 100644 --- a/tests/test_s3_to_json.py +++ b/tests/test_s3_to_json.py @@ -1,12 +1,12 @@ -import os +import datetime import io import json +import os import shutil import zipfile -import datetime -from dateutil.tz import tzutc import pytest +from dateutil.tz import tzutc from src.glue.jobs import s3_to_json @@ -52,8 +52,7 @@ def s3_obj(self, shared_datadir): } # sample test data with open( - shared_datadir - / "2023-01-13T21--08--51Z_TESTDATA", + shared_datadir / "2023-01-13T21--08--51Z_TESTDATA", "rb", ) as z: s3_obj["Body"] = z.read() @@ -62,11 +61,11 @@ def s3_obj(self, shared_datadir): @pytest.fixture def sample_metadata(self): sample_metadata = { - "type": "FitbitDevices", - "start_date": datetime.datetime(2022, 1, 12, 0, 0), - "end_date": datetime.datetime(2023, 1, 14, 0, 0), - "subtype": "FakeSubtype", - "cohort": "adults_v1" + "type": "FitbitDevices", + "start_date": datetime.datetime(2022, 1, 12, 0, 0), + "end_date": datetime.datetime(2023, 1, 14, 0, 0), + "subtype": "FakeSubtype", + "cohort": "adults_v1", } return sample_metadata @@ -99,41 +98,25 @@ def json_file_basenames_dict(self): return json_file_basenames def test_transform_object_to_array_of_objects(self): - json_obj_to_replace = { - "0": 60.0, - "1": 61.2, - "2": "99" - } + json_obj_to_replace = {"0": 60.0, "1": 61.2, "2": "99"} transformed_object = s3_to_json.transform_object_to_array_of_objects( - json_obj_to_replace=json_obj_to_replace, - key_name="key", - key_type=int, - value_name="value", - value_type=int + json_obj_to_replace=json_obj_to_replace, + key_name="key", + key_type=int, + value_name="value", + value_type=int, ) expected_object = [ - { - "key": 0, - "value": 60 - }, - { - "key": 1, - "value": 61 - }, - { - "key":2, - "value": 99 - } + {"key": 0, "value": 60}, + {"key": 1, "value": 61}, + {"key": 2, "value": 99}, ] assert all([obj in expected_object for obj in transformed_object]) def test_log_error_transform_object_to_array_of_objects(self, caplog): s3_to_json.logger.propagate = True s3_to_json._log_error_transform_object_to_array_of_objects( - value="a", - value_type=int, - error=ValueError, - logger_context={} + value="a", value_type=int, error=ValueError, logger_context={} ) s3_to_json.logger.propagate = False assert len(caplog.records) == 2 @@ -141,8 +124,7 @@ def test_log_error_transform_object_to_array_of_objects(self, caplog): def test_transform_json_with_subtype(self, sample_metadata): sample_metadata["type"] = "HealthKitV2Samples" transformed_json = s3_to_json.transform_json( - json_obj={}, - metadata=sample_metadata + json_obj={}, metadata=sample_metadata ) assert sample_metadata["subtype"] == transformed_json["Type"] @@ -159,8 +141,7 @@ def test_transform_json_symptom_log(self, sample_metadata): sample_metadata["type"] = "SymptomLog" transformed_value = {"a": 1, "b": 2} transformed_json = s3_to_json.transform_json( - json_obj={"Value": json.dumps(transformed_value)}, - metadata=sample_metadata + json_obj={"Value": json.dumps(transformed_value)}, metadata=sample_metadata ) assert ( @@ -175,23 +156,25 @@ def test_transform_json_symptom_log(self, sample_metadata): def test_add_universal_properties_start_date(self, sample_metadata): json_obj = s3_to_json._add_universal_properties( - json_obj={}, - metadata=sample_metadata, + json_obj={}, + metadata=sample_metadata, + ) + assert ( + json_obj["export_start_date"] == sample_metadata["start_date"].isoformat() ) - assert json_obj["export_start_date"] == sample_metadata["start_date"].isoformat() def test_add_universal_properties_no_start_date(self, sample_metadata): sample_metadata["start_date"] = None json_obj = s3_to_json._add_universal_properties( - json_obj={}, - metadata=sample_metadata, + json_obj={}, + metadata=sample_metadata, ) assert json_obj["export_start_date"] is None def test_add_universal_properties_generic(self, sample_metadata): json_obj = s3_to_json._add_universal_properties( - json_obj={}, - metadata=sample_metadata, + json_obj={}, + metadata=sample_metadata, ) assert json_obj["export_end_date"] == sample_metadata["end_date"].isoformat() assert json_obj["cohort"] == sample_metadata["cohort"] @@ -199,118 +182,87 @@ def test_add_universal_properties_generic(self, sample_metadata): def test_cast_custom_fields_to_array(self): sample_symptoms = {"id": "123", "symptom": "sick"} transformed_json = s3_to_json._cast_custom_fields_to_array( - json_obj={ - "CustomFields": { - "Symptoms": json.dumps(sample_symptoms) - } - }, - logger_context={}, + json_obj={"CustomFields": {"Symptoms": json.dumps(sample_symptoms)}}, + logger_context={}, ) assert all( - [ - item in transformed_json["CustomFields"]["Symptoms"].items() - for item in sample_symptoms.items() - ] + [ + item in transformed_json["CustomFields"]["Symptoms"].items() + for item in sample_symptoms.items() + ] ) def test_cast_custom_fields_to_array_malformatted_str(self, caplog): s3_to_json.logger.propagate = True transformed_json = s3_to_json._cast_custom_fields_to_array( - json_obj={ - "CustomFields": { - "Symptoms": r'[{\\\"id\\\": "123", \\\"symptom\\\": "sick"}]' - } - }, - logger_context={}, + json_obj={ + "CustomFields": { + "Symptoms": r'[{\\\"id\\\": "123", \\\"symptom\\\": "sick"}]' + } + }, + logger_context={}, ) s3_to_json.logger.propagate = False assert len(caplog.records) == 1 assert transformed_json["CustomFields"]["Symptoms"] == [] def test_transform_garmin_data_types_one_level_hierarchy(self): - time_offset_heartrate_samples = { - "0": 60.0, - "1": 61.0, - "2": 99.0 - } + time_offset_heartrate_samples = {"0": 60.0, "1": 61.0, "2": 99.0} transformed_time_offset_heartrate_samples = [ - { - "OffsetInSeconds": 0, - "HeartRate": 60 - }, - { - "OffsetInSeconds": 1, - "HeartRate": 61 - }, - { - "OffsetInSeconds":2, - "HeartRate": 99 - } + {"OffsetInSeconds": 0, "HeartRate": 60}, + {"OffsetInSeconds": 1, "HeartRate": 61}, + {"OffsetInSeconds": 2, "HeartRate": 99}, ] - data_type_transforms={"TimeOffsetHeartRateSamples": (("OffsetInSeconds", int), ("HeartRate", int))} + data_type_transforms = { + "TimeOffsetHeartRateSamples": (("OffsetInSeconds", int), ("HeartRate", int)) + } transformed_json = s3_to_json._transform_garmin_data_types( - json_obj={"TimeOffsetHeartRateSamples": time_offset_heartrate_samples}, - data_type_transforms=data_type_transforms, - logger_context={}, + json_obj={"TimeOffsetHeartRateSamples": time_offset_heartrate_samples}, + data_type_transforms=data_type_transforms, + logger_context={}, ) assert all( - [ - obj in transformed_json["TimeOffsetHeartRateSamples"] - for obj in transformed_time_offset_heartrate_samples - ] + [ + obj in transformed_json["TimeOffsetHeartRateSamples"] + for obj in transformed_time_offset_heartrate_samples + ] ) def test_transform_garmin_data_types_two_level_hierarchy(self): - epoch_summaries = { - "0": 60.0, - "1": 61.0, - "2": 99.0 - } + epoch_summaries = {"0": 60.0, "1": 61.0, "2": 99.0} transformed_epoch_summaries = [ - { - "OffsetInSeconds": 0, - "Value": 60.0 - }, - { - "OffsetInSeconds": 1, - "Value": 61.0 - }, - { - "OffsetInSeconds":2, - "Value": 99.0 - } + {"OffsetInSeconds": 0, "Value": 60.0}, + {"OffsetInSeconds": 1, "Value": 61.0}, + {"OffsetInSeconds": 2, "Value": 99.0}, ] - data_type_transforms= { - "Summaries.EpochSummaries": (("OffsetInSeconds", int), ("Value", float)) + data_type_transforms = { + "Summaries.EpochSummaries": (("OffsetInSeconds", int), ("Value", float)) } transformed_json = s3_to_json._transform_garmin_data_types( - json_obj={ - "Summaries": [ - { - "EpochSummaries": epoch_summaries, - "Dummy": 1 - }, - { - "EpochSummaries": epoch_summaries, - }, - ] - }, - data_type_transforms=data_type_transforms, - logger_context={}, + json_obj={ + "Summaries": [ + {"EpochSummaries": epoch_summaries, "Dummy": 1}, + { + "EpochSummaries": epoch_summaries, + }, + ] + }, + data_type_transforms=data_type_transforms, + logger_context={}, ) print(transformed_json) assert all( - [ - obj in transformed_json["Summaries"][0]["EpochSummaries"] - for obj in transformed_epoch_summaries - ] + [ + obj in transformed_json["Summaries"][0]["EpochSummaries"] + for obj in transformed_epoch_summaries + ] ) assert transformed_json["Summaries"][0]["Dummy"] == 1 assert all( - [ - obj in transformed_json["Summaries"][1]["EpochSummaries"] - for obj in transformed_epoch_summaries - ] + [ + obj in transformed_json["Summaries"][1]["EpochSummaries"] + for obj in transformed_epoch_summaries + ] ) def test_transform_block_empty_file(self, s3_obj, sample_metadata): @@ -319,9 +271,7 @@ def test_transform_block_empty_file(self, s3_obj, sample_metadata): json_path = "HealthKitV2Samples_Weight_20230112-20230114.json" with z.open(json_path, "r") as input_json: transformed_block = s3_to_json.transform_block( - input_json=input_json, - metadata=sample_metadata, - block_size=2 + input_json=input_json, metadata=sample_metadata, block_size=2 ) with pytest.raises(StopIteration): next(transformed_block) @@ -332,15 +282,12 @@ def test_transform_block_non_empty_file_block_size(self, s3_obj, sample_metadata json_path = "FitbitSleepLogs_20230112-20230114.json" with z.open(json_path, "r") as input_json: transformed_block = s3_to_json.transform_block( - input_json=input_json, - metadata=sample_metadata, - block_size=2 + input_json=input_json, metadata=sample_metadata, block_size=2 ) first_block = next(transformed_block) assert len(first_block) == 2 - assert ( - isinstance(first_block[0], dict) - and isinstance(first_block[1], dict) + assert isinstance(first_block[0], dict) and isinstance( + first_block[1], dict ) def test_transform_block_non_empty_file_all_blocks(self, s3_obj, sample_metadata): @@ -351,9 +298,7 @@ def test_transform_block_non_empty_file_all_blocks(self, s3_obj, sample_metadata record_count = len(input_json.readlines()) with z.open(json_path, "r") as input_json: transformed_block = s3_to_json.transform_block( - input_json=input_json, - metadata=sample_metadata, - block_size=10 + input_json=input_json, metadata=sample_metadata, block_size=10 ) counter = 0 for block in transformed_block: @@ -363,35 +308,42 @@ def test_transform_block_non_empty_file_all_blocks(self, s3_obj, sample_metadata def test_get_output_filename_generic(self, sample_metadata): output_filename = s3_to_json.get_output_filename( - metadata=sample_metadata, - part_number=0 + metadata=sample_metadata, part_number=0 + ) + assert ( + output_filename + == f"{sample_metadata['type']}_20220112-20230114.part0.ndjson" ) - assert output_filename == f"{sample_metadata['type']}_20220112-20230114.part0.ndjson" def test_get_output_filename_no_start_date(self, sample_metadata): sample_metadata["start_date"] = None output_filename = s3_to_json.get_output_filename( - metadata=sample_metadata, - part_number=0 + metadata=sample_metadata, part_number=0 ) assert output_filename == f"{sample_metadata['type']}_20230114.part0.ndjson" def test_get_output_filename_subtype(self, sample_metadata): sample_metadata["type"] = "HealthKitV2Samples" output_filename = s3_to_json.get_output_filename( - metadata=sample_metadata, - part_number=0 + metadata=sample_metadata, part_number=0 + ) + assert ( + output_filename + == "HealthKitV2Samples_FakeSubtype_20220112-20230114.part0.ndjson" ) - assert output_filename == "HealthKitV2Samples_FakeSubtype_20220112-20230114.part0.ndjson" - def test_upload_file_to_json_dataset_delete_local_copy(self, namespace, sample_metadata, monkeypatch, shared_datadir): + def test_upload_file_to_json_dataset_delete_local_copy( + self, namespace, sample_metadata, monkeypatch, shared_datadir + ): monkeypatch.setattr("boto3.client", lambda x: MockAWSClient()) workflow_run_properties = { "namespace": namespace, "json_prefix": "raw-json", "json_bucket": "json-bucket", } - original_file_path = os.path.join(shared_datadir, "2023-01-13T21--08--51Z_TESTDATA") + original_file_path = os.path.join( + shared_datadir, "2023-01-13T21--08--51Z_TESTDATA" + ) temp_dir = f"dataset={sample_metadata['type']}" os.makedirs(temp_dir) new_file_path = shutil.copy(original_file_path, temp_dir) @@ -405,18 +357,22 @@ def test_upload_file_to_json_dataset_delete_local_copy(self, namespace, sample_m assert not os.path.exists(new_file_path) shutil.rmtree(temp_dir) - def test_upload_file_to_json_dataset_s3_key(self, namespace, monkeypatch, shared_datadir): + def test_upload_file_to_json_dataset_s3_key( + self, namespace, monkeypatch, shared_datadir + ): monkeypatch.setattr("boto3.client", lambda x: MockAWSClient()) sample_metadata = { - "type": "HealthKitV2Samples", - "subtype": "Weight", + "type": "HealthKitV2Samples", + "subtype": "Weight", } workflow_run_properties = { "namespace": namespace, "json_prefix": "raw-json", "json_bucket": "json-bucket", } - original_file_path = os.path.join(shared_datadir, "2023-01-13T21--08--51Z_TESTDATA") + original_file_path = os.path.join( + shared_datadir, "2023-01-13T21--08--51Z_TESTDATA" + ) temp_dir = f"dataset={sample_metadata['type']}" os.makedirs(temp_dir) new_file_path = shutil.copy(original_file_path, temp_dir) @@ -435,7 +391,9 @@ def test_upload_file_to_json_dataset_s3_key(self, namespace, monkeypatch, shared assert s3_key == correct_s3_key shutil.rmtree(temp_dir) - def test_write_file_to_json_dataset_delete_local_copy(self, s3_obj, sample_metadata, namespace, monkeypatch): + def test_write_file_to_json_dataset_delete_local_copy( + self, s3_obj, sample_metadata, namespace, monkeypatch + ): sample_metadata["type"] = "HealthKitV2Samples" monkeypatch.setattr("boto3.client", lambda x: MockAWSClient()) workflow_run_properties = { @@ -457,7 +415,8 @@ def test_write_file_to_json_dataset_delete_local_copy(self, s3_obj, sample_metad shutil.rmtree(f"dataset={sample_metadata['type']}") def test_write_file_to_json_dataset_record_consistency( - self, s3_obj, sample_metadata, namespace, monkeypatch): + self, s3_obj, sample_metadata, namespace, monkeypatch + ): monkeypatch.setattr("boto3.client", lambda x: MockAWSClient()) sample_metadata["start_date"] = None workflow_run_properties = { @@ -493,7 +452,8 @@ def test_write_file_to_json_dataset_record_consistency( shutil.rmtree(f"dataset={sample_metadata['type']}", ignore_errors=True) def test_write_file_to_json_dataset_multiple_parts( - self, s3_obj, sample_metadata, namespace, monkeypatch): + self, s3_obj, sample_metadata, namespace, monkeypatch + ): monkeypatch.setattr("boto3.client", lambda x: MockAWSClient()) sample_metadata["type"] = "FitbitIntradayCombined" workflow_run_properties = { @@ -511,7 +471,7 @@ def test_write_file_to_json_dataset_multiple_parts( metadata=sample_metadata, workflow_run_properties=workflow_run_properties, delete_upon_successful_upload=False, - file_size_limit=1e6 + file_size_limit=1e6, ) output_line_count = 0 for output_file in output_files: @@ -530,19 +490,19 @@ def test_derive_str_metadata(self, sample_metadata): def test_get_part_path_no_touch(self, sample_metadata): sample_metadata["start_date"] = None part_path = s3_to_json.get_part_path( - metadata=sample_metadata, - part_number=0, - part_dir=sample_metadata["type"], - touch=False + metadata=sample_metadata, + part_number=0, + part_dir=sample_metadata["type"], + touch=False, ) assert part_path == "FitbitDevices/FitbitDevices_20230114.part0.ndjson" def test_get_part_path_touch(self, sample_metadata): part_path = s3_to_json.get_part_path( - metadata=sample_metadata, - part_number=0, - part_dir=sample_metadata["type"], - touch=True + metadata=sample_metadata, + part_number=0, + part_dir=sample_metadata["type"], + touch=True, ) assert os.path.exists(part_path) shutil.rmtree(sample_metadata["type"], ignore_errors=True) @@ -558,8 +518,9 @@ def test_get_metadata_startdate_enddate(self, json_file_basenames_dict): def test_get_metadata_no_startdate(self, json_file_basenames_dict): basename = json_file_basenames_dict["EnrolledParticipants"] assert s3_to_json.get_metadata(basename)["start_date"] is None - assert s3_to_json.get_metadata(basename)["end_date"] == \ - datetime.datetime(2023, 1, 3, 0, 0) + assert s3_to_json.get_metadata(basename)["end_date"] == datetime.datetime( + 2023, 1, 3, 0, 0 + ) def test_get_metadata_subtype(self, json_file_basenames_dict): basename = json_file_basenames_dict["HealthKitV2Samples"] @@ -581,15 +542,17 @@ def test_get_metadata_no_subtype(self, json_file_basenames_dict): subtypes = [ "subtype" in record.keys() for record in metadata - if record["type"] not in [ + if record["type"] + not in [ "HealthKitV2Samples", "HealthKitV2Samples_Deleted", "HealthKitV2Statistics", - "HealthKitV2Statistics_Deleted" + "HealthKitV2Statistics_Deleted", ] ] - assert not any(subtypes),\ - "Some data types that are not HealthKitV2Samples or HealthKitV2Statistics have the metadata subtype key" + assert not any( + subtypes + ), "Some data types that are not HealthKitV2Samples or HealthKitV2Statistics have the metadata subtype key" def test_get_metadata_type(self, json_file_basenames_dict): # check that all file basenames match their type @@ -598,13 +561,15 @@ def test_get_metadata_type(self, json_file_basenames_dict): == basename for basename in json_file_basenames_dict.keys() ] - assert all(metadata_check),\ - "Some data types' metadata type key are incorrect" + assert all(metadata_check), "Some data types' metadata type key are incorrect" def test_get_basic_file_info(self): file_path = "my/dir/HealthKitV2Samples_Weight_20230112-20230114.json" basic_file_info = s3_to_json.get_basic_file_info(file_path=file_path) required_fields = ["file.type", "file.path", "file.name", "file.extension"] assert all([field in basic_file_info for field in required_fields]) - assert basic_file_info["file.name"] == "HealthKitV2Samples_Weight_20230112-20230114.json" + assert ( + basic_file_info["file.name"] + == "HealthKitV2Samples_Weight_20230112-20230114.json" + ) assert basic_file_info["file.extension"] == "json" diff --git a/tests/test_setup_external_storage.py b/tests/test_setup_external_storage.py index f4094d59..8492008a 100644 --- a/tests/test_setup_external_storage.py +++ b/tests/test_setup_external_storage.py @@ -27,7 +27,7 @@ def test_setup_external_storage_success( namespace: str, test_synapse_folder_id: str, test_sts_permission: str, - ssm_parameter : str + ssm_parameter: str, ): """This test tests that it can get the STS token credentials and view and list the files in the S3 bucket location to verify that it has access"""