diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py index 878e31b924594..fbb74630a244d 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py @@ -215,6 +215,7 @@ def _defer(self) -> None: poke_interval=self.poke_interval, should_check_fn=bool(self.check_fn), use_regex=self.use_regex, + metadata_keys=self.metadata_keys, ), method_name="execute_complete", ) @@ -226,7 +227,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event["status"] == "running": - found_keys = self.check_fn(event["files"]) # type: ignore[misc] + found_keys = self.check_fn(event["files"], **context) # type: ignore[misc] if not found_keys: self._defer() elif event["status"] == "error": diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py index 0bc192b479845..edf811b42f0a8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py @@ -41,6 +41,11 @@ class S3KeyTrigger(BaseTrigger): Unix wildcard pattern :param aws_conn_id: reference to the s3 connection :param use_regex: whether to use regex to check bucket + :param metadata_keys: List of head_object attributes to gather and send to ``check_fn``. + Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return + all available attributes. + Default value: "Size". + If the requested attribute is not found, the key is still included and the value is None. :param hook_params: params for hook its optional """ @@ -56,6 +61,7 @@ def __init__( region_name: str | None = None, verify: bool | str | None = None, botocore_config: dict | None = None, + metadata_keys: list[str] | None = None, **hook_params: Any, ): super().__init__() @@ -70,6 +76,7 @@ def __init__( self.region_name = region_name self.verify = verify self.botocore_config = botocore_config + self.metadata_keys = metadata_keys if metadata_keys else ["Size", "Key"] def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize S3KeyTrigger arguments and classpath.""" @@ -87,6 +94,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "region_name": self.region_name, "verify": self.verify, "botocore_config": self.botocore_config, + "metadata_keys": self.metadata_keys, }, ) @@ -108,11 +116,30 @@ async def run(self) -> AsyncIterator[TriggerEvent]: client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex ): if self.should_check_fn: - s3_objects = await self.hook.get_files_async( + raw_objects = await self.hook.get_files_async( client, self.bucket_name, self.bucket_key, self.wildcard_match ) + files = [] + for f in raw_objects: + metadata = {} + obj = await self.hook.get_head_object_async( + client=client, key=f, bucket_name=self.bucket_name + ) + if obj is None: + return + + if "*" in self.metadata_keys: + metadata = obj + else: + for mk in self.metadata_keys: + if mk == "Size": + metadata[mk] = obj.get("ContentLength") + else: + metadata[mk] = obj.get(mk, None) + metadata["Key"] = f + files.append(metadata) await asyncio.sleep(self.poke_interval) - yield TriggerEvent({"status": "running", "files": s3_objects}) + yield TriggerEvent({"status": "running", "files": files}) else: yield TriggerEvent({"status": "success"}) return diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py index f55a6a0f157e5..c43a39adeaf77 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py @@ -33,7 +33,10 @@ def test_serialization(self): and classpath. """ trigger = S3KeyTrigger( - bucket_key="s3://test_bucket/file", bucket_name="test_bucket", wildcard_match=True + bucket_key="s3://test_bucket/file", + bucket_name="test_bucket", + wildcard_match=True, + metadata_keys=["Size", "LastModified"], ) classpath, kwargs = trigger.serialize() assert classpath == "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger" @@ -49,6 +52,7 @@ def test_serialization(self): "verify": None, "region_name": None, "botocore_config": None, + "metadata_keys": ["Size", "LastModified"], } @pytest.mark.asyncio @@ -58,11 +62,13 @@ async def test_run_success(self, mock_client): Test if the task is run is in triggerr successfully. """ mock_client.return_value.return_value.check_key.return_value = True - trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket") + trigger = S3KeyTrigger(bucket_key="test_bucket/file", bucket_name="test_bucket") task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) assert task.done() is True + result = await task + assert result == TriggerEvent({"status": "success"}) asyncio.get_event_loop().stop() @pytest.mark.asyncio @@ -73,13 +79,104 @@ async def test_run_pending(self, mock_client, mock_check_key_async): Test if the task is run is in trigger successfully and set check_key to return false. """ mock_check_key_async.return_value = False - trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket") + trigger = S3KeyTrigger(bucket_key="test_bucket/file", bucket_name="test_bucket") task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) assert task.done() is False asyncio.get_event_loop().stop() + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_files_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.check_key_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn") + async def test_run_with_metadata( + self, + mock_get_async_conn, + mock_check_key_async, + mock_get_head_object_async, + mock_get_files_async, + ): + """Test if the task retrieves metadata correctly when should_check_fn is True.""" + mock_check_key_async.return_value = True + mock_get_files_async.return_value = ["file1.txt", "file2.txt"] + + async def fake_get_head_object_async(*args, **kwargs): + key = kwargs.get("key") + if key == "file1.txt": + return {"ContentLength": 1024, "LastModified": "2023-10-01T12:00:00Z"} + if key == "file2.txt": + return {"ContentLength": 2048, "LastModified": "2023-10-02T12:00:00Z"} + + mock_get_head_object_async.side_effect = fake_get_head_object_async + mock_get_async_conn.return_value.__aenter__.return_value = async_mock.AsyncMock() + trigger = S3KeyTrigger( + bucket_key="test_bucket/file", + bucket_name="test_bucket", + should_check_fn=True, + metadata_keys=["Size", "LastModified"], + poke_interval=0.1, # reduce waiting time + ) + result = await asyncio.wait_for(trigger.run().__anext__(), timeout=2) + expected = TriggerEvent( + { + "status": "running", + "files": [ + {"Size": 1024, "LastModified": "2023-10-01T12:00:00Z", "Key": "file1.txt"}, + {"Size": 2048, "LastModified": "2023-10-02T12:00:00Z", "Key": "file2.txt"}, + ], + } + ) + + assert result == expected + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_files_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.check_key_async") + @async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn") + async def test_run_with_all_metadata( + self, mock_get_async_conn, mock_check_key_async, mock_get_head_object_async, mock_get_files_async + ): + """ + Test if the task retrieves all metadata when metadata_keys contains '*'. + """ + mock_check_key_async.return_value = True + mock_get_files_async.return_value = ["file1.txt"] + + async def fake_get_head_object_async(*args, **kwargs): + return { + "ContentLength": 1024, + "LastModified": "2023-10-01T12:00:00Z", + "ETag": "abc123", + } + + mock_get_head_object_async.side_effect = fake_get_head_object_async + mock_get_async_conn.return_value.__aenter__.return_value = async_mock.AsyncMock() + trigger = S3KeyTrigger( + bucket_key="test_bucket/file", + bucket_name="test_bucket", + should_check_fn=True, + metadata_keys=["*"], + poke_interval=0.1, + ) + result = await asyncio.wait_for(trigger.run().__anext__(), timeout=2) + expected = TriggerEvent( + { + "status": "running", + "files": [ + { + "ContentLength": 1024, + "LastModified": "2023-10-01T12:00:00Z", + "ETag": "abc123", + "Key": "file1.txt", + } + ], + } + ) + assert result == expected + class TestS3KeysUnchangedTrigger: def test_serialization(self): @@ -156,7 +253,6 @@ async def test_run_pending(self, mock_is_keys_unchanged, mock_client): trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test") task = asyncio.create_task(trigger.run().__anext__()) await asyncio.sleep(0.5) - # TriggerEvent was not returned assert task.done() is False asyncio.get_event_loop().stop()