Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
3d714b2
Pass context to check_fn in S3 sensor
carlinix Oct 20, 2025
101503e
Enhance S3 trigger with metadata handling
carlinix Oct 20, 2025
183e5fc
Add metadata_keys parameter to S3KeyTrigger
carlinix Oct 20, 2025
d468764
Add metadata_keys to S3 sensor configuration
carlinix Oct 20, 2025
69910c1
Fix indentation in S3 trigger metadata handling
carlinix Oct 20, 2025
ac16add
Add metadata_keys to S3 trigger object retrieval
carlinix Oct 21, 2025
fa62a58
Enhance S3KeyTrigger tests with metadata assertions
carlinix Oct 21, 2025
f329241
Update context passing in S3 sensor check function
carlinix Oct 21, 2025
710c0ef
Improve S3 trigger test cases
carlinix Oct 21, 2025
40d2c00
Add sleep interval before yielding trigger events
carlinix Oct 21, 2025
39166c5
Update S3 trigger test to include key check mock
carlinix Oct 21, 2025
1149d21
Merge branch 'main' into main
carlinix Oct 30, 2025
bfbac6a
Merge branch 'main' into main
carlinix Oct 31, 2025
c9dd046
Merge branch 'apache:main' into main
carlinix Nov 3, 2025
d55af3d
Merge branch 'apache:main' into main
carlinix Nov 3, 2025
b742555
Refactor S3 trigger code for improved readability and consistency
carlinix Nov 3, 2025
30ec160
Merge branch 'apache:main' into main
carlinix Nov 3, 2025
33dddb3
Merge remote-tracking branch 'origin/main'
carlinix Nov 3, 2025
ea60cfc
Refactor S3 trigger code for improved readability and consistency
carlinix Nov 3, 2025
f83f63f
Merge branch 'apache:main' into main
carlinix Nov 3, 2025
28f5390
Remove unnecessary blank line in S3 trigger test
carlinix Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _defer(self) -> None:
poke_interval=self.poke_interval,
should_check_fn=bool(self.check_fn),
use_regex=self.use_regex,
metadata_keys=self.metadata_keys,
),
method_name="execute_complete",
)
Expand All @@ -226,7 +227,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if event["status"] == "running":
found_keys = self.check_fn(event["files"]) # type: ignore[misc]
found_keys = self.check_fn(event["files"], **context) # type: ignore[misc]
if not found_keys:
self._defer()
elif event["status"] == "error":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class S3KeyTrigger(BaseTrigger):
Unix wildcard pattern
:param aws_conn_id: reference to the s3 connection
:param use_regex: whether to use regex to check bucket
:param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
all available attributes.
Default value: "Size".
If the requested attribute is not found, the key is still included and the value is None.
:param hook_params: params for hook its optional
"""

Expand All @@ -56,6 +61,7 @@ def __init__(
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
metadata_keys: list[str] | None = None,
**hook_params: Any,
):
super().__init__()
Expand All @@ -70,6 +76,7 @@ def __init__(
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config
self.metadata_keys = metadata_keys if metadata_keys else ["Size", "Key"]

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
Expand All @@ -87,6 +94,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"region_name": self.region_name,
"verify": self.verify,
"botocore_config": self.botocore_config,
"metadata_keys": self.metadata_keys,
},
)

Expand All @@ -108,11 +116,30 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
client, self.bucket_name, self.bucket_key, self.wildcard_match, self.use_regex
):
if self.should_check_fn:
s3_objects = await self.hook.get_files_async(
raw_objects = await self.hook.get_files_async(
client, self.bucket_name, self.bucket_key, self.wildcard_match
)
files = []
for f in raw_objects:
metadata = {}
obj = await self.hook.get_head_object_async(
client=client, key=f, bucket_name=self.bucket_name
)
if obj is None:
return

if "*" in self.metadata_keys:
metadata = obj
else:
for mk in self.metadata_keys:
if mk == "Size":
metadata[mk] = obj.get("ContentLength")
else:
metadata[mk] = obj.get(mk, None)
metadata["Key"] = f
files.append(metadata)
await asyncio.sleep(self.poke_interval)
yield TriggerEvent({"status": "running", "files": s3_objects})
yield TriggerEvent({"status": "running", "files": files})
else:
yield TriggerEvent({"status": "success"})
return
Expand Down
104 changes: 100 additions & 4 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def test_serialization(self):
and classpath.
"""
trigger = S3KeyTrigger(
bucket_key="s3://test_bucket/file", bucket_name="test_bucket", wildcard_match=True
bucket_key="s3://test_bucket/file",
bucket_name="test_bucket",
wildcard_match=True,
metadata_keys=["Size", "LastModified"],
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.amazon.aws.triggers.s3.S3KeyTrigger"
Expand All @@ -49,6 +52,7 @@ def test_serialization(self):
"verify": None,
"region_name": None,
"botocore_config": None,
"metadata_keys": ["Size", "LastModified"],
}

