diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py index 7093120ceafd0..45b34511d5d78 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py @@ -24,6 +24,14 @@ from typing import Any from botocore.exceptions import ClientError +from tenacity import ( + AsyncRetrying, + Retrying, + before_sleep_log, + retry_if_exception, + stop_after_attempt, + wait_exponential, +) from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -51,6 +59,7 @@ class GlueJobHook(AwsBaseHook): :param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None. :param create_job_kwargs: Extra arguments for Glue Job Creation :param update_config: Update job configuration on Glue (default: False) + :param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes. Additional arguments (such as ``aws_conn_id``) may be specified and are passed down to the underlying AwsBaseHook. @@ -80,6 +89,7 @@ def __init__( create_job_kwargs: dict | None = None, update_config: bool = False, job_poll_interval: int | float = 6, + api_retry_args: dict[Any, Any] | None = None, *args, **kwargs, ): @@ -96,6 +106,17 @@ def __init__( self.update_config = update_config self.job_poll_interval = job_poll_interval + self.retry_config: dict[str, Any] = { + "retry": retry_if_exception(self._should_retry_on_error), + "wait": wait_exponential(multiplier=1, min=1, max=60), + "stop": stop_after_attempt(5), + "before_sleep": before_sleep_log(self.log, log_level=20), + "reraise": True, + } + + if api_retry_args: + self.retry_config.update(api_retry_args) + worker_type_exists = "WorkerType" in self.create_job_kwargs num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs @@ -116,6 +137,29 @@ def __init__( kwargs["client_type"] = "glue" super().__init__(*args, **kwargs) + def _should_retry_on_error(self, exception: BaseException) -> bool: + """ + Determine if an exception should trigger a retry. + + :param exception: The exception that occurred + :return: True if the exception should trigger a retry, False otherwise + """ + if isinstance(exception, ClientError): + error_code = exception.response.get("Error", {}).get("Code", "") + retryable_errors = { + "ThrottlingException", + "RequestLimitExceeded", + "ServiceUnavailable", + "InternalFailure", + "InternalServerError", + "TooManyRequestsException", + "RequestTimeout", + "RequestTimeoutException", + "HttpTimeoutException", + } + return error_code in retryable_errors + return False + def create_glue_job_config(self) -> dict: default_command = { "Name": "glueetl", @@ -217,8 +261,21 @@ def get_job_state(self, job_name: str, run_id: str) -> str: :param run_id: The job-run ID of the predecessor job run :return: State of the Glue job """ - job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) - return job_run["JobRun"]["JobRunState"] + for attempt in Retrying(**self.retry_config): + with attempt: + try: + job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True) + return job_run["JobRun"]["JobRunState"] + except ClientError as e: + self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e) + raise + except Exception as e: + self.log.error( + "Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e + ) + raise + # This should never be reached due to reraise=True, but mypy needs it + raise RuntimeError("Unexpected end of retry loop") async def async_get_job_state(self, job_name: str, run_id: str) -> str: """ @@ -226,9 +283,22 @@ async def async_get_job_state(self, job_name: str, run_id: str) -> str: The async version of get_job_state. """ - async with await self.get_async_conn() as client: - job_run = await client.get_job_run(JobName=job_name, RunId=run_id) - return job_run["JobRun"]["JobRunState"] + async for attempt in AsyncRetrying(**self.retry_config): + with attempt: + try: + async with await self.get_async_conn() as client: + job_run = await client.get_job_run(JobName=job_name, RunId=run_id) + return job_run["JobRun"]["JobRunState"] + except ClientError as e: + self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e) + raise + except Exception as e: + self.log.error( + "Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e + ) + raise + # This should never be reached due to reraise=True, but mypy needs it + raise RuntimeError("Unexpected end of retry loop") @cached_property def logs_hook(self): @@ -372,7 +442,7 @@ def _handle_state( ) return None - def has_job(self, job_name) -> bool: + def has_job(self, job_name: str) -> bool: """ Check if the job already exists. @@ -422,6 +492,9 @@ def get_or_create_glue_job(self) -> str | None: :return:Name of the Job """ + if self.job_name is None: + raise ValueError("job_name must be set to get or create a Glue job") + if self.has_job(self.job_name): return self.job_name @@ -441,6 +514,9 @@ def create_or_update_glue_job(self) -> str | None: :return:Name of the Job """ + if self.job_name is None: + raise ValueError("job_name must be set to create or update a Glue job") + config = self.create_glue_job_config() if self.has_job(self.job_name): diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_glue.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_glue.py index de1de82273890..df3e98401ae74 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_glue.py @@ -479,6 +479,127 @@ async def test_async_job_completion_failure(self, get_state_mock: MagicMock): assert get_state_mock.call_count == 3 + @mock.patch.object(GlueJobHook, "conn") + def test_get_job_state_success(self, mock_conn): + hook = GlueJobHook() + job_name = "test_job" + run_id = "test_run_id" + expected_state = "SUCCEEDED" + + mock_conn.get_job_run.return_value = {"JobRun": {"JobRunState": expected_state}} + + result = hook.get_job_state(job_name, run_id) + + assert result == expected_state + mock_conn.get_job_run.assert_called_once_with( + JobName=job_name, RunId=run_id, PredecessorsIncluded=True + ) + + @mock.patch.object(GlueJobHook, "conn") + def test_get_job_state_retry_on_client_error(self, mock_conn): + hook = GlueJobHook() + job_name = "test_job" + run_id = "test_run_id" + expected_state = "SUCCEEDED" + + mock_conn.get_job_run.side_effect = [ + ClientError( + {"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, "get_job_run" + ), + {"JobRun": {"JobRunState": expected_state}}, + ] + + result = hook.get_job_state(job_name, run_id) + + assert result == expected_state + assert mock_conn.get_job_run.call_count == 2 + + @mock.patch.object(GlueJobHook, "conn") + def test_get_job_state_fails_after_all_retries(self, mock_conn): + """Test get_job_state raises exception when all retries are exhausted.""" + hook = GlueJobHook() + job_name = "test_job" + run_id = "test_run_id" + + mock_conn.get_job_run.side_effect = ClientError( + {"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, "get_job_run" + ) + + with pytest.raises(ClientError) as exc_info: + hook.get_job_state(job_name, run_id) + + assert exc_info.value.response["Error"]["Code"] == "ThrottlingException" + assert mock_conn.get_job_run.call_count == 5 + + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_async_get_job_state_success(self, mock_get_async_conn): + hook = GlueJobHook() + job_name = "test_job" + run_id = "test_run_id" + expected_state = "RUNNING" + + mock_client = mock.AsyncMock() + mock_client.get_job_run.return_value = {"JobRun": {"JobRunState": expected_state}} + mock_context = mock.AsyncMock() + mock_context.__aenter__.return_value = mock_client + mock_context.__aexit__.return_value = None + mock_get_async_conn.return_value = mock_context + + result = await hook.async_get_job_state(job_name, run_id) + + assert result == expected_state + mock_client.get_job_run.assert_called_once_with(JobName=job_name, RunId=run_id) + + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_async_get_job_state_retry_on_client_error(self, mock_get_async_conn): + hook = GlueJobHook() + job_name = "test_job" + run_id = "test_run_id" + expected_state = "FAILED" + + mock_client = mock.AsyncMock() + mock_client.get_job_run.side_effect = [ + ClientError( + {"Error": {"Code": "ServiceUnavailable", "Message": "Service temporarily unavailable"}}, + "get_job_run", + ), + {"JobRun": {"JobRunState": expected_state}}, + ] + mock_context = mock.AsyncMock() + mock_context.__aenter__.return_value = mock_client + mock_context.__aexit__.return_value = None + mock_get_async_conn.return_value = mock_context + + result = await hook.async_get_job_state(job_name, run_id) + + assert result == expected_state + assert mock_client.get_job_run.call_count == 2 + + @pytest.mark.asyncio + @mock.patch.object(GlueJobHook, "get_async_conn") + async def test_async_get_job_state_fails_after_all_retries(self, mock_get_async_conn): + hook = GlueJobHook() + job_name = "test_job" + run_id = "test_run_id" + + mock_client = mock.AsyncMock() + mock_client.get_job_run.side_effect = ClientError( + {"Error": {"Code": "ServiceUnavailable", "Message": "Service temporarily unavailable"}}, + "get_job_run", + ) + mock_context = mock.AsyncMock() + mock_context.__aenter__.return_value = mock_client + mock_context.__aexit__.return_value = None + mock_get_async_conn.return_value = mock_context + + with pytest.raises(ClientError) as exc_info: + await hook.async_get_job_state(job_name, run_id) + + assert exc_info.value.response["Error"]["Code"] == "ServiceUnavailable" + assert mock_client.get_job_run.call_count == 5 + class TestGlueDataQualityHook: RUN_ID = "1234" diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py index f2537226183bd..fc5654961def8 100644 --- a/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_hooks_signature.py @@ -58,6 +58,7 @@ "retry_limit", "num_of_dpus", "script_location", + "api_retry_args", }, "S3Hook": {"transfer_config_args", "aws_conn_id", "extra_args"}, }