diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst index 6535e974519b4..792b42e7498aa 100644 --- a/airflow/providers/amazon/CHANGELOG.rst +++ b/airflow/providers/amazon/CHANGELOG.rst @@ -19,6 +19,13 @@ Changelog --------- +2.1.0 +..... + +Bug Fixes +~~~~~~~~~ +* ``AWS DataSync default polling adjusted from 5s to 30s (#11011)`` + 2.0.0 ..... diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py index ec92bd739d0fa..d157f92f016e8 100644 --- a/airflow/providers/amazon/aws/hooks/datasync.py +++ b/airflow/providers/amazon/aws/hooks/datasync.py @@ -36,7 +36,7 @@ class AWSDataSyncHook(AwsBaseHook): :class:`~airflow.providers.amazon.aws.operators.datasync.AWSDataSyncOperator` :param wait_interval_seconds: Time to wait between two - consecutive calls to check TaskExecution status. Defaults to 5 seconds. + consecutive calls to check TaskExecution status. Defaults to 30 seconds. :type wait_interval_seconds: Optional[int] :raises ValueError: If wait_interval_seconds is not between 0 and 15*60 seconds. """ @@ -52,7 +52,7 @@ class AWSDataSyncHook(AwsBaseHook): TASK_EXECUTION_FAILURE_STATES = ("ERROR",) TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",) - def __init__(self, wait_interval_seconds: int = 5, *args, **kwargs) -> None: + def __init__(self, wait_interval_seconds: int = 30, *args, **kwargs) -> None: super().__init__(client_type='datasync', *args, **kwargs) # type: ignore[misc] self.locations: list = [] self.tasks: list = [] @@ -279,7 +279,7 @@ def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]: return task_description["CurrentTaskExecutionArn"] return None - def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 2 * 180) -> bool: + def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int = 60) -> bool: """ Wait for Task Execution status to be complete (SUCCESS/ERROR). The ``task_execution_arn`` must exist, or a boto3 ClientError will be raised. diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py index 6c88eb1c339cd..88381a85edc40 100644 --- a/airflow/providers/amazon/aws/operators/datasync.py +++ b/airflow/providers/amazon/aws/operators/datasync.py @@ -21,7 +21,7 @@ import random from typing import List, Optional -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook @@ -46,6 +46,9 @@ class AWSDataSyncOperator(BaseOperator): :param wait_interval_seconds: Time to wait between two consecutive calls to check TaskExecution status. :type wait_interval_seconds: int + :param max_iterations: Maximum number of + consecutive calls to check TaskExecution status. + :type max_iterations: int :param task_arn: AWS DataSync TaskArn to use. If None, then this operator will attempt to either search for an existing Task or attempt to create a new Task. :type task_arn: str @@ -128,7 +131,8 @@ def __init__( self, *, aws_conn_id: str = "aws_default", - wait_interval_seconds: int = 5, + wait_interval_seconds: int = 30, + max_iterations: int = 60, task_arn: Optional[str] = None, source_location_uri: Optional[str] = None, destination_location_uri: Optional[str] = None, @@ -147,6 +151,7 @@ def __init__( # Assignments self.aws_conn_id = aws_conn_id self.wait_interval_seconds = wait_interval_seconds + self.max_iterations = max_iterations self.task_arn = task_arn @@ -355,8 +360,14 @@ def _execute_datasync_task(self) -> None: # Wait for task execution to complete self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn) - result = hook.wait_for_task_execution(self.task_execution_arn) + try: + result = hook.wait_for_task_execution(self.task_execution_arn, max_iterations=self.max_iterations) + except (AirflowTaskTimeout, AirflowException) as e: + self.log.error('Cancelling TaskExecution after Exception: %s', e) + self._cancel_datasync_task_execution() + raise self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn) + task_execution_description = hook.describe_task_execution(task_execution_arn=self.task_execution_arn) self.log.info("task_execution_description=%s", task_execution_description) @@ -371,7 +382,7 @@ def _execute_datasync_task(self) -> None: if not result: raise AirflowException(f"Failed TaskExecutionArn {self.task_execution_arn}") - def on_kill(self) -> None: + def _cancel_datasync_task_execution(self): """Cancel the submitted DataSync task.""" hook = self.get_hook() if self.task_execution_arn: @@ -379,6 +390,10 @@ def on_kill(self) -> None: hook.cancel_task_execution(task_execution_arn=self.task_execution_arn) self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn) + def on_kill(self): + self.log.error('Cancelling TaskExecution after task was killed') + self._cancel_datasync_task_execution() + def _delete_datasync_task(self) -> None: """Deletes an AWS DataSync Task.""" if not self.task_arn: diff --git a/tests/providers/amazon/aws/operators/test_datasync.py b/tests/providers/amazon/aws/operators/test_datasync.py index 587196cba1c13..6a4f9111ca7ae 100644 --- a/tests/providers/amazon/aws/operators/test_datasync.py +++ b/tests/providers/amazon/aws/operators/test_datasync.py @@ -710,7 +710,7 @@ def test_killed_task(self, mock_wait, mock_get_conn): # ### Begin tests: # Kill the task when doing wait_for_task_execution - def kill_task(*args): + def kill_task(*args, **kwargs): self.datasync.on_kill() return True