Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 147 additions & 157 deletions providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py

Large diffs are not rendered by default.

73 changes: 31 additions & 42 deletions providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import re
from collections.abc import Sequence
from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, cast

from airflow.configuration import conf
Expand All @@ -34,11 +33,13 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger, S3KeyTrigger
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.sensors.base import poke_mode_only


class S3KeySensor(BaseSensorOperator):
class S3KeySensor(AwsBaseSensor[S3Hook]):
"""
Waits for one or multiple keys (a file-like instance on S3) to be present in a S3 bucket.

Expand All @@ -65,27 +66,25 @@ class S3KeySensor(BaseSensorOperator):

def check_fn(files: List, **kwargs) -> bool:
return any(f.get('Size', 0) > 1048576 for f in files)
:param aws_conn_id: a reference to the s3 connection
:param verify: Whether to verify SSL certificates for S3 connection.
By default, SSL certificates are verified.
You can provide the following values:

- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param deferrable: Run operator in the deferrable mode
:param use_regex: whether to use regex to check bucket
:param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
all available attributes.
Default value: "Size".
If the requested attribute is not found, the key is still included and the value is None.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
"""

template_fields: Sequence[str] = ("bucket_key", "bucket_name")
template_fields: Sequence[str] = aws_template_fields("bucket_key", "bucket_name")
aws_hook_class = S3Hook

def __init__(
self,
Expand All @@ -94,7 +93,6 @@ def __init__(
bucket_name: str | None = None,
wildcard_match: bool = False,
check_fn: Callable[..., bool] | None = None,
aws_conn_id: str | None = "aws_default",
verify: str | bool | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
use_regex: bool = False,
Expand All @@ -106,14 +104,13 @@ def __init__(
self.bucket_key = bucket_key
self.wildcard_match = wildcard_match
self.check_fn = check_fn
self.aws_conn_id = aws_conn_id
self.verify = verify
self.deferrable = deferrable
self.use_regex = use_regex
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]

def _check_key(self, key, context: Context):
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)

"""
Expand Down Expand Up @@ -199,7 +196,9 @@ def _defer(self) -> None:
bucket_key=self.bucket_key,
wildcard_match=self.wildcard_match,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
poke_interval=self.poke_interval,
should_check_fn=bool(self.check_fn),
use_regex=self.use_regex,
Expand All @@ -220,13 +219,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
elif event["status"] == "error":
raise AirflowException(event["message"])

@cached_property
def hook(self) -> S3Hook:
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)


@poke_mode_only
class S3KeysUnchangedSensor(BaseSensorOperator):
class S3KeysUnchangedSensor(AwsBaseSensor[S3Hook]):
"""
Return True if inactivity_period has passed with no increase in the number of objects matching prefix.

Expand All @@ -239,17 +234,7 @@ class S3KeysUnchangedSensor(BaseSensorOperator):

:param bucket_name: Name of the S3 bucket
:param prefix: The prefix being waited on. Relative path from bucket root level.
:param aws_conn_id: a reference to the s3 connection
:param verify: Whether or not to verify SSL certificates for S3 connection.
By default SSL certificates are verified.
You can provide the following values:

- ``False``: do not validate SSL certificates. SSL will still be used
(unless use_ssl is False), but SSL certificates will not be
verified.
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param inactivity_period: The total seconds of inactivity to designate
keys unchanged. Note, this mechanism is not real time and
this operator may not return until a poke_interval after this period
Expand All @@ -261,16 +246,24 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
between pokes valid behavior. If true a warning message will be logged
when this happens. If false an error will be raised.
:param deferrable: Run sensor in the deferrable mode
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
"""

template_fields: Sequence[str] = ("bucket_name", "prefix")
template_fields: Sequence[str] = aws_template_fields("bucket_name", "prefix")
aws_hook_class = S3Hook

def __init__(
self,
*,
bucket_name: str,
prefix: str,
aws_conn_id: str | None = "aws_default",
verify: bool | str | None = None,
inactivity_period: float = 60 * 60,
min_objects: int = 1,
Expand All @@ -291,15 +284,9 @@ def __init__(
self.inactivity_seconds = 0
self.allow_delete = allow_delete
self.deferrable = deferrable
self.aws_conn_id = aws_conn_id
self.verify = verify
self.last_activity_time: datetime | None = None

@cached_property
def hook(self):
"""Returns S3Hook."""
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)

def is_keys_unchanged(self, current_objects: set[str]) -> bool:
"""
Check for new objects after the inactivity_period and update the sensor state accordingly.
Expand Down Expand Up @@ -382,7 +369,9 @@ def execute(self, context: Context) -> None:
inactivity_seconds=self.inactivity_seconds,
allow_delete=self.allow_delete,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
last_activity_time=self.last_activity_time,
),
method_name="execute_complete",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def __init__(
poke_interval: float = 5.0,
should_check_fn: bool = False,
use_regex: bool = False,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
**hook_params: Any,
):
super().__init__()
Expand All @@ -64,6 +67,9 @@ def __init__(
self.poke_interval = poke_interval
self.should_check_fn = should_check_fn
self.use_regex = use_regex
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize S3KeyTrigger arguments and classpath."""
Expand All @@ -78,12 +84,20 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"poke_interval": self.poke_interval,
"should_check_fn": self.should_check_fn,
"use_regex": self.use_regex,
"region_name": self.region_name,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self) -> S3Hook:
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
return S3Hook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3HookAsync."""
Expand Down Expand Up @@ -143,7 +157,9 @@ def __init__(
allow_delete: bool = True,
aws_conn_id: str | None = "aws_default",
last_activity_time: datetime | None = None,
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
**hook_params: Any,
):
super().__init__()
Expand All @@ -160,8 +176,10 @@ def __init__(
self.allow_delete = allow_delete
self.aws_conn_id = aws_conn_id
self.last_activity_time = last_activity_time
self.verify = verify
self.polling_period_seconds = 0
self.region_name = region_name
self.verify = verify
self.botocore_config = botocore_config
self.hook_params = hook_params

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand All @@ -179,14 +197,21 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"aws_conn_id": self.aws_conn_id,
"last_activity_time": self.last_activity_time,
"hook_params": self.hook_params,
"verify": self.verify,
"polling_period_seconds": self.polling_period_seconds,
"region_name": self.region_name,
"verify": self.verify,
"botocore_config": self.botocore_config,
},
)

