From 8e979ea9a27bea4369acb30533b51d485cd64352 Mon Sep 17 00:00:00 2001 From: Henry Chen Date: Sun, 14 Dec 2025 02:56:09 +0800 Subject: [PATCH] feat: Add resume_glue_job_on_retry to GlueJobOperator --- .../providers/amazon/aws/operators/glue.py | 33 ++++-- .../unit/amazon/aws/operators/test_glue.py | 101 ++++++++++++++++++ 2 files changed, 127 insertions(+), 7 deletions(-) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py index 536121de67919..4c014d92db36c 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/glue.py @@ -139,6 +139,7 @@ def __init__( job_poll_interval: int | float = 6, waiter_delay: int = 60, waiter_max_attempts: int = 75, + resume_glue_job_on_retry: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -168,6 +169,7 @@ def __init__( self.s3_script_location: str | None = None self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.resume_glue_job_on_retry = resume_glue_job_on_retry @property def _hook_parameters(self): @@ -217,13 +219,30 @@ def execute(self, context: Context): :return: the current Glue job ID. """ - self.log.info( - "Initializing AWS Glue Job: %s. Wait for completion: %s", - self.job_name, - self.wait_for_completion, - ) - glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs) - self._job_run_id = glue_job_run["JobRunId"] + previous_job_run_id = None + if self.resume_glue_job_on_retry: + ti = context["ti"] + previous_job_run_id = ti.xcom_pull(key="glue_job_run_id", task_ids=ti.task_id) + if previous_job_run_id: + try: + job_run = self.hook.conn.get_job_run(JobName=self.job_name, RunId=previous_job_run_id) + state = job_run.get("JobRun", {}).get("JobRunState") + self.log.info("Previous Glue job_run_id: %s, state: %s", previous_job_run_id, state) + if state in ("RUNNING", "STARTING", "STOPPING"): + self._job_run_id = previous_job_run_id + except Exception: + self.log.warning("Failed to get previous Glue job run state", exc_info=True) + + if not self._job_run_id: + self.log.info( + "Initializing AWS Glue Job: %s. Wait for completion: %s", + self.job_name, + self.wait_for_completion, + ) + glue_job_run = self.hook.initialize_job(self.script_args, self.run_job_kwargs) + self._job_run_id = glue_job_run["JobRunId"] + context["ti"].xcom_push(key="glue_job_run_id", value=self._job_run_id) + glue_job_run_url = GlueJobRunDetailsLink.format_str.format( aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.hook.conn_partition), region_name=self.hook.conn_region_name, diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py index fedf55431a66f..d4e299f26a727 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_glue.py @@ -432,6 +432,107 @@ def test_default_conn_passed_to_hook(self): ) assert op.hook.aws_conn_id == DEFAULT_CONN + @mock.patch.object(GlueJobHook, "get_conn") + @mock.patch.object(GlueJobHook, "initialize_job") + def test_check_previous_job_id_run_reuse_in_progress(self, mock_initialize_job, mock_get_conn): + """Test that when resume_glue_job_on_retry=True and previous job is in progress, it is reused.""" + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_name="my_test_role", + resume_glue_job_on_retry=True, + wait_for_completion=False, + ) + + # Mock the context and task instance + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + + # Simulate previous job_run_id in XCom + previous_job_run_id = "previous_run_12345" + mock_ti.xcom_pull.return_value = previous_job_run_id + + # Mock the Glue client to return RUNNING state for the previous job + mock_glue_client = mock.MagicMock() + glue.hook.conn = mock_glue_client + mock_glue_client.get_job_run.return_value = { + "JobRun": { + "JobRunState": "RUNNING", + } + } + + # Execute the operator + glue.execute(mock_context) + + # Verify that the previous job_run_id was reused + assert glue._job_run_id == previous_job_run_id + # Verify that initialize_job was NOT called + mock_initialize_job.assert_not_called() + # Verify that XCom push was not called for glue_job_run_id (since we reused the previous one) + # Note: xcom_push may be called for other purposes like glue_job_run_details + xcom_calls = [ + call for call in mock_ti.xcom_push.call_args_list if call[1].get("key") == "glue_job_run_id" + ] + assert len(xcom_calls) == 0, "Should not push new glue_job_run_id when reusing previous one" + + @mock.patch.object(GlueJobHook, "get_conn") + @mock.patch.object(GlueJobHook, "initialize_job") + def test_check_previous_job_id_run_new_on_finished(self, mock_initialize_job, mock_get_conn): + """Test that when previous job is finished, a new job is started and pushed to XCom.""" + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_name="my_test_role", + resume_glue_job_on_retry=True, + wait_for_completion=False, + ) + + # Mock the context and task instance + mock_ti = mock.MagicMock() + mock_context = {"ti": mock_ti} + + # Simulate previous job_run_id in XCom + previous_job_run_id = "previous_run_12345" + mock_ti.xcom_pull.return_value = previous_job_run_id + + # Mock the Glue client to return SUCCEEDED state for the previous job + mock_glue_client = mock.MagicMock() + glue.hook.conn = mock_glue_client + mock_glue_client.get_job_run.return_value = { + "JobRun": { + "JobRunState": "SUCCEEDED", + } + } + + # Mock initialize_job to return a new job run ID + new_job_run_id = "new_run_67890" + mock_initialize_job.return_value = { + "JobRunState": "RUNNING", + "JobRunId": new_job_run_id, + } + + # Execute the operator + glue.execute(mock_context) + + # Verify that a new job_run_id was created + assert glue._job_run_id == new_job_run_id + # Verify that initialize_job was called + mock_initialize_job.assert_called_once() + # Verify that the new job_run_id was pushed to XCom + xcom_calls = [ + call for call in mock_ti.xcom_push.call_args_list if call[1].get("key") == "glue_job_run_id" + ] + assert len(xcom_calls) == 1, "Should push new glue_job_run_id" + assert xcom_calls[0][1]["value"] == new_job_run_id + class TestGlueDataQualityOperator: RULE_SET_NAME = "TestRuleSet"