diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py index 05c5bd88634b3..406d9f597527b 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py @@ -29,8 +29,9 @@ from dateutil import parser from airflow.exceptions import AirflowException -from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import exactly_one if TYPE_CHECKING: @@ -41,7 +42,7 @@ BUCKET_DOES_NOT_EXIST_MSG = "Bucket with name: %s doesn't exist" -class S3CreateBucketOperator(BaseOperator): +class S3CreateBucketOperator(AwsBaseOperator[S3Hook]): """ This operator creates an S3 bucket. @@ -51,38 +52,38 @@ class S3CreateBucketOperator(BaseOperator): :param bucket_name: This is bucket name you want to create :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). - :param region_name: AWS region_name. If not specified fetched from connection. + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("bucket_name",) + template_fields: Sequence[str] = aws_template_fields("bucket_name") + aws_hook_class = S3Hook def __init__( self, *, bucket_name: str, - aws_conn_id: str | None = "aws_default", - region_name: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name - self.region_name = region_name - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name) - if not s3_hook.check_for_bucket(self.bucket_name): - s3_hook.create_bucket(bucket_name=self.bucket_name, region_name=self.region_name) + if not self.hook.check_for_bucket(self.bucket_name): + self.hook.create_bucket(bucket_name=self.bucket_name, region_name=self.region_name) self.log.info("Created bucket with name: %s", self.bucket_name) else: self.log.info("Bucket with name: %s already exists", self.bucket_name) -class S3DeleteBucketOperator(BaseOperator): +class S3DeleteBucketOperator(AwsBaseOperator[S3Hook]): """ This operator deletes an S3 bucket. @@ -93,36 +94,39 @@ class S3DeleteBucketOperator(BaseOperator): :param bucket_name: This is bucket name you want to delete :param force_delete: Forcibly delete all objects in the bucket before deleting the bucket :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("bucket_name",) + template_fields: Sequence[str] = aws_template_fields("bucket_name") + aws_hook_class = S3Hook def __init__( self, bucket_name: str, force_delete: bool = False, - aws_conn_id: str | None = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name self.force_delete = force_delete - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) - if s3_hook.check_for_bucket(self.bucket_name): - s3_hook.delete_bucket(bucket_name=self.bucket_name, force_delete=self.force_delete) + if self.hook.check_for_bucket(self.bucket_name): + self.hook.delete_bucket(bucket_name=self.bucket_name, force_delete=self.force_delete) self.log.info("Deleted bucket with name: %s", self.bucket_name) else: self.log.info("Bucket with name: %s doesn't exist", self.bucket_name) -class S3GetBucketTaggingOperator(BaseOperator): +class S3GetBucketTaggingOperator(AwsBaseOperator[S3Hook]): """ This operator gets tagging from an S3 bucket. @@ -132,31 +136,34 @@ class S3GetBucketTaggingOperator(BaseOperator): :param bucket_name: This is bucket name you want to reference :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("bucket_name",) + template_fields: Sequence[str] = aws_template_fields("bucket_name") + aws_hook_class = S3Hook def __init__(self, bucket_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) - - if s3_hook.check_for_bucket(self.bucket_name): + if self.hook.check_for_bucket(self.bucket_name): self.log.info("Getting tags for bucket %s", self.bucket_name) - return s3_hook.get_bucket_tagging(self.bucket_name) + return self.hook.get_bucket_tagging(self.bucket_name) else: self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) return None -class S3PutBucketTaggingOperator(BaseOperator): +class S3PutBucketTaggingOperator(AwsBaseOperator[S3Hook]): """ This operator puts tagging for an S3 bucket. @@ -171,14 +178,20 @@ class S3PutBucketTaggingOperator(BaseOperator): If a value is provided, a key must be provided as well. :param tag_set: A dictionary containing the tags, or a List of key/value pairs. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or - empty, then the default boto3 configuration would be used (and must be + empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("bucket_name",) + template_fields: Sequence[str] = aws_template_fields("bucket_name") template_fields_renderers = {"tag_set": "json"} + aws_hook_class = S3Hook def __init__( self, @@ -186,7 +199,6 @@ def __init__( key: str | None = None, value: str | None = None, tag_set: dict | list[dict[str, str]] | None = None, - aws_conn_id: str | None = "aws_default", **kwargs, ) -> None: super().__init__(**kwargs) @@ -194,14 +206,11 @@ def __init__( self.value = value self.tag_set = tag_set self.bucket_name = bucket_name - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) - - if s3_hook.check_for_bucket(self.bucket_name): + if self.hook.check_for_bucket(self.bucket_name): self.log.info("Putting tags for bucket %s", self.bucket_name) - return s3_hook.put_bucket_tagging( + return self.hook.put_bucket_tagging( key=self.key, value=self.value, tag_set=self.tag_set, bucket_name=self.bucket_name ) else: @@ -209,7 +218,7 @@ def execute(self, context: Context): return None -class S3DeleteBucketTaggingOperator(BaseOperator): +class S3DeleteBucketTaggingOperator(AwsBaseOperator[S3Hook]): """ This operator deletes tagging from an S3 bucket. @@ -219,31 +228,38 @@ class S3DeleteBucketTaggingOperator(BaseOperator): :param bucket_name: This is the name of the bucket to delete tags from. :param aws_conn_id: The Airflow connection used for AWS credentials. - If this is None or empty then the default boto3 behaviour is used. If + If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or empty, then default boto3 configuration would be used (and must be maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("bucket_name",) + template_fields: Sequence[str] = aws_template_fields("bucket_name") + aws_hook_class = S3Hook - def __init__(self, bucket_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None: + def __init__( + self, + bucket_name: str, + **kwargs, + ) -> None: super().__init__(**kwargs) self.bucket_name = bucket_name - self.aws_conn_id = aws_conn_id def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id) - - if s3_hook.check_for_bucket(self.bucket_name): + if self.hook.check_for_bucket(self.bucket_name): self.log.info("Deleting tags for bucket %s", self.bucket_name) - return s3_hook.delete_bucket_tagging(self.bucket_name) + return self.hook.delete_bucket_tagging(self.bucket_name) else: self.log.warning(BUCKET_DOES_NOT_EXIST_MSG, self.bucket_name) return None -class S3CopyObjectOperator(BaseOperator): +class S3CopyObjectOperator(AwsBaseOperator[S3Hook]): """ Creates a copy of an object that is already stored in S3. @@ -269,30 +285,29 @@ class S3CopyObjectOperator(BaseOperator): It should be omitted when `dest_bucket_key` is provided as a full s3:// url. :param source_version_id: Version ID of the source object (OPTIONAL) - :param aws_conn_id: Connection id of the S3 connection to use - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - - You can provide the following values: - - - False: do not validate SSL certificates. SSL will still be used, - but SSL certificates will not be - verified. - - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :param acl_policy: String specifying the canned ACL policy for the file being uploaded to the S3 bucket. :param meta_data_directive: Whether to `COPY` the metadata from the source object or `REPLACE` it with metadata that's provided in the request. """ - template_fields: Sequence[str] = ( + template_fields: Sequence[str] = aws_template_fields( "source_bucket_key", "dest_bucket_key", "source_bucket_name", "dest_bucket_name", ) + aws_hook_class = S3Hook def __init__( self, @@ -302,8 +317,6 @@ def __init__( source_bucket_name: str | None = None, dest_bucket_name: str | None = None, source_version_id: str | None = None, - aws_conn_id: str | None = "aws_default", - verify: str | bool | None = None, acl_policy: str | None = None, meta_data_directive: str | None = None, **kwargs, @@ -315,14 +328,11 @@ def __init__( self.source_bucket_name = source_bucket_name self.dest_bucket_name = dest_bucket_name self.source_version_id = source_version_id - self.aws_conn_id = aws_conn_id - self.verify = verify self.acl_policy = acl_policy self.meta_data_directive = meta_data_directive def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - s3_hook.copy_object( + self.hook.copy_object( self.source_bucket_key, self.dest_bucket_key, self.source_bucket_name, @@ -336,11 +346,11 @@ def get_openlineage_facets_on_start(self): from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage - dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key( + dest_bucket_name, dest_bucket_key = self.hook.get_s3_bucket_key( self.dest_bucket_name, self.dest_bucket_key, "dest_bucket_name", "dest_bucket_key" ) - source_bucket_name, source_bucket_key = S3Hook.get_s3_bucket_key( + source_bucket_name, source_bucket_key = self.hook.get_s3_bucket_key( self.source_bucket_name, self.source_bucket_key, "source_bucket_name", "source_bucket_key" ) @@ -359,7 +369,7 @@ def get_openlineage_facets_on_start(self): ) -class S3CreateObjectOperator(BaseOperator): +class S3CreateObjectOperator(AwsBaseOperator[S3Hook]): """ Creates a new object from `data` as string or bytes. @@ -382,22 +392,21 @@ class S3CreateObjectOperator(BaseOperator): It should be specified only when `data` is provided as string. :param compression: Type of compression to use, currently only gzip is supported. It can be specified only when `data` is provided as string. - :param aws_conn_id: Connection id of the S3 connection to use - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - - You can provide the following values: - - - False: do not validate SSL certificates. SSL will still be used, - but SSL certificates will not be - verified. - - path/to/cert/bundle.pem: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("s3_bucket", "s3_key", "data") + template_fields: Sequence[str] = aws_template_fields("s3_bucket", "s3_key", "data") + aws_hook_class = S3Hook def __init__( self, @@ -410,8 +419,6 @@ def __init__( acl_policy: str | None = None, encoding: str | None = None, compression: str | None = None, - aws_conn_id: str | None = "aws_default", - verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) @@ -424,16 +431,14 @@ def __init__( self.acl_policy = acl_policy self.encoding = encoding self.compression = compression - self.aws_conn_id = aws_conn_id - self.verify = verify def execute(self, context: Context): - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - - s3_bucket, s3_key = s3_hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key") + s3_bucket, s3_key = self.hook.get_s3_bucket_key( + self.s3_bucket, self.s3_key, "dest_bucket", "dest_key" + ) if isinstance(self.data, str): - s3_hook.load_string( + self.hook.load_string( self.data, s3_key, s3_bucket, @@ -444,13 +449,13 @@ def execute(self, context: Context): self.compression, ) else: - s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy) + self.hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, self.encrypt, self.acl_policy) def get_openlineage_facets_on_start(self): from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.openlineage.extractors import OperatorLineage - bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key") + bucket, key = self.hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key") output_dataset = Dataset( namespace=f"s3://{bucket}", @@ -462,7 +467,7 @@ def get_openlineage_facets_on_start(self): ) -class S3DeleteObjectsOperator(BaseOperator): +class S3DeleteObjectsOperator(AwsBaseOperator[S3Hook]): """ To enable users to delete single object or multiple objects from a bucket using a single HTTP request. @@ -485,21 +490,22 @@ class S3DeleteObjectsOperator(BaseOperator): All objects which LastModified Date is greater than this datetime in the bucket will be deleted. :param to_datetime: less LastModified Date of objects to delete. (templated) All objects which LastModified Date is less than this datetime in the bucket will be deleted. - :param aws_conn_id: Connection id of the S3 connection to use - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used, - but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html """ - template_fields: Sequence[str] = ("keys", "bucket", "prefix", "from_datetime", "to_datetime") + template_fields: Sequence[str] = aws_template_fields( + "keys", "bucket", "prefix", "from_datetime", "to_datetime" + ) + aws_hook_class = S3Hook def __init__( self, @@ -509,8 +515,6 @@ def __init__( prefix: str | None = None, from_datetime: datetime | str | None = None, to_datetime: datetime | str | None = None, - aws_conn_id: str | None = "aws_default", - verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) @@ -519,8 +523,6 @@ def __init__( self.prefix = prefix self.from_datetime = from_datetime self.to_datetime = to_datetime - self.aws_conn_id = aws_conn_id - self.verify = verify self._keys: str | list[str] = "" @@ -546,16 +548,14 @@ def execute(self, context: Context): if isinstance(self.from_datetime, str): self.from_datetime = parser.parse(self.from_datetime).replace(tzinfo=pytz.UTC) - s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - - keys = self.keys or s3_hook.list_keys( + keys = self.keys or self.hook.list_keys( bucket_name=self.bucket, prefix=self.prefix, from_datetime=self.from_datetime, to_datetime=self.to_datetime, ) if keys: - s3_hook.delete_objects(bucket=self.bucket, keys=keys) + self.hook.delete_objects(bucket=self.bucket, keys=keys) self._keys = keys def get_openlineage_facets_on_complete(self, task_instance): @@ -598,7 +598,7 @@ def get_openlineage_facets_on_complete(self, task_instance): ) -class S3FileTransformOperator(BaseOperator): +class S3FileTransformOperator(AwsBaseOperator[S3Hook]): """ Copies data from a source S3 location to a temporary location on the local filesystem. @@ -644,9 +644,10 @@ class S3FileTransformOperator(BaseOperator): :param replace: Replace dest S3 key if it already exists """ - template_fields: Sequence[str] = ("source_s3_key", "dest_s3_key", "script_args") + template_fields: Sequence[str] = aws_template_fields("source_s3_key", "dest_s3_key", "script_args") template_ext: Sequence[str] = () ui_color = "#f9c915" + aws_hook_class = S3Hook def __init__( self, @@ -682,6 +683,7 @@ def execute(self, context: Context): if self.transform_script is None and self.select_expression is None: raise AirflowException("Either transform_script or select_expression must be specified") + # Keep these hooks constructed here since we are using two unique conn_ids source_s3 = S3Hook(aws_conn_id=self.source_aws_conn_id, verify=self.source_verify) dest_s3 = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify) @@ -770,7 +772,7 @@ def get_openlineage_facets_on_start(self): ) -class S3ListOperator(BaseOperator): +class S3ListOperator(AwsBaseOperator[S3Hook]): """ List all objects from the bucket with the given string prefix in name. @@ -785,17 +787,16 @@ class S3ListOperator(BaseOperator): :param prefix: Prefix string to filters the objects whose name begin with such prefix. (templated) :param delimiter: the delimiter marks key hierarchy. (templated) - :param aws_conn_id: The connection ID to use when connecting to S3 storage. - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html :param apply_wildcard: whether to treat '*' as a wildcard or a plain symbol in the prefix. @@ -813,8 +814,9 @@ class S3ListOperator(BaseOperator): ) """ - template_fields: Sequence[str] = ("bucket", "prefix", "delimiter") + template_fields: Sequence[str] = aws_template_fields("bucket", "prefix", "delimiter") ui_color = "#ffd700" + aws_hook_class = S3Hook def __init__( self, @@ -822,8 +824,6 @@ def __init__( bucket: str, prefix: str = "", delimiter: str = "", - aws_conn_id: str | None = "aws_default", - verify: str | bool | None = None, apply_wildcard: bool = False, **kwargs, ): @@ -831,13 +831,9 @@ def __init__( self.bucket = bucket self.prefix = prefix self.delimiter = delimiter - self.aws_conn_id = aws_conn_id - self.verify = verify self.apply_wildcard = apply_wildcard def execute(self, context: Context): - hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - self.log.info( "Getting the list of files from bucket: %s in prefix: %s (Delimiter %s)", self.bucket, @@ -845,7 +841,7 @@ def execute(self, context: Context): self.delimiter, ) - return hook.list_keys( + return self.hook.list_keys( bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, @@ -853,7 +849,7 @@ def execute(self, context: Context): ) -class S3ListPrefixesOperator(BaseOperator): +class S3ListPrefixesOperator(AwsBaseOperator[S3Hook]): """ List all subfolders from the bucket with the given string prefix in name. @@ -868,17 +864,16 @@ class S3ListPrefixesOperator(BaseOperator): :param prefix: Prefix string to filter the subfolders whose name begin with such prefix. (templated) :param delimiter: the delimiter marks subfolder hierarchy. (templated) - :param aws_conn_id: The connection ID to use when connecting to S3 storage. - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html **Example**: @@ -894,8 +889,9 @@ class S3ListPrefixesOperator(BaseOperator): ) """ - template_fields: Sequence[str] = ("bucket", "prefix", "delimiter") + template_fields: Sequence[str] = aws_template_fields("bucket", "prefix", "delimiter") ui_color = "#ffd700" + aws_hook_class = S3Hook def __init__( self, @@ -903,20 +899,14 @@ def __init__( bucket: str, prefix: str, delimiter: str, - aws_conn_id: str | None = "aws_default", - verify: str | bool | None = None, **kwargs, ): super().__init__(**kwargs) self.bucket = bucket self.prefix = prefix self.delimiter = delimiter - self.aws_conn_id = aws_conn_id - self.verify = verify def execute(self, context: Context): - hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - self.log.info( "Getting the list of subfolders from bucket: %s in prefix: %s (Delimiter %s)", self.bucket, @@ -924,4 +914,4 @@ def execute(self, context: Context): self.delimiter, ) - return hook.list_prefixes(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) + return self.hook.list_prefixes(bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py index bb3616597d272..c59bf53ea1c91 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py @@ -23,7 +23,6 @@ import re from collections.abc import Sequence from datetime import datetime, timedelta -from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, cast from airflow.configuration import conf @@ -34,11 +33,13 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger, S3KeyTrigger -from airflow.sensors.base import BaseSensorOperator, poke_mode_only +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.sensors.base import poke_mode_only -class S3KeySensor(BaseSensorOperator): +class S3KeySensor(AwsBaseSensor[S3Hook]): """ Waits for one or multiple keys (a file-like instance on S3) to be present in a S3 bucket. @@ -65,17 +66,6 @@ class S3KeySensor(BaseSensorOperator): def check_fn(files: List, **kwargs) -> bool: return any(f.get('Size', 0) > 1048576 for f in files) - :param aws_conn_id: a reference to the s3 connection - :param verify: Whether to verify SSL certificates for S3 connection. - By default, SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. :param deferrable: Run operator in the deferrable mode :param use_regex: whether to use regex to check bucket :param metadata_keys: List of head_object attributes to gather and send to ``check_fn``. @@ -83,9 +73,18 @@ def check_fn(files: List, **kwargs) -> bool: all available attributes. Default value: "Size". If the requested attribute is not found, the key is still included and the value is None. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html """ - template_fields: Sequence[str] = ("bucket_key", "bucket_name") + template_fields: Sequence[str] = aws_template_fields("bucket_key", "bucket_name") + aws_hook_class = S3Hook def __init__( self, @@ -94,7 +93,6 @@ def __init__( bucket_name: str | None = None, wildcard_match: bool = False, check_fn: Callable[..., bool] | None = None, - aws_conn_id: str | None = "aws_default", verify: str | bool | None = None, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), use_regex: bool = False, @@ -106,14 +104,13 @@ def __init__( self.bucket_key = bucket_key self.wildcard_match = wildcard_match self.check_fn = check_fn - self.aws_conn_id = aws_conn_id self.verify = verify self.deferrable = deferrable self.use_regex = use_regex self.metadata_keys = metadata_keys if metadata_keys else ["Size"] def _check_key(self, key, context: Context): - bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") + bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key") self.log.info("Poking for key : s3://%s/%s", bucket_name, key) """ @@ -199,7 +196,9 @@ def _defer(self) -> None: bucket_key=self.bucket_key, wildcard_match=self.wildcard_match, aws_conn_id=self.aws_conn_id, + region_name=self.region_name, verify=self.verify, + botocore_config=self.botocore_config, poke_interval=self.poke_interval, should_check_fn=bool(self.check_fn), use_regex=self.use_regex, @@ -220,13 +219,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None: elif event["status"] == "error": raise AirflowException(event["message"]) - @cached_property - def hook(self) -> S3Hook: - return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - @poke_mode_only -class S3KeysUnchangedSensor(BaseSensorOperator): +class S3KeysUnchangedSensor(AwsBaseSensor[S3Hook]): """ Return True if inactivity_period has passed with no increase in the number of objects matching prefix. @@ -239,17 +234,7 @@ class S3KeysUnchangedSensor(BaseSensorOperator): :param bucket_name: Name of the S3 bucket :param prefix: The prefix being waited on. Relative path from bucket root level. - :param aws_conn_id: a reference to the s3 connection - :param verify: Whether or not to verify SSL certificates for S3 connection. - By default SSL certificates are verified. - You can provide the following values: - - - ``False``: do not validate SSL certificates. SSL will still be used - (unless use_ssl is False), but SSL certificates will not be - verified. - - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. - You can specify this argument if you want to use a different - CA cert bundle than the one used by botocore. + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html :param inactivity_period: The total seconds of inactivity to designate keys unchanged. Note, this mechanism is not real time and this operator may not return until a poke_interval after this period @@ -261,16 +246,24 @@ class S3KeysUnchangedSensor(BaseSensorOperator): between pokes valid behavior. If true a warning message will be logged when this happens. If false an error will be raised. :param deferrable: Run sensor in the deferrable mode + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html """ - template_fields: Sequence[str] = ("bucket_name", "prefix") + template_fields: Sequence[str] = aws_template_fields("bucket_name", "prefix") + aws_hook_class = S3Hook def __init__( self, *, bucket_name: str, prefix: str, - aws_conn_id: str | None = "aws_default", verify: bool | str | None = None, inactivity_period: float = 60 * 60, min_objects: int = 1, @@ -291,15 +284,9 @@ def __init__( self.inactivity_seconds = 0 self.allow_delete = allow_delete self.deferrable = deferrable - self.aws_conn_id = aws_conn_id self.verify = verify self.last_activity_time: datetime | None = None - @cached_property - def hook(self): - """Returns S3Hook.""" - return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) - def is_keys_unchanged(self, current_objects: set[str]) -> bool: """ Check for new objects after the inactivity_period and update the sensor state accordingly. @@ -382,7 +369,9 @@ def execute(self, context: Context) -> None: inactivity_seconds=self.inactivity_seconds, allow_delete=self.allow_delete, aws_conn_id=self.aws_conn_id, + region_name=self.region_name, verify=self.verify, + botocore_config=self.botocore_config, last_activity_time=self.last_activity_time, ), method_name="execute_complete", diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py index 0be6c992cc8b6..9d2b055fe44ce 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py @@ -53,6 +53,9 @@ def __init__( poke_interval: float = 5.0, should_check_fn: bool = False, use_regex: bool = False, + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, **hook_params: Any, ): super().__init__() @@ -64,6 +67,9 @@ def __init__( self.poke_interval = poke_interval self.should_check_fn = should_check_fn self.use_regex = use_regex + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize S3KeyTrigger arguments and classpath.""" @@ -78,12 +84,20 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "poke_interval": self.poke_interval, "should_check_fn": self.should_check_fn, "use_regex": self.use_regex, + "region_name": self.region_name, + "verify": self.verify, + "botocore_config": self.botocore_config, }, ) @cached_property def hook(self) -> S3Hook: - return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) + return S3Hook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) async def run(self) -> AsyncIterator[TriggerEvent]: """Make an asynchronous connection using S3HookAsync.""" @@ -143,7 +157,9 @@ def __init__( allow_delete: bool = True, aws_conn_id: str | None = "aws_default", last_activity_time: datetime | None = None, + region_name: str | None = None, verify: bool | str | None = None, + botocore_config: dict | None = None, **hook_params: Any, ): super().__init__() @@ -160,8 +176,10 @@ def __init__( self.allow_delete = allow_delete self.aws_conn_id = aws_conn_id self.last_activity_time = last_activity_time - self.verify = verify self.polling_period_seconds = 0 + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config self.hook_params = hook_params def serialize(self) -> tuple[str, dict[str, Any]]: @@ -179,14 +197,21 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "aws_conn_id": self.aws_conn_id, "last_activity_time": self.last_activity_time, "hook_params": self.hook_params, - "verify": self.verify, "polling_period_seconds": self.polling_period_seconds, + "region_name": self.region_name, + "verify": self.verify, + "botocore_config": self.botocore_config, }, ) @cached_property def hook(self) -> S3Hook: - return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify")) + return S3Hook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) async def run(self) -> AsyncIterator[TriggerEvent]: """Make an asynchronous connection using S3Hook.""" diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py index 0195afd6e2a59..b17d4374c4442 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3.py @@ -415,20 +415,19 @@ def test_template_fields(self): class TestS3ListOperator: - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") - def test_execute(self, mock_hook): - mock_hook.return_value.list_keys.return_value = ["TEST1.csv", "TEST2.csv", "TEST3.csv"] - + def test_execute(self): operator = S3ListOperator( task_id="test-s3-list-operator", bucket=BUCKET_NAME, prefix="TEST", delimiter=".csv", ) + operator.hook = mock.MagicMock() + operator.hook.list_keys.return_value = ["TEST1.csv", "TEST2.csv", "TEST3.csv"] files = operator.execute(None) - mock_hook.return_value.list_keys.assert_called_once_with( + operator.hook.list_keys.assert_called_once_with( bucket_name=BUCKET_NAME, prefix="TEST", delimiter=".csv", @@ -447,17 +446,16 @@ def test_template_fields(self): class TestS3ListPrefixesOperator: - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") - def test_execute(self, mock_hook): - mock_hook.return_value.list_prefixes.return_value = ["test/"] - + def test_execute(self): operator = S3ListPrefixesOperator( task_id="test-s3-list-prefixes-operator", bucket=BUCKET_NAME, prefix="test/", delimiter="/" ) + operator.hook = mock.MagicMock() + operator.hook.list_prefixes.return_value = ["test/"] subfolders = operator.execute(None) - mock_hook.return_value.list_prefixes.assert_called_once_with( + operator.hook.list_prefixes.assert_called_once_with( bucket_name=BUCKET_NAME, prefix="test/", delimiter="/" ) assert subfolders == ["test/"] @@ -870,8 +868,7 @@ def test_validate_keys_and_prefix_in_execute(self, keys, prefix, from_datetime, assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test @pytest.mark.parametrize("keys", ("path/data.txt", ["path/data.txt"])) - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") - def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys): + def test_get_openlineage_facets_on_complete_single_object(self, keys): bucket = "testbucket" expected_input = Dataset( namespace=f"s3://{bucket}", @@ -888,14 +885,14 @@ def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys) ) op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys) + op.hook = mock.MagicMock() op.execute(None) lineage = op.get_openlineage_facets_on_complete(None) assert len(lineage.inputs) == 1 assert lineage.inputs[0] == expected_input - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") - def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook): + def test_get_openlineage_facets_on_complete_multiple_objects(self): bucket = "testbucket" keys = ["path/data1.txt", "path/data2.txt"] expected_inputs = [ @@ -928,6 +925,7 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook): ] op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys) + op.hook = mock.MagicMock() op.execute(None) lineage = op.get_openlineage_facets_on_complete(None) diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py index 9c169f86cf0c6..b9f8fb284bcd0 100644 --- a/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py @@ -538,10 +538,10 @@ def test_key_changes(self, current_objects, expected_returns, inactivity_periods assert self.sensor.inactivity_seconds == period time_machine.coordinates.shift(10) - @mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook") - def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine): + def test_poke_succeeds_on_upload_complete(self, time_machine): time_machine.move_to(DEFAULT_DATE) - mock_hook.return_value.list_keys.return_value = {"a"} + self.sensor.hook = mock.MagicMock() + self.sensor.hook.list_keys.return_value = {"a"} assert not self.sensor.poke(dict()) time_machine.coordinates.shift(10) assert not self.sensor.poke(dict()) diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py index 01533d298875b..14c79f1e462ed 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py @@ -46,6 +46,9 @@ def test_serialization(self): "poke_interval": 5.0, "should_check_fn": False, "use_regex": False, + "verify": None, + "region_name": None, + "botocore_config": None, } @pytest.mark.asyncio @@ -106,6 +109,8 @@ def test_serialization(self): "last_activity_time": None, "hook_params": {}, "verify": None, + "region_name": None, + "botocore_config": None, "polling_period_seconds": 0, } diff --git a/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py b/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py index 9539e257aff71..78821ed041f94 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_s3_to_gcs.py @@ -98,9 +98,8 @@ def test_init(self): assert operator.poll_interval == POLL_INTERVAL @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") - def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): + def test_execute(self, gcs_mock_hook, s3_mock_hook): """Test the execute function when the run is successful.""" operator = S3ToGCSOperator( @@ -112,9 +111,9 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): dest_gcs=GCS_PATH_PREFIX, google_impersonation_chain=IMPERSONATION_CHAIN, ) + operator.hook = mock.MagicMock() - s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES - s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES + operator.hook.list_keys.return_value = MOCK_FILES uploaded_files = operator.execute(context={}) gcs_mock_hook.return_value.upload.assert_has_calls( @@ -126,8 +125,8 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): any_order=True, ) - s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) - s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) + operator.hook.list_keys.assert_called_once() + s3_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) gcs_mock_hook.assert_called_once_with( gcp_conn_id=GCS_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -137,9 +136,8 @@ def test_execute(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): assert sorted(MOCK_FILES) == sorted(uploaded_files) @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") - def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_hook): + def test_execute_with_gzip(self, gcs_mock_hook, s3_mock_hook): """Test the execute function when the run is successful.""" operator = S3ToGCSOperator( @@ -152,8 +150,9 @@ def test_execute_with_gzip(self, gcs_mock_hook, s3_one_mock_hook, s3_two_mock_ho gzip=True, ) - s3_one_mock_hook.return_value.list_keys.return_value = MOCK_FILES - s3_two_mock_hook.return_value.list_keys.return_value = MOCK_FILES + operator.hook = mock.MagicMock() + + operator.hook.list_keys.return_value = MOCK_FILES operator.execute(context={}) gcs_mock_hook.assert_called_once_with( @@ -226,13 +225,11 @@ def test_gcs_to_s3_object(self, apply_gcs_prefix, s3_prefix, s3_object, gcs_dest @pytest.mark.parametrize(*PARAMETRIZED_OBJECT_PATHS) @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") def test_execute_apply_gcs_prefix( self, gcs_mock_hook, - s3_one_mock_hook, - s3_two_mock_hook, + s3_mock_hook, apply_gcs_prefix, s3_prefix, s3_object, @@ -249,9 +246,8 @@ def test_execute_apply_gcs_prefix( google_impersonation_chain=IMPERSONATION_CHAIN, apply_gcs_prefix=apply_gcs_prefix, ) - - s3_one_mock_hook.return_value.list_keys.return_value = [s3_prefix + s3_object] - s3_two_mock_hook.return_value.list_keys.return_value = [s3_prefix + s3_object] + operator.hook = mock.MagicMock() + operator.hook.list_keys.return_value = [s3_prefix + s3_object] uploaded_files = operator.execute(context={}) gcs_mock_hook.return_value.upload.assert_has_calls( @@ -261,8 +257,8 @@ def test_execute_apply_gcs_prefix( any_order=True, ) - s3_one_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) - s3_two_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) + operator.hook.list_keys.assert_called_once() + s3_mock_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=None) gcs_mock_hook.assert_called_once_with( gcp_conn_id=GCS_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN, @@ -306,14 +302,12 @@ def test_get_openlineage_facets_on_start( class TestS3ToGoogleCloudStorageOperatorDeferrable: @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.CloudDataTransferServiceHook") @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.S3Hook") - @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook") @mock.patch("airflow.providers.google.cloud.transfers.s3_to_gcs.GCSHook") - def test_execute_deferrable(self, mock_gcs_hook, mock_s3_super_hook, mock_s3_hook, mock_transfer_hook): + def test_execute_deferrable(self, mock_gcs_hook, mock_s3_hook, mock_transfer_hook): mock_gcs_hook.return_value.project_id = PROJECT_ID - mock_list_keys = mock.MagicMock() - mock_list_keys.return_value = MOCK_FILES - mock_s3_super_hook.return_value.list_keys = mock_list_keys + mock_s3_super_hook = mock.MagicMock() + mock_s3_super_hook.list_keys.return_value = MOCK_FILES mock_s3_hook.conn_config = mock.MagicMock( aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, @@ -335,11 +329,13 @@ def test_execute_deferrable(self, mock_gcs_hook, mock_s3_super_hook, mock_s3_hoo deferrable=True, ) + operator.hook = mock_s3_super_hook + with pytest.raises(TaskDeferred) as exception_info: operator.execute(None) - mock_s3_super_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=operator.verify) - mock_list_keys.assert_called_once_with( + mock_s3_hook.assert_called_once_with(aws_conn_id=AWS_CONN_ID, verify=operator.verify) + mock_s3_super_hook.list_keys.assert_called_once_with( bucket_name=S3_BUCKET, prefix=S3_PREFIX, delimiter=S3_DELIMITER, apply_wildcard=False ) mock_create_transfer_job.assert_called_once()