@cached_property
def hook(self) -> S3Hook:
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
return S3Hook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make an asynchronous connection using S3Hook."""
Expand Down
26 changes: 12 additions & 14 deletions providers/amazon/tests/unit/amazon/aws/operators/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,20 +415,19 @@ def test_template_fields(self):


class TestS3ListOperator:
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_execute(self, mock_hook):
mock_hook.return_value.list_keys.return_value = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]

def test_execute(self):
operator = S3ListOperator(
task_id="test-s3-list-operator",
bucket=BUCKET_NAME,
prefix="TEST",
delimiter=".csv",
)
operator.hook = mock.MagicMock()
operator.hook.list_keys.return_value = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]

files = operator.execute(None)

mock_hook.return_value.list_keys.assert_called_once_with(
operator.hook.list_keys.assert_called_once_with(
bucket_name=BUCKET_NAME,
prefix="TEST",
delimiter=".csv",
Expand All @@ -447,17 +446,16 @@ def test_template_fields(self):


class TestS3ListPrefixesOperator:
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_execute(self, mock_hook):
mock_hook.return_value.list_prefixes.return_value = ["test/"]

def test_execute(self):
operator = S3ListPrefixesOperator(
task_id="test-s3-list-prefixes-operator", bucket=BUCKET_NAME, prefix="test/", delimiter="/"
)
operator.hook = mock.MagicMock()
operator.hook.list_prefixes.return_value = ["test/"]

subfolders = operator.execute(None)

mock_hook.return_value.list_prefixes.assert_called_once_with(
operator.hook.list_prefixes.assert_called_once_with(
bucket_name=BUCKET_NAME, prefix="test/", delimiter="/"
)
assert subfolders == ["test/"]
Expand Down Expand Up @@ -870,8 +868,7 @@ def test_validate_keys_and_prefix_in_execute(self, keys, prefix, from_datetime,
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test

@pytest.mark.parametrize("keys", ("path/data.txt", ["path/data.txt"]))
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys):
def test_get_openlineage_facets_on_complete_single_object(self, keys):
bucket = "testbucket"
expected_input = Dataset(
namespace=f"s3://{bucket}",
Expand All @@ -888,14 +885,14 @@ def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys)
)

op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys)
op.hook = mock.MagicMock()
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 1
assert lineage.inputs[0] == expected_input

@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook):
def test_get_openlineage_facets_on_complete_multiple_objects(self):
bucket = "testbucket"
keys = ["path/data1.txt", "path/data2.txt"]
expected_inputs = [
Expand Down Expand Up @@ -928,6 +925,7 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook):
]

op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys)
op.hook = mock.MagicMock()
op.execute(None)

lineage = op.get_openlineage_facets_on_complete(None)
Expand Down
6 changes: 3 additions & 3 deletions providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,10 @@ def test_key_changes(self, current_objects, expected_returns, inactivity_periods
assert self.sensor.inactivity_seconds == period
time_machine.coordinates.shift(10)

@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine):
def test_poke_succeeds_on_upload_complete(self, time_machine):
time_machine.move_to(DEFAULT_DATE)
mock_hook.return_value.list_keys.return_value = {"a"}
self.sensor.hook = mock.MagicMock()
self.sensor.hook.list_keys.return_value = {"a"}
assert not self.sensor.poke(dict())
time_machine.coordinates.shift(10)
assert not self.sensor.poke(dict())
Expand Down
5 changes: 5 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def test_serialization(self):
"poke_interval": 5.0,
"should_check_fn": False,
"use_regex": False,
"verify": None,
"region_name": None,
"botocore_config": None,
}

@pytest.mark.asyncio
Expand Down Expand Up @@ -106,6 +109,8 @@ def test_serialization(self):
"last_activity_time": None,
"hook_params": {},
"verify": None,
"region_name": None,
"botocore_config": None,
"polling_period_seconds": 0,
}

Expand Down
Loading