@pytest.mark.asyncio
Expand All @@ -58,11 +62,13 @@ async def test_run_success(self, mock_client):
Test if the task is run is in triggerr successfully.
"""
mock_client.return_value.return_value.check_key.return_value = True
trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket")
trigger = S3KeyTrigger(bucket_key="test_bucket/file", bucket_name="test_bucket")
task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

assert task.done() is True
result = await task
assert result == TriggerEvent({"status": "success"})
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
Expand All @@ -73,13 +79,104 @@ async def test_run_pending(self, mock_client, mock_check_key_async):
Test if the task is run is in trigger successfully and set check_key to return false.
"""
mock_check_key_async.return_value = False
trigger = S3KeyTrigger(bucket_key="s3://test_bucket/file", bucket_name="test_bucket")
trigger = S3KeyTrigger(bucket_key="test_bucket/file", bucket_name="test_bucket")
task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

assert task.done() is False
asyncio.get_event_loop().stop()

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_files_async")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.check_key_async")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
async def test_run_with_metadata(
self,
mock_get_async_conn,
mock_check_key_async,
mock_get_head_object_async,
mock_get_files_async,
):
"""Test if the task retrieves metadata correctly when should_check_fn is True."""
mock_check_key_async.return_value = True
mock_get_files_async.return_value = ["file1.txt", "file2.txt"]

async def fake_get_head_object_async(*args, **kwargs):
key = kwargs.get("key")
if key == "file1.txt":
return {"ContentLength": 1024, "LastModified": "2023-10-01T12:00:00Z"}
if key == "file2.txt":
return {"ContentLength": 2048, "LastModified": "2023-10-02T12:00:00Z"}

mock_get_head_object_async.side_effect = fake_get_head_object_async
mock_get_async_conn.return_value.__aenter__.return_value = async_mock.AsyncMock()
trigger = S3KeyTrigger(
bucket_key="test_bucket/file",
bucket_name="test_bucket",
should_check_fn=True,
metadata_keys=["Size", "LastModified"],
poke_interval=0.1, # reduce waiting time
)
result = await asyncio.wait_for(trigger.run().__anext__(), timeout=2)
expected = TriggerEvent(
{
"status": "running",
"files": [
{"Size": 1024, "LastModified": "2023-10-01T12:00:00Z", "Key": "file1.txt"},
{"Size": 2048, "LastModified": "2023-10-02T12:00:00Z", "Key": "file2.txt"},
],
}
)

assert result == expected

@pytest.mark.asyncio
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_files_async")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_head_object_async")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.check_key_async")
@async_mock.patch("airflow.providers.amazon.aws.triggers.s3.S3Hook.get_async_conn")
async def test_run_with_all_metadata(
self, mock_get_async_conn, mock_check_key_async, mock_get_head_object_async, mock_get_files_async
):
"""
Test if the task retrieves all metadata when metadata_keys contains '*'.
"""
mock_check_key_async.return_value = True
mock_get_files_async.return_value = ["file1.txt"]

async def fake_get_head_object_async(*args, **kwargs):
return {
"ContentLength": 1024,
"LastModified": "2023-10-01T12:00:00Z",
"ETag": "abc123",
}

mock_get_head_object_async.side_effect = fake_get_head_object_async
mock_get_async_conn.return_value.__aenter__.return_value = async_mock.AsyncMock()
trigger = S3KeyTrigger(
bucket_key="test_bucket/file",
bucket_name="test_bucket",
should_check_fn=True,
metadata_keys=["*"],
poke_interval=0.1,
)
result = await asyncio.wait_for(trigger.run().__anext__(), timeout=2)
expected = TriggerEvent(
{
"status": "running",
"files": [
{
"ContentLength": 1024,
"LastModified": "2023-10-01T12:00:00Z",
"ETag": "abc123",
"Key": "file1.txt",
}
],
}
)
assert result == expected


class TestS3KeysUnchangedTrigger:
def test_serialization(self):
Expand Down Expand Up @@ -156,7 +253,6 @@ async def test_run_pending(self, mock_is_keys_unchanged, mock_client):
trigger = S3KeysUnchangedTrigger(bucket_name="test_bucket", prefix="test")
task = asyncio.create_task(trigger.run().__anext__())
await asyncio.sleep(0.5)

# TriggerEvent was not returned
assert task.done() is False
asyncio.get_event_loop().stop()