Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def __init__(
job_poll_interval: int | float = 6,
waiter_delay: int = 60,
waiter_max_attempts: int = 75,
resume_glue_job_on_retry: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -168,6 +169,7 @@ def __init__(
self.s3_script_location: str | None = None
self.waiter_delay = waiter_delay
self.waiter_max_attempts = waiter_max_attempts
self.resume_glue_job_on_retry = resume_glue_job_on_retry

@property
def _hook_parameters(self):
Expand Down Expand Up @@ -217,13 +219,30 @@ def execute(self, context: Context):

:return: the current Glue job ID.
"""
self.log.info(
"Initializing AWS Glue Job: %s. Wait for completion: %s",
self.job_name,
self.wait_for_completion,
)
glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs)
self._job_run_id = glue_job_run["JobRunId"]
previous_job_run_id = None
if self.resume_glue_job_on_retry:
ti = context["ti"]
previous_job_run_id = ti.xcom_pull(key="glue_job_run_id", task_ids=ti.task_id)
if previous_job_run_id:
try:
job_run = self.hook.conn.get_job_run(JobName=self.job_name, RunId=previous_job_run_id)
state = job_run.get("JobRun", {}).get("JobRunState")
self.log.info("Previous Glue job_run_id: %s, state: %s", previous_job_run_id, state)
if state in ("RUNNING", "STARTING", "STOPPING"):
self._job_run_id = previous_job_run_id
except Exception:
self.log.warning("Failed to get previous Glue job run state", exc_info=True)

if not self._job_run_id:
self.log.info(
"Initializing AWS Glue Job: %s. Wait for completion: %s",
self.job_name,
self.wait_for_completion,
)
glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs)
self._job_run_id = glue_job_run["JobRunId"]
context["ti"].xcom_push(key="glue_job_run_id", value=self._job_run_id)

glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition),
region_name=self.hook.conn_region_name,
Expand Down
101 changes: 101 additions & 0 deletions providers/amazon/tests/unit/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,107 @@ def test_default_conn_passed_to_hook(self):
)
assert op.hook.aws_conn_id == DEFAULT_CONN

@mock.patch.object(GlueJobHook, "get_conn")
@mock.patch.object(GlueJobHook, "initialize_job")
def test_check_previous_job_id_run_reuse_in_progress(self, mock_initialize_job, mock_get_conn):
"""Test that when resume_glue_job_on_retry=True and previous job is in progress, it is reused."""
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
resume_glue_job_on_retry=True,
wait_for_completion=False,
)

# Mock the context and task instance
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti}

# Simulate previous job_run_id in XCom
previous_job_run_id = "previous_run_12345"
mock_ti.xcom_pull.return_value = previous_job_run_id

# Mock the Glue client to return RUNNING state for the previous job
mock_glue_client = mock.MagicMock()
glue.hook.conn = mock_glue_client
mock_glue_client.get_job_run.return_value = {
"JobRun": {
"JobRunState": "RUNNING",
}
}

# Execute the operator
glue.execute(mock_context)

# Verify that the previous job_run_id was reused
assert glue._job_run_id == previous_job_run_id
# Verify that initialize_job was NOT called
mock_initialize_job.assert_not_called()
# Verify that XCom push was not called for glue_job_run_id (since we reused the previous one)
# Note: xcom_push may be called for other purposes like glue_job_run_details
xcom_calls = [
call for call in mock_ti.xcom_push.call_args_list if call[1].get("key") == "glue_job_run_id"
]
assert len(xcom_calls) == 0, "Should not push new glue_job_run_id when reusing previous one"

@mock.patch.object(GlueJobHook, "get_conn")
@mock.patch.object(GlueJobHook, "initialize_job")
def test_check_previous_job_id_run_new_on_finished(self, mock_initialize_job, mock_get_conn):
"""Test that when previous job is finished, a new job is started and pushed to XCom."""
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
resume_glue_job_on_retry=True,
wait_for_completion=False,
)

# Mock the context and task instance
mock_ti = mock.MagicMock()
mock_context = {"ti": mock_ti}

# Simulate previous job_run_id in XCom
previous_job_run_id = "previous_run_12345"
mock_ti.xcom_pull.return_value = previous_job_run_id

# Mock the Glue client to return SUCCEEDED state for the previous job
mock_glue_client = mock.MagicMock()
glue.hook.conn = mock_glue_client
mock_glue_client.get_job_run.return_value = {
"JobRun": {
"JobRunState": "SUCCEEDED",
}
}

# Mock initialize_job to return a new job run ID
new_job_run_id = "new_run_67890"
mock_initialize_job.return_value = {
"JobRunState": "RUNNING",
"JobRunId": new_job_run_id,
}

# Execute the operator
glue.execute(mock_context)

# Verify that a new job_run_id was created
assert glue._job_run_id == new_job_run_id
# Verify that initialize_job was called
mock_initialize_job.assert_called_once()
# Verify that the new job_run_id was pushed to XCom
xcom_calls = [
call for call in mock_ti.xcom_push.call_args_list if call[1].get("key") == "glue_job_run_id"
]
assert len(xcom_calls) == 1, "Should push new glue_job_run_id"
assert xcom_calls[0][1]["value"] == new_job_run_id


class TestGlueDataQualityOperator:
RULE_SET_NAME = "TestRuleSet"
Expand Down