Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 82 additions & 6 deletions providers/amazon/src/airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
from typing import Any

from botocore.exceptions import ClientError
from tenacity import (
AsyncRetrying,
Retrying,
before_sleep_log,
retry_if_exception,
stop_after_attempt,
wait_exponential,
)

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
Expand Down Expand Up @@ -51,6 +59,7 @@ class GlueJobHook(AwsBaseHook):
:param iam_role_arn: AWS IAM Role ARN for Glue Job Execution, If set `iam_role_name` must equal None.
:param create_job_kwargs: Extra arguments for Glue Job Creation
:param update_config: Update job configuration on Glue (default: False)
:param api_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` & ``tenacity.AsyncRetrying`` classes.

Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
Expand Down Expand Up @@ -80,6 +89,7 @@ def __init__(
create_job_kwargs: dict | None = None,
update_config: bool = False,
job_poll_interval: int | float = 6,
api_retry_args: dict[Any, Any] | None = None,
*args,
**kwargs,
):
Expand All @@ -96,6 +106,17 @@ def __init__(
self.update_config = update_config
self.job_poll_interval = job_poll_interval

self.retry_config: dict[str, Any] = {
"retry": retry_if_exception(self._should_retry_on_error),
"wait": wait_exponential(multiplier=1, min=1, max=60),
"stop": stop_after_attempt(5),
"before_sleep": before_sleep_log(self.log, log_level=20),
"reraise": True,
}

if api_retry_args:
self.retry_config.update(api_retry_args)

worker_type_exists = "WorkerType" in self.create_job_kwargs
num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs

Expand All @@ -116,6 +137,29 @@ def __init__(
kwargs["client_type"] = "glue"
super().__init__(*args, **kwargs)

def _should_retry_on_error(self, exception: BaseException) -> bool:
"""
Determine if an exception should trigger a retry.

:param exception: The exception that occurred
:return: True if the exception should trigger a retry, False otherwise
"""
if isinstance(exception, ClientError):
error_code = exception.response.get("Error", {}).get("Code", "")
retryable_errors = {
"ThrottlingException",
"RequestLimitExceeded",
"ServiceUnavailable",
"InternalFailure",
"InternalServerError",
"TooManyRequestsException",
"RequestTimeout",
"RequestTimeoutException",
"HttpTimeoutException",
}
return error_code in retryable_errors
return False

def create_glue_job_config(self) -> dict:
default_command = {
"Name": "glueetl",
Expand Down Expand Up @@ -217,18 +261,44 @@ def get_job_state(self, job_name: str, run_id: str) -> str:
:param run_id: The job-run ID of the predecessor job run
:return: State of the Glue job
"""
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
return job_run["JobRun"]["JobRunState"]
for attempt in Retrying(**self.retry_config):
with attempt:
try:
job_run = self.conn.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
return job_run["JobRun"]["JobRunState"]
except ClientError as e:
self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e)
raise
except Exception as e:
self.log.error(
"Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e
)
raise
# This should never be reached due to reraise=True, but mypy needs it
raise RuntimeError("Unexpected end of retry loop")

async def async_get_job_state(self, job_name: str, run_id: str) -> str:
"""
Get state of the Glue job; the job state can be running, finished, failed, stopped or timeout.

The async version of get_job_state.
"""
async with await self.get_async_conn() as client:
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
return job_run["JobRun"]["JobRunState"]
async for attempt in AsyncRetrying(**self.retry_config):
with attempt:
try:
async with await self.get_async_conn() as client:
job_run = await client.get_job_run(JobName=job_name, RunId=run_id)
return job_run["JobRun"]["JobRunState"]
except ClientError as e:
self.log.error("Failed to get job state for job %s run %s: %s", job_name, run_id, e)
raise
except Exception as e:
self.log.error(
"Unexpected error getting job state for job %s run %s: %s", job_name, run_id, e
)
raise
# This should never be reached due to reraise=True, but mypy needs it
raise RuntimeError("Unexpected end of retry loop")

@cached_property
def logs_hook(self):
Expand Down Expand Up @@ -372,7 +442,7 @@ def _handle_state(
)
return None

def has_job(self, job_name) -> bool:
def has_job(self, job_name: str) -> bool:
"""
Check if the job already exists.

Expand Down Expand Up @@ -422,6 +492,9 @@ def get_or_create_glue_job(self) -> str | None:

:return:Name of the Job
"""
if self.job_name is None:
raise ValueError("job_name must be set to get or create a Glue job")

if self.has_job(self.job_name):
return self.job_name

Expand All @@ -441,6 +514,9 @@ def create_or_update_glue_job(self) -> str | None:

:return:Name of the Job
"""
if self.job_name is None:
raise ValueError("job_name must be set to create or update a Glue job")

config = self.create_glue_job_config()

if self.has_job(self.job_name):
Expand Down
121 changes: 121 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,127 @@ async def test_async_job_completion_failure(self, get_state_mock: MagicMock):

assert get_state_mock.call_count == 3

@mock.patch.object(GlueJobHook, "conn")
def test_get_job_state_success(self, mock_conn):
hook = GlueJobHook()
job_name = "test_job"
run_id = "test_run_id"
expected_state = "SUCCEEDED"

mock_conn.get_job_run.return_value = {"JobRun": {"JobRunState": expected_state}}

result = hook.get_job_state(job_name, run_id)

assert result == expected_state
mock_conn.get_job_run.assert_called_once_with(
JobName=job_name, RunId=run_id, PredecessorsIncluded=True
)

@mock.patch.object(GlueJobHook, "conn")
def test_get_job_state_retry_on_client_error(self, mock_conn):
hook = GlueJobHook()
job_name = "test_job"
run_id = "test_run_id"
expected_state = "SUCCEEDED"

mock_conn.get_job_run.side_effect = [
ClientError(
{"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, "get_job_run"
),
{"JobRun": {"JobRunState": expected_state}},
]

result = hook.get_job_state(job_name, run_id)

assert result == expected_state
assert mock_conn.get_job_run.call_count == 2

@mock.patch.object(GlueJobHook, "conn")
def test_get_job_state_fails_after_all_retries(self, mock_conn):
"""Test get_job_state raises exception when all retries are exhausted."""
hook = GlueJobHook()
job_name = "test_job"
run_id = "test_run_id"

mock_conn.get_job_run.side_effect = ClientError(
{"Error": {"Code": "ThrottlingException", "Message": "Rate exceeded"}}, "get_job_run"
)

with pytest.raises(ClientError) as exc_info:
hook.get_job_state(job_name, run_id)

assert exc_info.value.response["Error"]["Code"] == "ThrottlingException"
assert mock_conn.get_job_run.call_count == 5

@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "get_async_conn")
async def test_async_get_job_state_success(self, mock_get_async_conn):
hook = GlueJobHook()
job_name = "test_job"
run_id = "test_run_id"
expected_state = "RUNNING"

mock_client = mock.AsyncMock()
mock_client.get_job_run.return_value = {"JobRun": {"JobRunState": expected_state}}
mock_context = mock.AsyncMock()
mock_context.__aenter__.return_value = mock_client
mock_context.__aexit__.return_value = None
mock_get_async_conn.return_value = mock_context

result = await hook.async_get_job_state(job_name, run_id)

assert result == expected_state
mock_client.get_job_run.assert_called_once_with(JobName=job_name, RunId=run_id)

@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "get_async_conn")
async def test_async_get_job_state_retry_on_client_error(self, mock_get_async_conn):
hook = GlueJobHook()
job_name = "test_job"
run_id = "test_run_id"
expected_state = "FAILED"

mock_client = mock.AsyncMock()
mock_client.get_job_run.side_effect = [
ClientError(
{"Error": {"Code": "ServiceUnavailable", "Message": "Service temporarily unavailable"}},
"get_job_run",
),
{"JobRun": {"JobRunState": expected_state}},
]
mock_context = mock.AsyncMock()
mock_context.__aenter__.return_value = mock_client
mock_context.__aexit__.return_value = None
mock_get_async_conn.return_value = mock_context

result = await hook.async_get_job_state(job_name, run_id)

assert result == expected_state
assert mock_client.get_job_run.call_count == 2

@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "get_async_conn")
async def test_async_get_job_state_fails_after_all_retries(self, mock_get_async_conn):
hook = GlueJobHook()
job_name = "test_job"
run_id = "test_run_id"

mock_client = mock.AsyncMock()
mock_client.get_job_run.side_effect = ClientError(
{"Error": {"Code": "ServiceUnavailable", "Message": "Service temporarily unavailable"}},
"get_job_run",
)
mock_context = mock.AsyncMock()
mock_context.__aenter__.return_value = mock_client
mock_context.__aexit__.return_value = None
mock_get_async_conn.return_value = mock_context

with pytest.raises(ClientError) as exc_info:
await hook.async_get_job_state(job_name, run_id)

assert exc_info.value.response["Error"]["Code"] == "ServiceUnavailable"
assert mock_client.get_job_run.call_count == 5


class TestGlueDataQualityHook:
RUN_ID = "1234"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"retry_limit",
"num_of_dpus",
"script_location",
"api_retry_args",
},
"S3Hook": {"transfer_config_args", "aws_conn_id", "extra_args"},
}
Expand Down