diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py index 1bea70ae25544..8a34812eca6a4 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py @@ -107,7 +107,7 @@ def __init__( self.verify = verify self.deferrable = deferrable self.use_regex = use_regex - self.metadata_keys = metadata_keys if metadata_keys else ["Size"] + self.metadata_keys = metadata_keys if metadata_keys else ["Size", "Key"] def _check_key(self, key, context: Context): bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") @@ -116,7 +116,8 @@ def _check_key(self, key, context: Context): """ Set variable `files` which contains a list of dict which contains attributes defined by the user Format: [{ - 'Size': int + 'Size': int, + 'Key': str, }] """ if self.wildcard_match: diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py index b9f8fb284bcd0..4ada09d092303 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py @@ -320,11 +320,19 @@ def test_fail_execute_complete(self): with pytest.raises(AirflowException, match=message): op.execute_complete(context={}, event={"status": "error", "message": message}) + @pytest.mark.parametrize( + "metadata_keys, expected", + [ + (["Size", "Key"], True), + (["Content"], False), + (None, True), + ], + ) @mock_aws - def test_custom_metadata_default_return_vals(self): + def test_custom_metadata_default_return_vals(self, metadata_keys, expected): def check_fn(files: list) -> bool: for f in files: - if "Size" not in f: + if "Size" not in f or "Key" not in f: return False return True @@ -335,31 +343,14 @@ def check_fn(files: list) -> bool: key="test-key", string_data="test-body", ) - - op = S3KeySensor( - task_id="test-metadata", - bucket_key="test-key", - bucket_name="test-bucket", - metadata_keys=["Size"], - check_fn=check_fn, - ) - assert op.poke(None) is True op = S3KeySensor( task_id="test-metadata", bucket_key="test-key", bucket_name="test-bucket", - metadata_keys=["Content"], + metadata_keys=metadata_keys, check_fn=check_fn, ) - assert op.poke(None) is False - - op = S3KeySensor( - task_id="test-metadata", - bucket_key="test-key", - bucket_name="test-bucket", - check_fn=check_fn, - ) - assert op.poke(None) is True + assert op.poke(None) is expected @mock_aws def test_custom_metadata_default_custom_vals(self): @@ -391,7 +382,7 @@ def test_custom_metadata_all_attributes(self): def check_fn(files: list) -> bool: hook = S3Hook() metadata_keys = set(hook.head_object(bucket_name="test-bucket", key="test-key").keys()) - test_data_keys = set(files[0].keys()) + test_data_keys = set(files[0].keys()) - {"Key"} return test_data_keys == metadata_keys