diff --git a/airflow/providers/amazon/aws/hooks/comprehend.py b/airflow/providers/amazon/aws/hooks/comprehend.py new file mode 100644 index 0000000000000..897aaf72ee782 --- /dev/null +++ b/airflow/providers/amazon/aws/hooks/comprehend.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + + +class ComprehendHook(AwsBaseHook): + """ + Interact with AWS Comprehend. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("comprehend") `. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = "comprehend" + super().__init__(*args, **kwargs) diff --git a/airflow/providers/amazon/aws/operators/comprehend.py b/airflow/providers/amazon/aws/operators/comprehend.py new file mode 100644 index 0000000000000..780e227af408a --- /dev/null +++ b/airflow/providers/amazon/aws/operators/comprehend.py @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields +from airflow.utils.timezone import utcnow + +if TYPE_CHECKING: + import boto3 + + from airflow.utils.context import Context + + +class ComprehendBaseOperator(AwsBaseOperator[ComprehendHook]): + """ + This is the base operator for Comprehend Service operators (not supposed to be used directly in DAGs). + + :param input_data_config: The input properties for a PII entities detection job. (templated) + :param output_data_config: Provides `configuration` parameters for the output of PII entity detection + jobs. (templated) + :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM role that grants Amazon Comprehend + read access to your input data. (templated) + :param language_code: The language of the input documents. (templated) + """ + + aws_hook_class = ComprehendHook + + template_fields: Sequence[str] = aws_template_fields( + "input_data_config", "output_data_config", "data_access_role_arn", "language_code" + ) + + template_fields_renderers: dict = {"input_data_config": "json", "output_data_config": "json"} + + def __init__( + self, + input_data_config: dict, + output_data_config: dict, + data_access_role_arn: str, + language_code: str, + **kwargs, + ): + super().__init__(**kwargs) + self.input_data_config = input_data_config + self.output_data_config = output_data_config + self.data_access_role_arn = data_access_role_arn + self.language_code = language_code + + @cached_property + def client(self) -> boto3.client: + """Create and return the Comprehend client.""" + return self.hook.conn + + def execute(self, context: Context): + """Must overwrite in child classes.""" + raise NotImplementedError("Please implement execute() in subclass") + + +class ComprehendStartPiiEntitiesDetectionJobOperator(ComprehendBaseOperator): + """ + Create a comprehend pii entities detection job for a collection of documents. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:ComprehendStartPiiEntitiesDetectionJobOperator` + + :param input_data_config: The input properties for a PII entities detection job. (templated) + :param output_data_config: Provides `configuration` parameters for the output of PII entity detection + jobs. (templated) + :param mode: Specifies whether the output provides the locations (offsets) of PII entities or a file in + which PII entities are redacted. If you set the mode parameter to ONLY_REDACTION. In that case you + must provide a RedactionConfig in start_pii_entities_kwargs. + :param data_access_role_arn: The Amazon Resource Name (ARN) of the IAM role that grants Amazon Comprehend + read access to your input data. (templated) + :param language_code: The language of the input documents. (templated) + :param start_pii_entities_kwargs: Any optional parameters to pass to the job. If JobName is not provided + in start_pii_entities_kwargs, operator will create. + + :param wait_for_completion: Whether to wait for job to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. (default: 60) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 20) + :param deferrable: If True, the operator will wait asynchronously for the job to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + def __init__( + self, + input_data_config: dict, + output_data_config: dict, + mode: str, + data_access_role_arn: str, + language_code: str, + start_pii_entities_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 60, + waiter_max_attempts: int = 20, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__( + input_data_config=input_data_config, + output_data_config=output_data_config, + data_access_role_arn=data_access_role_arn, + language_code=language_code, + **kwargs, + ) + self.mode = mode + self.start_pii_entities_kwargs = start_pii_entities_kwargs or {} + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + def execute(self, context: Context) -> str: + if self.start_pii_entities_kwargs.get("JobName", None) is None: + self.start_pii_entities_kwargs["JobName"] = ( + f"start_pii_entities_detection_job-{int(utcnow().timestamp())}" + ) + + self.log.info( + "Submitting start pii entities detection job '%s'.", self.start_pii_entities_kwargs["JobName"] + ) + job_id = self.client.start_pii_entities_detection_job( + InputDataConfig=self.input_data_config, + OutputDataConfig=self.output_data_config, + Mode=self.mode, + DataAccessRoleArn=self.data_access_role_arn, + LanguageCode=self.language_code, + **self.start_pii_entities_kwargs, + )["JobId"] + + message_description = f"start pii entities detection job {job_id} to complete." + if self.deferrable: + self.log.info("Deferring %s", message_description) + self.defer( + trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger( + job_id=job_id, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: + self.log.info("Waiting for %s", message_description) + self.hook.get_waiter("pii_entities_detection_job_complete").wait( + JobId=job_id, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return job_id + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + if event["status"] != "success": + raise AirflowException("Error while running job: %s", event) + + self.log.info("Comprehend pii entities detection job `%s` complete.", event["job_id"]) + return event["job_id"] diff --git a/airflow/providers/amazon/aws/sensors/comprehend.py b/airflow/providers/amazon/aws/sensors/comprehend.py new file mode 100644 index 0000000000000..8f0e328cbc221 --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/comprehend.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import abc +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class ComprehendBaseSensor(AwsBaseSensor[ComprehendHook]): + """ + General sensor behavior for Amazon Comprehend. + + Subclasses must implement following methods: + - ``get_state()`` + + Subclasses must set the following fields: + - ``INTERMEDIATE_STATES`` + - ``FAILURE_STATES`` + - ``SUCCESS_STATES`` + - ``FAILURE_MESSAGE`` + + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + """ + + aws_hook_class = ComprehendHook + + INTERMEDIATE_STATES: tuple[str, ...] = () + FAILURE_STATES: tuple[str, ...] = () + SUCCESS_STATES: tuple[str, ...] = () + FAILURE_MESSAGE = "" + + ui_color = "#66c3ff" + + def __init__( + self, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs: Any, + ): + super().__init__(**kwargs) + self.deferrable = deferrable + + def poke(self, context: Context, **kwargs) -> bool: + state = self.get_state() + if state in self.FAILURE_STATES: + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + if self.soft_fail: + raise AirflowSkipException(self.FAILURE_MESSAGE) + raise AirflowException(self.FAILURE_MESSAGE) + + return state not in self.INTERMEDIATE_STATES + + @abc.abstractmethod + def get_state(self) -> str: + """Implement in subclasses.""" + + +class ComprehendStartPiiEntitiesDetectionJobCompletedSensor(ComprehendBaseSensor): + """ + Poll the state of the pii entities detection job until it reaches a completed state; fails if the job fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor` + + :param job_id: The id of the Comprehend pii entities detection job. + + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) + :param max_retries: Number of times before returning the current state. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES: tuple[str, ...] = ("IN_PROGRESS",) + FAILURE_STATES: tuple[str, ...] = ("FAILED", "STOP_REQUESTED", "STOPPED") + SUCCESS_STATES: tuple[str, ...] = ("COMPLETED",) + FAILURE_MESSAGE = "Comprehend start pii entities detection job sensor failed." + + template_fields: Sequence[str] = aws_template_fields("job_id") + + def __init__( + self, + *, + job_id: str, + max_retries: int = 75, + poke_interval: int = 120, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.job_id = job_id + self.max_retries = max_retries + self.poke_interval = poke_interval + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=ComprehendPiiEntitiesDetectionJobCompletedTrigger( + job_id=self.job_id, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="poke", + ) + else: + super().execute(context=context) + + def get_state(self) -> str: + return self.hook.conn.describe_pii_entities_detection_job(JobId=self.job_id)[ + "PiiEntitiesDetectionJobProperties" + ]["JobStatus"] diff --git a/airflow/providers/amazon/aws/triggers/comprehend.py b/airflow/providers/amazon/aws/triggers/comprehend.py new file mode 100644 index 0000000000000..7de6650c87090 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/comprehend.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + + +class ComprehendPiiEntitiesDetectionJobCompletedTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a Comprehend pii entities detection job is complete. + + :param job_id: The id of the Comprehend pii entities detection job. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + job_id: str, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = "aws_default", + ) -> None: + super().__init__( + serialized_fields={"job_id": job_id}, + waiter_name="pii_entities_detection_job_complete", + waiter_args={"JobId": job_id}, + failure_message="Comprehend start pii entities detection job failed.", + status_message="Status of Comprehend start pii entities detection job is", + status_queries=["PiiEntitiesDetectionJobProperties.JobStatus"], + return_key="job_id", + return_value=job_id, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return ComprehendHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/waiters/comprehend.json b/airflow/providers/amazon/aws/waiters/comprehend.json new file mode 100644 index 0000000000000..9df82f319ff47 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/comprehend.json @@ -0,0 +1,49 @@ +{ + "version": 2, + "waiters": { + "pii_entities_detection_job_complete": { + "delay": 120, + "maxAttempts": 75, + "operation": "DescribePiiEntitiesDetectionJob", + "acceptors": [ + { + "matcher": "path", + "argument": "PiiEntitiesDetectionJobProperties.JobStatus", + "expected": "SUBMITTED", + "state": "retry" + }, + { + "matcher": "path", + "argument": "PiiEntitiesDetectionJobProperties.JobStatus", + "expected": "IN_PROGRESS", + "state": "retry" + }, + { + "matcher": "path", + "argument": "PiiEntitiesDetectionJobProperties.JobStatus", + "expected": "COMPLETED", + "state": "success" + }, + { + "matcher": "path", + "argument": "PiiEntitiesDetectionJobProperties.JobStatus", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "PiiEntitiesDetectionJobProperties.JobStatus", + "expected": "STOP_REQUESTED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "PiiEntitiesDetectionJobProperties.JobStatus", + "expected": "STOPPED", + "state": "failure" + } + + ] + } + } +} diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 641c315d62b6a..7c06879143ef1 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -166,6 +166,12 @@ integrations: external-doc-url: https://aws.amazon.com/cloudwatch/ logo: /integration-logos/aws/Amazon-CloudWatch_light-bg@4x.png tags: [aws] + - integration-name: Amazon Comprehend + external-doc-url: https://aws.amazon.com/comprehend/ + logo: /integration-logos/aws/Amazon-Comprehend_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/comprehend.rst + tags: [aws] - integration-name: Amazon DataSync external-doc-url: https://aws.amazon.com/datasync/ how-to-guide: @@ -385,6 +391,9 @@ operators: - integration-name: Amazon CloudFormation python-modules: - airflow.providers.amazon.aws.operators.cloud_formation + - integration-name: Amazon Comprehend + python-modules: + - airflow.providers.amazon.aws.operators.comprehend - integration-name: Amazon DataSync python-modules: - airflow.providers.amazon.aws.operators.datasync @@ -470,6 +479,9 @@ sensors: - integration-name: Amazon CloudFormation python-modules: - airflow.providers.amazon.aws.sensors.cloud_formation + - integration-name: Amazon Comprehend + python-modules: + - airflow.providers.amazon.aws.sensors.comprehend - integration-name: AWS Database Migration Service python-modules: - airflow.providers.amazon.aws.sensors.dms @@ -545,6 +557,9 @@ hooks: - integration-name: Amazon Chime python-modules: - airflow.providers.amazon.aws.hooks.chime + - integration-name: Amazon Comprehend + python-modules: + - airflow.providers.amazon.aws.hooks.comprehend - integration-name: Amazon DynamoDB python-modules: - airflow.providers.amazon.aws.hooks.dynamodb @@ -672,6 +687,9 @@ triggers: - integration-name: Amazon Bedrock python-modules: - airflow.providers.amazon.aws.triggers.bedrock + - integration-name: Amazon Comprehend + python-modules: + - airflow.providers.amazon.aws.triggers.comprehend - integration-name: Amazon EC2 python-modules: - airflow.providers.amazon.aws.triggers.ec2 diff --git a/docs/apache-airflow-providers-amazon/operators/comprehend.rst b/docs/apache-airflow-providers-amazon/operators/comprehend.rst new file mode 100644 index 0000000000000..dd79e2df6ae36 --- /dev/null +++ b/docs/apache-airflow-providers-amazon/operators/comprehend.rst @@ -0,0 +1,74 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +================= +Amazon Comprehend +================= + +`Amazon Comprehend `__ uses natural language processing (NLP) to +extract insights about the content of documents. It develops insights by recognizing the entities, key phrases, +language, sentiments, and other common elements in a document. + +Prerequisite Tasks +------------------ + +.. include:: ../_partials/prerequisite_tasks.rst + +Generic Parameters +------------------ + +.. include:: ../_partials/generic_parameters.rst + +Operators +--------- + +.. _howto/operator:ComprehendStartPiiEntitiesDetectionJobOperator: + +Create an Amazon Comprehend Start PII Entities Detection Job +============================================================ + +To create an Amazon Comprehend Start PII Entities Detection Job, you can use +:class:`~airflow.providers.amazon.aws.operators.comprehend.ComprehendStartPiiEntitiesDetectionJobOperator`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_comprehend.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_start_pii_entities_detection_job] + :end-before: [END howto_operator_start_pii_entities_detection_job] + +Sensors +------- + +.. _howto/sensor:ComprehendStartPiiEntitiesDetectionJobCompletedSensor: + +Wait for an Amazon Comprehend Start PII Entities Detection Job +============================================================== + +To wait on the state of an Amazon Comprehend Start PII Entities Detection Job until it reaches a terminal +state you can use +:class:`~airflow.providers.amazon.aws.sensors.comprehend.ComprehendStartPiiEntitiesDetectionJobCompletedSensor`. + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_comprehend.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_start_pii_entities_detection_job] + :end-before: [END howto_sensor_start_pii_entities_detection_job] + +Reference +--------- + +* `AWS boto3 library documentation for Amazon Comprehend `__ diff --git a/docs/integration-logos/aws/Amazon-Comprehend_light-bg@4x.png b/docs/integration-logos/aws/Amazon-Comprehend_light-bg@4x.png new file mode 100644 index 0000000000000..24e8c34962f53 Binary files /dev/null and b/docs/integration-logos/aws/Amazon-Comprehend_light-bg@4x.png differ diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index da04944231852..40b744fc5d1d6 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1177,6 +1177,8 @@ picklable pid pidbox pigcmd +Pii +pii pinecone pinodb Pinot diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 7ad127ac3b370..428570ee58954 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -522,6 +522,8 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest): "airflow.providers.amazon.aws.sensors.ecs.EcsBaseSensor", "airflow.providers.amazon.aws.sensors.eks.EksBaseSensor", "airflow.providers.amazon.aws.transfers.base.AwsToAwsBaseOperator", + "airflow.providers.amazon.aws.operators.comprehend.ComprehendBaseOperator", + "airflow.providers.amazon.aws.sensors.comprehend.ComprehendBaseSensor", } MISSING_EXAMPLES_FOR_CLASSES = { diff --git a/tests/providers/amazon/aws/hooks/test_comprehend.py b/tests/providers/amazon/aws/hooks/test_comprehend.py new file mode 100644 index 0000000000000..fded25e446023 --- /dev/null +++ b/tests/providers/amazon/aws/hooks/test_comprehend.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook + + +class TestComprehendHook: + @pytest.mark.parametrize( + "test_hook, service_name", + [pytest.param(ComprehendHook(), "comprehend", id="comprehend")], + ) + def test_comprehend_hook(self, test_hook, service_name): + comprehend_hook = ComprehendHook() + assert comprehend_hook.conn is not None diff --git a/tests/providers/amazon/aws/operators/test_comprehend.py b/tests/providers/amazon/aws/operators/test_comprehend.py new file mode 100644 index 0000000000000..b970b590adffc --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_comprehend.py @@ -0,0 +1,163 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Generator +from unittest import mock + +import pytest +from moto import mock_aws + +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.operators.comprehend import ( + ComprehendBaseOperator, + ComprehendStartPiiEntitiesDetectionJobOperator, +) +from airflow.utils.types import NOTSET + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection + +INPUT_DATA_CONFIG = { + "S3Uri": "s3://input-data-comprehend/sample_data.txt", + "InputFormat": "ONE_DOC_PER_LINE", +} +OUTPUT_DATA_CONFIG = {"S3Uri": "s3://output-data-comprehend/redacted_output/"} +LANGUAGE_CODE = "en" +ROLE_ARN = "role_arn" + + +class TestComprehendBaseOperator: + @pytest.mark.parametrize("aws_conn_id", [None, NOTSET, "aws_test_conn"]) + @pytest.mark.parametrize("region_name", [None, NOTSET, "ca-central-1"]) + def test_initialize_comprehend_base_operator(self, aws_conn_id, region_name): + op_kw = {"aws_conn_id": aws_conn_id, "region_name": region_name} + op_kw = {k: v for k, v in op_kw.items() if v is not NOTSET} + + comprehend_base_op = ComprehendBaseOperator( + task_id="comprehend_base_operator", + input_data_config=INPUT_DATA_CONFIG, + output_data_config=OUTPUT_DATA_CONFIG, + language_code=LANGUAGE_CODE, + data_access_role_arn=ROLE_ARN, + **op_kw, + ) + + assert comprehend_base_op.aws_conn_id == (aws_conn_id if aws_conn_id is not NOTSET else "aws_default") + assert comprehend_base_op.region_name == (region_name if region_name is not NOTSET else None) + + @mock.patch.object(ComprehendBaseOperator, "hook", new_callable=mock.PropertyMock) + def test_initialize_comprehend_base_operator_hook(self, comprehend_base_operator_mock_hook): + comprehend_base_op = ComprehendBaseOperator( + task_id="comprehend_base_operator", + input_data_config=INPUT_DATA_CONFIG, + output_data_config=OUTPUT_DATA_CONFIG, + language_code=LANGUAGE_CODE, + data_access_role_arn=ROLE_ARN, + ) + mocked_hook = mock.MagicMock(name="MockHook") + mocked_client = mock.MagicMock(name="MockClient") + mocked_hook.conn = mocked_client + comprehend_base_operator_mock_hook.return_value = mocked_hook + assert comprehend_base_op.client == mocked_client + comprehend_base_operator_mock_hook.assert_called_once() + + +class TestComprehendStartPiiEntitiesDetectionJobOperator: + JOB_ID = "random-job-id-1234567" + MODE = "ONLY_REDACTION" + JOB_NAME = "TEST_START_PII_ENTITIES_DETECTION_JOB-1" + DEFAULT_JOB_NAME_STARTS_WITH = "start_pii_entities_detection_job" + REDACTION_CONFIG = {"PiiEntityTypes": ["NAME", "ADDRESS"], "MaskMode": "REPLACE_WITH_PII_ENTITY_TYPE"} + + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(ComprehendHook, "conn") as _conn: + _conn.start_pii_entities_detection_job.return_value = {"JobId": self.JOB_ID} + yield _conn + + @pytest.fixture + def comprehend_hook(self) -> Generator[ComprehendHook, None, None]: + with mock_aws(): + hook = ComprehendHook(aws_conn_id="aws_default") + yield hook + + def setup_method(self): + self.operator = ComprehendStartPiiEntitiesDetectionJobOperator( + task_id="start_pii_entities_detection_job", + input_data_config=INPUT_DATA_CONFIG, + output_data_config=OUTPUT_DATA_CONFIG, + data_access_role_arn=ROLE_ARN, + mode=self.MODE, + language_code=LANGUAGE_CODE, + start_pii_entities_kwargs={"JobName": self.JOB_NAME, "RedactionConfig": self.REDACTION_CONFIG}, + ) + self.operator.defer = mock.MagicMock() + + def test_init(self): + assert self.operator.input_data_config == INPUT_DATA_CONFIG + assert self.operator.output_data_config == OUTPUT_DATA_CONFIG + assert self.operator.data_access_role_arn == ROLE_ARN + assert self.operator.mode == self.MODE + assert self.operator.language_code == LANGUAGE_CODE + assert self.operator.start_pii_entities_kwargs.get("JobName") == self.JOB_NAME + assert self.operator.start_pii_entities_kwargs.get("RedactionConfig") == self.REDACTION_CONFIG + + @mock.patch.object(ComprehendHook, "conn") + def test_start_pii_entities_detection_job_name_starts_with_service_name(self, comprehend_mock_conn): + self.op = ComprehendStartPiiEntitiesDetectionJobOperator( + task_id="start_pii_entities_detection_job", + input_data_config=INPUT_DATA_CONFIG, + output_data_config=OUTPUT_DATA_CONFIG, + data_access_role_arn=ROLE_ARN, + mode=self.MODE, + language_code=LANGUAGE_CODE, + start_pii_entities_kwargs={"RedactionConfig": self.REDACTION_CONFIG}, + ) + self.op.wait_for_completion = False + self.op.execute({}) + assert self.op.start_pii_entities_kwargs.get("JobName").startswith(self.DEFAULT_JOB_NAME_STARTS_WITH) + comprehend_mock_conn.start_pii_entities_detection_job.assert_called_once_with( + InputDataConfig=INPUT_DATA_CONFIG, + OutputDataConfig=OUTPUT_DATA_CONFIG, + Mode=self.MODE, + DataAccessRoleArn=ROLE_ARN, + LanguageCode=LANGUAGE_CODE, + RedactionConfig=self.REDACTION_CONFIG, + JobName=self.op.start_pii_entities_kwargs.get("JobName"), + ) + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(ComprehendHook, "get_waiter") + def test_start_pii_entities_detection_job_wait_combinations( + self, _, wait_for_completion, deferrable, mock_conn, comprehend_hook + ): + self.operator.wait_for_completion = wait_for_completion + self.operator.deferrable = deferrable + + response = self.operator.execute({}) + + assert response == self.JOB_ID + assert comprehend_hook.get_waiter.call_count == wait_for_completion + assert self.operator.defer.call_count == deferrable diff --git a/tests/providers/amazon/aws/sensors/test_comprehend.py b/tests/providers/amazon/aws/sensors/test_comprehend.py new file mode 100644 index 0000000000000..e066349031140 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_comprehend.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.sensors.comprehend import ( + ComprehendStartPiiEntitiesDetectionJobCompletedSensor, +) + + +class TestComprehendStartPiiEntitiesDetectionJobCompletedSensor: + SENSOR = ComprehendStartPiiEntitiesDetectionJobCompletedSensor + + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_pii_entities_detection_job_sensor", + job_id="job_id", + poke_interval=5, + max_retries=1, + ) + self.sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None) + + def test_base_aws_op_attributes(self): + op = self.SENSOR(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = self.SENSOR( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize("state", SENSOR.SUCCESS_STATES) + @mock.patch.object(ComprehendHook, "conn") + def test_poke_success_state(self, mock_conn, state): + mock_conn.describe_pii_entities_detection_job.return_value = { + "PiiEntitiesDetectionJobProperties": {"JobStatus": state} + } + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", SENSOR.INTERMEDIATE_STATES) + @mock.patch.object(ComprehendHook, "conn") + def test_intermediate_state(self, mock_conn, state): + mock_conn.describe_pii_entities_detection_job.return_value = { + "PiiEntitiesDetectionJobProperties": {"JobStatus": state} + } + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", SENSOR.FAILURE_STATES) + @mock.patch.object(ComprehendHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.describe_pii_entities_detection_job.return_value = { + "PiiEntitiesDetectionJobProperties": {"JobStatus": state} + } + sensor = self.SENSOR(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail) + + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): + sensor.poke({}) diff --git a/tests/providers/amazon/aws/triggers/test_comprehend.py b/tests/providers/amazon/aws/triggers/test_comprehend.py new file mode 100644 index 0000000000000..1c52aa8810175 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_comprehend.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.triggers.comprehend import ComprehendPiiEntitiesDetectionJobCompletedTrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.test_waiter import assert_expected_waiter_type + +BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.comprehend." + + +class TestBaseComprehendTrigger: + EXPECTED_WAITER_NAME: str | None = None + JOB_ID: str | None = None + + def test_setup(self): + # Ensure that all subclasses have an expected waiter name set. + if self.__class__.__name__ != "TestBaseComprehendTrigger": + assert isinstance(self.EXPECTED_WAITER_NAME, str) + assert isinstance(self.JOB_ID, str) + + +class TestComprehendPiiEntitiesDetectionJobCompletedTrigger(TestBaseComprehendTrigger): + EXPECTED_WAITER_NAME = "pii_entities_detection_job_complete" + JOB_ID = "job_id" + + def test_serialization(self): + """Assert that arguments and classpath are correctly serialized.""" + trigger = ComprehendPiiEntitiesDetectionJobCompletedTrigger(job_id=self.JOB_ID) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "ComprehendPiiEntitiesDetectionJobCompletedTrigger" + assert kwargs.get("job_id") == self.JOB_ID + + @pytest.mark.asyncio + @mock.patch.object(ComprehendHook, "get_waiter") + @mock.patch.object(ComprehendHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = ComprehendPiiEntitiesDetectionJobCompletedTrigger(job_id=self.JOB_ID) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "job_id": self.JOB_ID}) + assert_expected_waiter_type(mock_get_waiter, self.EXPECTED_WAITER_NAME) + mock_get_waiter().wait.assert_called_once() diff --git a/tests/providers/amazon/aws/waiters/test_comprehend.py b/tests/providers/amazon/aws/waiters/test_comprehend.py new file mode 100644 index 0000000000000..a514ea198f6f5 --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_comprehend.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import boto3 +import botocore +import pytest + +from airflow.providers.amazon.aws.hooks.comprehend import ComprehendHook +from airflow.providers.amazon.aws.sensors.comprehend import ( + ComprehendStartPiiEntitiesDetectionJobCompletedSensor, +) + + +class TestComprehendCustomWaiters: + def test_service_waiters(self): + assert "pii_entities_detection_job_complete" in ComprehendHook().list_waiters() + + +class TestComprehendCustomWaitersBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("comprehend") + monkeypatch.setattr(ComprehendHook, "conn", self.client) + + +class TestComprehendStartPiiEntitiesDetectionJobCompleteWaiter(TestComprehendCustomWaitersBase): + WAITER_NAME = "pii_entities_detection_job_complete" + + @pytest.fixture + def mock_get_job(self): + with mock.patch.object(self.client, "describe_pii_entities_detection_job") as mock_getter: + yield mock_getter + + @pytest.mark.parametrize("state", ComprehendStartPiiEntitiesDetectionJobCompletedSensor.SUCCESS_STATES) + def test_pii_entities_detection_job_complete(self, state, mock_get_job): + mock_get_job.return_value = {"PiiEntitiesDetectionJobProperties": {"JobStatus": state}} + + ComprehendHook().get_waiter(self.WAITER_NAME).wait(JobId="job_id") + + @pytest.mark.parametrize("state", ComprehendStartPiiEntitiesDetectionJobCompletedSensor.FAILURE_STATES) + def test_pii_entities_detection_job_failed(self, state, mock_get_job): + mock_get_job.return_value = {"PiiEntitiesDetectionJobProperties": {"JobStatus": state}} + + with pytest.raises(botocore.exceptions.WaiterError): + ComprehendHook().get_waiter(self.WAITER_NAME).wait(JobId="job_id") + + def test_pii_entities_detection_job_wait(self, mock_get_job): + wait = {"PiiEntitiesDetectionJobProperties": {"JobStatus": "IN_PROGRESS"}} + success = {"PiiEntitiesDetectionJobProperties": {"JobStatus": "COMPLETED"}} + mock_get_job.side_effect = [wait, wait, success] + + ComprehendHook().get_waiter(self.WAITER_NAME).wait( + JobId="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} + ) diff --git a/tests/system/providers/amazon/aws/example_comprehend.py b/tests/system/providers/amazon/aws/example_comprehend.py new file mode 100644 index 0000000000000..58e34329b67f7 --- /dev/null +++ b/tests/system/providers/amazon/aws/example_comprehend.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from datetime import datetime + +from airflow import DAG +from airflow.decorators import task_group +from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.operators.comprehend import ComprehendStartPiiEntitiesDetectionJobOperator +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3CreateObjectOperator, + S3DeleteBucketOperator, +) +from airflow.providers.amazon.aws.sensors.comprehend import ( + ComprehendStartPiiEntitiesDetectionJobCompletedSensor, +) +from airflow.utils.trigger_rule import TriggerRule +from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder + +ROLE_ARN_KEY = "ROLE_ARN" +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() + +DAG_ID = "example_comprehend" +INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB = "start-pii-entities-detection-job/sample_data.txt" + +SAMPLE_DATA = { + "username": "bob1234", + "name": "Bob", + "sex": "M", + "address": "1773 Raymond Ville Suite 682", + "mail": "test@hotmail.com", +} + + +@task_group +def pii_entities_detection_job_workflow(): + # [START howto_operator_start_pii_entities_detection_job] + start_pii_entities_detection_job = ComprehendStartPiiEntitiesDetectionJobOperator( + task_id="start_pii_entities_detection_job", + input_data_config=input_data_configurations, + output_data_config=output_data_configurations, + mode="ONLY_REDACTION", + data_access_role_arn=test_context[ROLE_ARN_KEY], + language_code="en", + start_pii_entities_kwargs=pii_entities_kwargs, + ) + # [END howto_operator_start_pii_entities_detection_job] + start_pii_entities_detection_job.wait_for_completion = False + + # [START howto_sensor_start_pii_entities_detection_job] + await_start_pii_entities_detection_job = ComprehendStartPiiEntitiesDetectionJobCompletedSensor( + task_id="await_start_pii_entities_detection_job", job_id=start_pii_entities_detection_job.output + ) + # [END howto_sensor_start_pii_entities_detection_job] + + chain(start_pii_entities_detection_job, await_start_pii_entities_detection_job) + + +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + env_id = test_context["ENV_ID"] + bucket_name = f"{env_id}-comprehend" + input_data_configurations = { + "S3Uri": f"s3://{bucket_name}/{INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB}", + "InputFormat": "ONE_DOC_PER_LINE", + } + output_data_configurations = {"S3Uri": f"s3://{bucket_name}/redacted_output/"} + pii_entities_kwargs = { + "RedactionConfig": { + "PiiEntityTypes": ["NAME", "ADDRESS"], + "MaskMode": "REPLACE_WITH_PII_ENTITY_TYPE", + } + } + + create_bucket = S3CreateBucketOperator( + task_id="create_bucket", + bucket_name=bucket_name, + ) + + upload_sample_data = S3CreateObjectOperator( + task_id="upload_sample_data", + s3_bucket=bucket_name, + s3_key=INPUT_S3_KEY_START_PII_ENTITIES_DETECTION_JOB, + data=json.dumps(SAMPLE_DATA), + ) + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + trigger_rule=TriggerRule.ALL_DONE, + bucket_name=bucket_name, + force_delete=True, + ) + + chain( + # TEST SETUP + test_context, + create_bucket, + upload_sample_data, + # TEST BODY + pii_entities_detection_job_workflow(), + # TEST TEARDOWN + delete_bucket, + ) + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)