From 741dafb962427e7488017d25ddf537c8577fb28c Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sat, 21 Jun 2025 09:47:39 +0900 Subject: [PATCH 01/19] add: DataprocSubmitJobTrigger --- .../google/cloud/operators/dataproc.py | 34 +++++ .../google/cloud/triggers/dataproc.py | 120 ++++++++++++++++++ 2 files changed, 154 insertions(+) diff --git a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py index 4b47aef0e2d62..9fbb0c0c9aa2e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py @@ -74,6 +74,20 @@ from airflow.utils.context import Context +try: + from airflow.triggers.base import StartTriggerArgs +except ImportError: + # TODO: Remove this when min airflow version is 2.10.0 for standard provider + @dataclass + class StartTriggerArgs: # type: ignore[no-redef] + """Arguments required for start task execution from triggerer.""" + + trigger_cls: str + next_method: str + trigger_kwargs: dict[str, Any] | None = None + next_kwargs: dict[str, Any] | None = None + timeout: timedelta | None = None + class PreemptibilityType(Enum): """Contains possible Type values of Preemptibility applicable for every secondary worker of Cluster.""" @@ -1831,6 +1845,15 @@ class DataprocSubmitJobOperator(GoogleCloudBaseOperator): operator_extra_links = (DataprocJobLink(),) + start_trigger_args = StartTriggerArgs( + trigger_cls="airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger", + trigger_kwargs={}, + next_method="execute_complete", + next_kwargs=None, + timeout=None, + ) + start_from_trigger = False + def __init__( self, *, @@ -1845,6 +1868,7 @@ def __init__( impersonation_chain: str | Sequence[str] | None = None, asynchronous: bool = False, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + start_from_trigger: bool = False, polling_interval_seconds: int = 10, cancel_on_kill: bool = True, wait_timeout: int | None = None, @@ -1877,6 +1901,16 @@ def __init__( self.wait_timeout = wait_timeout self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info self.openlineage_inject_transport_info = openlineage_inject_transport_info + self.start_trigger_args.trigger_kwargs = { + "project_id": self.project_id, + "region": self.region, + "job": self.job, + "request_id": self.request_id, + "retry": self.retry, + "timeout": self.timeout, + "metadata": self.metadata, + } + self.start_from_trigger = start_from_trigger def execute(self, context: Context): self.log.info("Submitting job") diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 80afd381cc389..ec5ae5b4694ba 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -26,6 +26,8 @@ from typing import TYPE_CHECKING, Any from google.api_core.exceptions import NotFound +from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault +from google.api_core.retry import Retry from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException @@ -187,6 +189,124 @@ async def run(self): raise e +class DataprocSubmitJobTrigger(DataprocBaseTrigger): + """DataprocSubmitJobTrigger runs on the trigger worker to perform Build operation.""" + + def __init__( + self, + region: str, + job: dict, + project_id: str = PROVIDE_PROJECT_ID, + request_id: str | None = None, + retry: Retry | _MethodDefault = DEFAULT, + timeout: float | None = None, + metadata: Sequence[tuple[str, str]] = (), + **kwargs, + ): + super().__init__(**kwargs) + self.project_id = project_id + self.region = region + self.job = job + self.request_id = request_id + self.retry = retry + self.timeout = timeout + self.metadata = metadata + + def serialize(self): + return ( + "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger", + { + "project_id": self.project_id, + "region": self.region, + "job": self.job, + "request_id": self.request_id, + "retry": self.retry, + "timeout": self.timeout, + "metadata": self.metadata, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + "polling_interval_seconds": self.polling_interval_seconds, + "cancel_on_kill": self.cancel_on_kill, + }, + ) + + @provide_session + def get_task_instance(self, session: Session) -> TaskInstance: + """ + Get the task instance for the current task. + + :param session: Sqlalchemy session + """ + query = session.query(TaskInstance).filter( + TaskInstance.dag_id == self.task_instance.dag_id, + TaskInstance.task_id == self.task_instance.task_id, + TaskInstance.run_id == self.task_instance.run_id, + TaskInstance.map_index == self.task_instance.map_index, + ) + task_instance = query.one_or_none() + if task_instance is None: + raise AirflowException( + "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_instance + + def safe_to_cancel(self) -> bool: + """ + Whether it is safe to cancel the external job which is being executed by this trigger. + + This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped. + Because in those cases, we should NOT cancel the external job. + """ + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + return task_instance.state != TaskInstanceState.DEFERRED + + async def run(self): + try: + # Create a new Dataproc job + job = self.get_sync_hook().submit_job( + project_id=self.project_id, + region=self.region, + job=self.job, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + self.job_id = job.reference.job_id + while True: + job = await self.get_async_hook().get_job( + project_id=self.project_id, region=self.region, job_id=self.job_id + ) + state = job.status.state + self.log.info("Dataproc job: %s is in state: %s", self.job_id, state) + if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): + break + await asyncio.sleep(self.polling_interval_seconds) + yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) + except asyncio.CancelledError: + self.log.info("Task got cancelled.") + try: + if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + self.log.info( + "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" + " in deferred state." + ) + self.log.info("Cancelling the job: %s", self.job_id) + self.get_sync_hook().cancel_job( + job_id=self.job_id, project_id=self.project_id, region=self.region + ) + self.log.info("Job: %s is cancelled", self.job_id) + yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING}) + except Exception as e: + self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) + raise e + + class DataprocClusterTrigger(DataprocBaseTrigger): """ DataprocClusterTrigger run on the trigger worker to perform create Build operation. From fcc4d930617dfa5b9fd233dd0a6f9693d5437d83 Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 22 Jun 2025 10:43:29 +0900 Subject: [PATCH 02/19] add: unit test --- .../google/cloud/triggers/dataproc.py | 25 +-- .../google/cloud/triggers/test_dataproc.py | 211 ++++++++++++++++++ 2 files changed, 223 insertions(+), 13 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index ec5ae5b4694ba..bdd524e7556b9 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -194,9 +194,7 @@ class DataprocSubmitJobTrigger(DataprocBaseTrigger): def __init__( self, - region: str, job: dict, - project_id: str = PROVIDE_PROJECT_ID, request_id: str | None = None, retry: Retry | _MethodDefault = DEFAULT, timeout: float | None = None, @@ -204,8 +202,6 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.project_id = project_id - self.region = region self.job = job self.request_id = request_id self.retry = retry @@ -265,18 +261,21 @@ def safe_to_cancel(self) -> bool: task_instance = self.get_task_instance() # type: ignore[call-arg] return task_instance.state != TaskInstanceState.DEFERRED + def submit_job(self): + return self.get_sync_hook().submit_job( + project_id=self.project_id, + region=self.region, + job=self.job, + request_id=self.request_id, + retry=self.retry, + timeout=self.timeout, + metadata=self.metadata, + ) + async def run(self): try: # Create a new Dataproc job - job = self.get_sync_hook().submit_job( - project_id=self.project_id, - region=self.region, - job=self.job, - request_id=self.request_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) + job = self.submit_job() self.job_id = job.reference.job_id while True: job = await self.get_async_hook().get_job( diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index 385572e8493ad..6a61ab0333647 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -23,6 +23,7 @@ from unittest import mock import pytest +from google.api_core.retry import Retry from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from google.protobuf.any_pb2 import Any from google.rpc.status_pb2 import Status @@ -31,6 +32,7 @@ DataprocBatchTrigger, DataprocClusterTrigger, DataprocOperationTrigger, + DataprocSubmitJobTrigger, DataprocSubmitTrigger, ) from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType @@ -51,6 +53,8 @@ TEST_GCP_CONN_ID = "google_cloud_default" TEST_OPERATION_NAME = "name" TEST_JOB_ID = "test-job-id" +RETRY = mock.MagicMock(Retry) +METADATA = [("key", "value")] @pytest.fixture @@ -130,6 +134,22 @@ def submit_trigger(): ) +@pytest.fixture +def submit_job_trigger(): + return DataprocSubmitJobTrigger( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + job={}, + request_id=None, + retry=RETRY, + timeout=None, + metadata=METADATA, + gcp_conn_id=TEST_GCP_CONN_ID, + polling_interval_seconds=TEST_POLL_INTERVAL, + cancel_on_kill=True + ) + + @pytest.fixture def async_get_batch(): def func(**kwargs): @@ -632,3 +652,194 @@ async def test_submit_trigger_run_cancelled( # Clean up the generator await async_gen.aclose() + + +@pytest.mark.db_test +class TestDataprocSubmitJobTrigger: + def test_submit_job_trigger_serialization(self, submit_job_trigger): + """Test that the DataprocSubmitJobTrigger serializes its configuration correctly.""" + classpath, kwargs = submit_job_trigger.serialize() + assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger" + assert kwargs == { + "project_id": TEST_PROJECT_ID, + "region": TEST_REGION, + "job": {}, + "request_id": None, + "retry": RETRY, + "timeout": None, + "metadata": METADATA, + "gcp_conn_id": TEST_GCP_CONN_ID, + "polling_interval_seconds": TEST_POLL_INTERVAL, + "cancel_on_kill": True, + "impersonation_chain": None, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") + async def test_submit_job_trigger_run_success(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + """Test the trigger correctly submits a job and handles job completion.""" + # Mock sync hook for job submission + mock_sync_hook = mock_get_sync_hook.return_value + mock_job = mock.MagicMock() + mock_job.reference.job_id = TEST_JOB_ID + mock_sync_hook.submit_job.return_value = mock_job + + # Mock async hook for job polling + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE)) + ) + + async_gen = submit_job_trigger.run() + event = await async_gen.asend(None) + + # Verify job was submitted + mock_sync_hook.submit_job.assert_called_once_with( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + job={}, + request_id=None, + retry=RETRY, + timeout=None, + metadata=METADATA, + ) + + # Verify job was polled + mock_async_hook.get_job.assert_called_once_with( + project_id=TEST_PROJECT_ID, + region=TEST_REGION, + job_id=TEST_JOB_ID, + ) + + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_async_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") + async def test_submit_job_trigger_run_error(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + """Test the trigger correctly handles a job error.""" + # Mock sync hook for job submission + mock_sync_hook = mock_get_sync_hook.return_value + mock_job = mock.MagicMock() + mock_job.reference.job_id = TEST_JOB_ID + mock_sync_hook.submit_job.return_value = mock_job + + # Mock async hook for job polling + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.ERROR)) + ) + + async_gen = submit_job_trigger.run() + event = await async_gen.asend(None) + + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_async_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") + async def test_submit_job_trigger_run_cancelled(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + """Test the trigger correctly handles a job cancellation.""" + # Mock sync hook for job submission + mock_sync_hook = mock_get_sync_hook.return_value + mock_job = mock.MagicMock() + mock_job.reference.job_id = TEST_JOB_ID + mock_sync_hook.submit_job.return_value = mock_job + + # Mock async hook for job polling + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job = mock.AsyncMock( + return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.CANCELLED)) + ) + + async_gen = submit_job_trigger.run() + event = await async_gen.asend(None) + + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.CANCELLED, "job": mock_async_hook.get_job.return_value} + ) + assert event.payload == expected_event.payload + + @pytest.mark.asyncio + @pytest.mark.parametrize("is_safe_to_cancel", [True, False]) + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.safe_to_cancel") + async def test_submit_job_trigger_run_cancelled_with_exception( + self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger, is_safe_to_cancel + ): + """Test the trigger correctly handles an asyncio.CancelledError during job polling.""" + mock_safe_to_cancel.return_value = is_safe_to_cancel + + # Mock sync hook for job submission + mock_sync_hook = mock_get_sync_hook.return_value + mock_job = mock.MagicMock() + mock_job.reference.job_id = TEST_JOB_ID + mock_sync_hook.submit_job.return_value = mock_job + mock_sync_hook.cancel_job = mock.MagicMock() + + # Mock async hook for job polling + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job.side_effect = asyncio.CancelledError + + async_gen = submit_job_trigger.run() + + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await async_gen.asend(None) + # Should raise StopAsyncIteration if no more items to yield + await async_gen.asend(None) + + # Check if cancel_job was correctly called + if submit_job_trigger.cancel_on_kill and is_safe_to_cancel: + mock_sync_hook.cancel_job.assert_called_once_with( + job_id=submit_job_trigger.job_id, + project_id=submit_job_trigger.project_id, + region=submit_job_trigger.region, + ) + else: + mock_sync_hook.cancel_job.assert_not_called() + + # Clean up the generator + await async_gen.aclose() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") + @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") + async def test_submit_job_trigger_polling_loop(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger, caplog): + """Test the trigger correctly polls for job status until completion.""" + # Mock sync hook for job submission + mock_sync_hook = mock_get_sync_hook.return_value + mock_job = mock.MagicMock() + mock_job.reference.job_id = TEST_JOB_ID + mock_sync_hook.submit_job.return_value = mock_job + + # Mock async hook for job polling + mock_async_hook = mock_get_async_hook.return_value + + # Create mock objects for the side effect + mock_running_job = mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.RUNNING)) + mock_done_job = mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE)) + + # First call returns RUNNING, second call returns DONE + mock_async_hook.get_job = mock.AsyncMock() + mock_async_hook.get_job.side_effect = [mock_running_job, mock_done_job] + + caplog.set_level(logging.INFO) + + async_gen = submit_job_trigger.run() + event = await async_gen.asend(None) + + # Verify job was polled multiple times + assert mock_async_hook.get_job.call_count == 2 + + expected_event = TriggerEvent( + {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_done_job} + ) + assert event.payload == expected_event.payload From 01d686aaf8ab1e1c3b0f96cd8b42e9671a140e1e Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 22 Jun 2025 11:52:08 +0900 Subject: [PATCH 03/19] add: mock_log --- .../google/cloud/triggers/test_dataproc.py | 46 +++++++++++-------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index 6a61ab0333647..848fd277c617c 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -18,7 +18,6 @@ import asyncio import contextlib -import logging from asyncio import CancelledError, Future, sleep from unittest import mock @@ -220,8 +219,9 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully( return_value=asyncio.Future(), ) @mock.patch("google.auth.default") + @mock.patch.object(DataprocClusterTrigger, "log") async def test_async_cluster_trigger_run_returns_error_event( - self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog + self, mock_log, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster ): mock_credentials = mock.MagicMock() mock_credentials.universe_domain = "googleapis.com" @@ -238,19 +238,22 @@ async def test_async_cluster_trigger_run_returns_error_event( status=ClusterStatus(state=ClusterStatus.State.ERROR), ) - caplog.set_level(logging.INFO) - trigger_event = None async for event in cluster_trigger.run(): trigger_event = event + # Verify logging was called for cluster deletion + mock_log.info.assert_any_call("Deleting cluster %s.", TEST_CLUSTER_NAME) + mock_log.info.assert_any_call("Cluster %s has been deleted.", TEST_CLUSTER_NAME) + assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster") + @mock.patch.object(DataprocClusterTrigger, "log") async def test_cluster_run_loop_is_still_running( - self, mock_hook, cluster_trigger, caplog, async_get_cluster + self, mock_log, mock_hook, cluster_trigger, async_get_cluster ): mock_hook.return_value = async_get_cluster( project_id=TEST_PROJECT_ID, @@ -259,20 +262,20 @@ async def test_cluster_run_loop_is_still_running( status=ClusterStatus(state=ClusterStatus.State.CREATING), ) - caplog.set_level(logging.INFO) - task = asyncio.create_task(cluster_trigger.run().__anext__()) await asyncio.sleep(0.5) assert not task.done() - assert f"Current state is: {ClusterStatus.State.CREATING}." - assert f"Sleeping for {TEST_POLL_INTERVAL} seconds." + # Verify logging was called for state updates + mock_log.info.assert_any_call("Current state is %s", ClusterStatus.State.CREATING) + mock_log.info.assert_any_call("Sleeping for %s seconds.", TEST_POLL_INTERVAL) @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook") + @mock.patch.object(DataprocClusterTrigger, "log") async def test_cluster_trigger_cancellation_handling( - self, mock_get_sync_hook, mock_get_async_hook, caplog + self, mock_log, mock_get_sync_hook, mock_get_async_hook ): cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING)) mock_get_async_hook.return_value.get_cluster.return_value = asyncio.Future() @@ -306,8 +309,9 @@ async def test_cluster_trigger_cancellation_handling( cluster_name=cluster_trigger.cluster_name, project_id=cluster_trigger.project_id, ) - assert "Deleting cluster" in caplog.text - assert "Deleted cluster" in caplog.text + # Verify logging was called for cluster deletion + mock_log.info.assert_any_call("Deleting cluster %s.", cluster_trigger.cluster_name) + mock_log.info.assert_any_call("Deleted cluster %s during cancellation.", cluster_trigger.cluster_name) else: mock_delete_cluster.assert_not_called() except Exception as e: @@ -469,19 +473,19 @@ async def test_create_batch_run_returns_cancelled_event(self, mock_hook, batch_t @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch") + @mock.patch.object(DataprocBatchTrigger, "log") async def test_create_batch_run_loop_is_still_running( - self, mock_hook, batch_trigger, caplog, async_get_batch + self, mock_log, mock_hook, batch_trigger, async_get_batch ): mock_hook.return_value = async_get_batch(state=Batch.State.RUNNING) - caplog.set_level(logging.INFO) - task = asyncio.create_task(batch_trigger.run().__anext__()) await asyncio.sleep(0.5) assert not task.done() - assert f"Current state is: {Batch.State.RUNNING}" - assert f"Sleeping for {TEST_POLL_INTERVAL} seconds." + # Verify logging was called for state updates + mock_log.info.assert_any_call("Current state is %s", Batch.State.RUNNING) + mock_log.info.assert_any_call("Sleeping for %s seconds.", TEST_POLL_INTERVAL) class TestDataprocOperationTrigger: @@ -812,7 +816,8 @@ async def test_submit_job_trigger_run_cancelled_with_exception( @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_polling_loop(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger, caplog): + @mock.patch.object(DataprocSubmitJobTrigger, "log") + async def test_submit_job_trigger_polling_loop(self, mock_log, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): """Test the trigger correctly polls for job status until completion.""" # Mock sync hook for job submission mock_sync_hook = mock_get_sync_hook.return_value @@ -831,14 +836,15 @@ async def test_submit_job_trigger_polling_loop(self, mock_get_sync_hook, mock_ge mock_async_hook.get_job = mock.AsyncMock() mock_async_hook.get_job.side_effect = [mock_running_job, mock_done_job] - caplog.set_level(logging.INFO) - async_gen = submit_job_trigger.run() event = await async_gen.asend(None) # Verify job was polled multiple times assert mock_async_hook.get_job.call_count == 2 + # Verify logging was called for job status updates + mock_log.info.assert_called_with("Dataproc job: %s is in state: %s", TEST_JOB_ID, JobStatus.State.DONE) + expected_event = TriggerEvent( {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_done_job} ) From e2fbfd4692ec6e01f6c85aab654c9b6d96d51222 Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 22 Jun 2025 12:34:03 +0900 Subject: [PATCH 04/19] fix: TestDataprocCreateBatchOperator test_execute_openlineage_all_info_injection --- .../tests/unit/google/cloud/operators/test_dataproc.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index 00923cb590a1f..dbe9c66da6663 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -3453,8 +3453,14 @@ def test_execute_openlineage_all_info_injection( mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) + expected_labels = { + "airflow-dag-id": "adhoc_airflow", + "airflow-task-id": "task-id", + } + expected_batch = { **BATCH, + "labels": expected_labels, "runtime_config": { "properties": { **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, From a6de84a5c3a988ff0a5882c15cdb469ba4e65846 Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 22 Jun 2025 16:50:51 +0900 Subject: [PATCH 05/19] fix: test_dataproc pre-commit --- .../google/cloud/triggers/test_dataproc.py | 51 +++++++++++++++---- 1 file changed, 40 insertions(+), 11 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index 848fd277c617c..a46cdcd857fcc 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -145,7 +145,7 @@ def submit_job_trigger(): metadata=METADATA, gcp_conn_id=TEST_GCP_CONN_ID, polling_interval_seconds=TEST_POLL_INTERVAL, - cancel_on_kill=True + cancel_on_kill=True, ) @@ -311,7 +311,9 @@ async def test_cluster_trigger_cancellation_handling( ) # Verify logging was called for cluster deletion mock_log.info.assert_any_call("Deleting cluster %s.", cluster_trigger.cluster_name) - mock_log.info.assert_any_call("Deleted cluster %s during cancellation.", cluster_trigger.cluster_name) + mock_log.info.assert_any_call( + "Deleted cluster %s during cancellation.", cluster_trigger.cluster_name + ) else: mock_delete_cluster.assert_not_called() except Exception as e: @@ -681,7 +683,9 @@ def test_submit_job_trigger_serialization(self, submit_job_trigger): @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_run_success(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + async def test_submit_job_trigger_run_success( + self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger + ): """Test the trigger correctly submits a job and handles job completion.""" # Mock sync hook for job submission mock_sync_hook = mock_get_sync_hook.return_value @@ -717,14 +721,20 @@ async def test_submit_job_trigger_run_success(self, mock_get_sync_hook, mock_get ) expected_event = TriggerEvent( - {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_async_hook.get_job.return_value} + { + "job_id": TEST_JOB_ID, + "job_state": JobStatus.State.DONE, + "job": mock_async_hook.get_job.return_value, + } ) assert event.payload == expected_event.payload @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_run_error(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + async def test_submit_job_trigger_run_error( + self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger + ): """Test the trigger correctly handles a job error.""" # Mock sync hook for job submission mock_sync_hook = mock_get_sync_hook.return_value @@ -742,14 +752,20 @@ async def test_submit_job_trigger_run_error(self, mock_get_sync_hook, mock_get_a event = await async_gen.asend(None) expected_event = TriggerEvent( - {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_async_hook.get_job.return_value} + { + "job_id": TEST_JOB_ID, + "job_state": JobStatus.State.ERROR, + "job": mock_async_hook.get_job.return_value, + } ) assert event.payload == expected_event.payload @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_run_cancelled(self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + async def test_submit_job_trigger_run_cancelled( + self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger + ): """Test the trigger correctly handles a job cancellation.""" # Mock sync hook for job submission mock_sync_hook = mock_get_sync_hook.return_value @@ -767,7 +783,11 @@ async def test_submit_job_trigger_run_cancelled(self, mock_get_sync_hook, mock_g event = await async_gen.asend(None) expected_event = TriggerEvent( - {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.CANCELLED, "job": mock_async_hook.get_job.return_value} + { + "job_id": TEST_JOB_ID, + "job_state": JobStatus.State.CANCELLED, + "job": mock_async_hook.get_job.return_value, + } ) assert event.payload == expected_event.payload @@ -777,7 +797,12 @@ async def test_submit_job_trigger_run_cancelled(self, mock_get_sync_hook, mock_g @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.safe_to_cancel") async def test_submit_job_trigger_run_cancelled_with_exception( - self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger, is_safe_to_cancel + self, + mock_safe_to_cancel, + mock_get_sync_hook, + mock_get_async_hook, + submit_job_trigger, + is_safe_to_cancel, ): """Test the trigger correctly handles an asyncio.CancelledError during job polling.""" mock_safe_to_cancel.return_value = is_safe_to_cancel @@ -817,7 +842,9 @@ async def test_submit_job_trigger_run_cancelled_with_exception( @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") @mock.patch.object(DataprocSubmitJobTrigger, "log") - async def test_submit_job_trigger_polling_loop(self, mock_log, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger): + async def test_submit_job_trigger_polling_loop( + self, mock_log, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger + ): """Test the trigger correctly polls for job status until completion.""" # Mock sync hook for job submission mock_sync_hook = mock_get_sync_hook.return_value @@ -843,7 +870,9 @@ async def test_submit_job_trigger_polling_loop(self, mock_log, mock_get_sync_hoo assert mock_async_hook.get_job.call_count == 2 # Verify logging was called for job status updates - mock_log.info.assert_called_with("Dataproc job: %s is in state: %s", TEST_JOB_ID, JobStatus.State.DONE) + mock_log.info.assert_called_with( + "Dataproc job: %s is in state: %s", TEST_JOB_ID, JobStatus.State.DONE + ) expected_event = TriggerEvent( {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_done_job} From 826f4fa09d3d80c6fee3b1c7492ff75f1df0f557 Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 22 Jun 2025 17:45:30 +0900 Subject: [PATCH 06/19] fix: Move Retry into type-checking block --- .../src/airflow/providers/google/cloud/triggers/dataproc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 7d13996ef5995..63bb5d5fd5ccd 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -28,7 +28,6 @@ from asgiref.sync import sync_to_async from google.api_core.exceptions import NotFound from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault -from google.api_core.retry import Retry from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus from airflow.exceptions import AirflowException @@ -42,6 +41,7 @@ from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: + from google.api_core.retry import Retry from sqlalchemy.orm.session import Session From 961eb3c93a1be3bfb3a77225062f44c5ce5f86fc Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 29 Jun 2025 15:44:27 +0900 Subject: [PATCH 07/19] fix: add task_state for af3 --- .../google/cloud/triggers/dataproc.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 63bb5d5fd5ccd..192b8984c31cc 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -277,7 +277,28 @@ def get_task_instance(self, session: Session) -> TaskInstance: ) return task_instance - def safe_to_cancel(self) -> bool: + async def get_task_state(self): + from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance + + task_states_response = await sync_to_async(RuntimeTaskInstance.get_task_states)( + dag_id=self.task_instance.dag_id, + task_ids=[self.task_instance.task_id], + run_ids=[self.task_instance.run_id], + map_index=self.task_instance.map_index, + ) + try: + task_state = task_states_response[self.task_instance.run_id][self.task_instance.task_id] + except Exception: + raise AirflowException( + "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and map_index: %s is not found", + self.task_instance.dag_id, + self.task_instance.task_id, + self.task_instance.run_id, + self.task_instance.map_index, + ) + return task_state + + async def safe_to_cancel(self) -> bool: """ Whether it is safe to cancel the external job which is being executed by this trigger. @@ -285,8 +306,13 @@ def safe_to_cancel(self) -> bool: Because in those cases, we should NOT cancel the external job. """ # Database query is needed to get the latest state of the task instance. - task_instance = self.get_task_instance() # type: ignore[call-arg] - return task_instance.state != TaskInstanceState.DEFERRED + if AIRFLOW_V_3_0_PLUS: + task_state = await self.get_task_state() + else: + # Database query is needed to get the latest state of the task instance. + task_instance = self.get_task_instance() # type: ignore[call-arg] + task_state = task_instance.state + return task_state != TaskInstanceState.DEFERRED def submit_job(self): return self.get_sync_hook().submit_job( From 1a2fdc31ddc8d0622b9a20bce51be01c4125761f Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 29 Jun 2025 15:55:34 +0900 Subject: [PATCH 08/19] add: start_from_trigger system test --- .../example_dataproc_start_from_trigger.py | 98 +++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py diff --git a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py new file mode 100644 index 0000000000000..37bdd7d7a1578 --- /dev/null +++ b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Example Airflow DAG for DataprocSubmitJobOperator with start_from_trigger. +""" + +from __future__ import annotations + +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.google.cloud.operators.dataproc import DataprocSubmitJobOperator + +from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") +DAG_ID = "dataproc_spark" +PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID + +CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-") +CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-") +CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else CLUSTER_NAME_FULL + +BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" +RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" +JOB_FILE = "dataproc-pyspark-job-pi.py" +GCS_JOB_FILE = f"gs://{BUCKET_NAME}/dataproc/{JOB_FILE}" +REGION = "us-central1" + +# Cluster definition +CLUSTER_CONFIG = { + "master_config": { + "num_instances": 1, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, + "worker_config": { + "num_instances": 2, + "machine_type_uri": "n1-standard-4", + "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 32}, + }, +} + +# Jobs definitions +# Define a sample PySpark job +PYSPARK_JOB = { + "reference": {"project_id": PROJECT_ID}, + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": {"main_python_file_uri": GCS_JOB_FILE}, +} + +# Create DAG +with DAG( + dag_id=DAG_ID, + schedule="@once", + start_date=datetime(2023, 1, 1), + catchup=False, + tags=["dataproc", "start_from_trigger"], +) as dag: + # Task to test start_from_trigger=True + submit_job_with_trigger = DataprocSubmitJobOperator( + task_id="submit_job_with_trigger", + job=PYSPARK_JOB, + region=REGION, + project_id=PROJECT_ID, + start_from_trigger=True, + ) + + # Define task dependencies + submit_job_with_trigger + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "teardown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) From ed2fd8000c8244498dd9a083674728849ee6f2e6 Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 29 Jun 2025 16:40:41 +0900 Subject: [PATCH 09/19] fix: system-test --- .../example_dataproc_start_from_trigger.py | 59 ++++++++++++++----- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py index 37bdd7d7a1578..d432944b16020 100644 --- a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py +++ b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py @@ -24,24 +24,27 @@ import os from datetime import datetime +from google.api_core.retry import Retry + from airflow import DAG -from airflow.providers.google.cloud.operators.dataproc import DataprocSubmitJobOperator +from airflow.providers.google.cloud.operators.dataproc import ( + DataprocCreateClusterOperator, + DataprocDeleteClusterOperator, + DataprocSubmitJobOperator, +) +from airflow.utils.trigger_rule import TriggerRule from system.google import DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default") -DAG_ID = "dataproc_spark" +DAG_ID = "dataproc_start_from_trigger" PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT") or DEFAULT_GCP_SYSTEM_TEST_PROJECT_ID CLUSTER_NAME_BASE = f"cluster-{DAG_ID}".replace("_", "-") CLUSTER_NAME_FULL = CLUSTER_NAME_BASE + f"-{ENV_ID}".replace("_", "-") CLUSTER_NAME = CLUSTER_NAME_BASE if len(CLUSTER_NAME_FULL) >= 33 else CLUSTER_NAME_FULL -BUCKET_NAME = f"bucket_{DAG_ID}_{ENV_ID}" -RESOURCE_DATA_BUCKET = "airflow-system-tests-resources" -JOB_FILE = "dataproc-pyspark-job-pi.py" -GCS_JOB_FILE = f"gs://{BUCKET_NAME}/dataproc/{JOB_FILE}" -REGION = "us-central1" +REGION = "europe-west1" # Cluster definition CLUSTER_CONFIG = { @@ -58,11 +61,13 @@ } # Jobs definitions -# Define a sample PySpark job -PYSPARK_JOB = { +SPARK_JOB = { "reference": {"project_id": PROJECT_ID}, "placement": {"cluster_name": CLUSTER_NAME}, - "pyspark_job": {"main_python_file_uri": GCS_JOB_FILE}, + "spark_job": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, } # Create DAG @@ -73,17 +78,41 @@ catchup=False, tags=["dataproc", "start_from_trigger"], ) as dag: - # Task to test start_from_trigger=True - submit_job_with_trigger = DataprocSubmitJobOperator( - task_id="submit_job_with_trigger", - job=PYSPARK_JOB, + create_cluster = DataprocCreateClusterOperator( + task_id="create_cluster", + project_id=PROJECT_ID, + cluster_config=CLUSTER_CONFIG, + region=REGION, + cluster_name=CLUSTER_NAME, + retry=Retry(maximum=100.0, initial=10.0, multiplier=1.0), + num_retries_if_resource_is_not_ready=3, + ) + + spark_job_with_start_from_trigger = DataprocSubmitJobOperator( + task_id="spark_job_with_start_from_trigger", + job=SPARK_JOB, region=REGION, project_id=PROJECT_ID, start_from_trigger=True, ) + delete_cluster = DataprocDeleteClusterOperator( + task_id="delete_cluster", + project_id=PROJECT_ID, + region=REGION, + cluster_name=CLUSTER_NAME, + trigger_rule=TriggerRule.ALL_DONE, + ) + # Define task dependencies - submit_job_with_trigger + ( + # TEST SETUP + create_cluster + # TEST BODY + >> spark_job_with_start_from_trigger + # TEST TEARDOWN + >> delete_cluster + ) from tests_common.test_utils.watcher import watcher From f2b84e2e1825b03fecd37619cfa2ef6ff7a1c572 Mon Sep 17 00:00:00 2001 From: geonwoo Date: Sun, 29 Jun 2025 21:53:15 +0900 Subject: [PATCH 10/19] fix: add await to safe_to_cancel --- .../google/cloud/triggers/dataproc.py | 2 +- .../google/cloud/triggers/test_dataproc.py | 56 +++++-------------- 2 files changed, 15 insertions(+), 43 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 192b8984c31cc..3cb95d6158e33 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -343,7 +343,7 @@ async def run(self): except asyncio.CancelledError: self.log.info("Task got cancelled.") try: - if self.job_id and self.cancel_on_kill and self.safe_to_cancel(): + if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): self.log.info( "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" " in deferred state." diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index a46cdcd857fcc..0f378ade5e0f3 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -760,43 +760,12 @@ async def test_submit_job_trigger_run_error( ) assert event.payload == expected_event.payload - @pytest.mark.asyncio - @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") - @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_run_cancelled( - self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger - ): - """Test the trigger correctly handles a job cancellation.""" - # Mock sync hook for job submission - mock_sync_hook = mock_get_sync_hook.return_value - mock_job = mock.MagicMock() - mock_job.reference.job_id = TEST_JOB_ID - mock_sync_hook.submit_job.return_value = mock_job - - # Mock async hook for job polling - mock_async_hook = mock_get_async_hook.return_value - mock_async_hook.get_job = mock.AsyncMock( - return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.CANCELLED)) - ) - - async_gen = submit_job_trigger.run() - event = await async_gen.asend(None) - - expected_event = TriggerEvent( - { - "job_id": TEST_JOB_ID, - "job_state": JobStatus.State.CANCELLED, - "job": mock_async_hook.get_job.return_value, - } - ) - assert event.payload == expected_event.payload - @pytest.mark.asyncio @pytest.mark.parametrize("is_safe_to_cancel", [True, False]) @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.safe_to_cancel") - async def test_submit_job_trigger_run_cancelled_with_exception( + async def test_submit_job_trigger_run_cancelled( self, mock_safe_to_cancel, mock_get_sync_hook, @@ -804,26 +773,29 @@ async def test_submit_job_trigger_run_cancelled_with_exception( submit_job_trigger, is_safe_to_cancel, ): - """Test the trigger correctly handles an asyncio.CancelledError during job polling.""" + """Test the trigger correctly handles an asyncio.CancelledError.""" mock_safe_to_cancel.return_value = is_safe_to_cancel + mock_async_hook = mock_get_async_hook.return_value + mock_async_hook.get_job.side_effect = asyncio.CancelledError - # Mock sync hook for job submission mock_sync_hook = mock_get_sync_hook.return_value - mock_job = mock.MagicMock() - mock_job.reference.job_id = TEST_JOB_ID - mock_sync_hook.submit_job.return_value = mock_job mock_sync_hook.cancel_job = mock.MagicMock() - # Mock async hook for job polling - mock_async_hook = mock_get_async_hook.return_value - mock_async_hook.get_job.side_effect = asyncio.CancelledError - async_gen = submit_job_trigger.run() - with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + try: await async_gen.asend(None) # Should raise StopAsyncIteration if no more items to yield await async_gen.asend(None) + except asyncio.CancelledError: + # Handle the cancellation as expected + pass + except StopAsyncIteration: + # The generator should be properly closed after handling the cancellation + pass + except Exception as e: + # Catch any other exceptions that should not occur + pytest.fail(f"Unexpected exception raised: {e}") # Check if cancel_job was correctly called if submit_job_trigger.cancel_on_kill and is_safe_to_cancel: From d5da7e81b652f93de8b6584096064fe61bbd057b Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Mon, 30 Jun 2025 23:52:56 +0900 Subject: [PATCH 11/19] fix: test_execute_openlineage_transport_info_injection --- .../tests/unit/google/cloud/operators/test_dataproc.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index dbe9c66da6663..c9f8b1b1b6019 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -3408,8 +3408,14 @@ def test_execute_openlineage_transport_info_injection( mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) + expected_labels = { + "airflow-dag-id": "adhoc_airflow", + "airflow-task-id": "task-id", + } + expected_batch = { **BATCH, + "labels": expected_labels, "runtime_config": {"properties": OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES}, } From 8f7c7118b63cadfd0a0c589280fd4c79fe3b6358 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Wed, 2 Jul 2025 08:19:26 +0900 Subject: [PATCH 12/19] fix: test --- .../unit/google/cloud/operators/test_dataproc.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py index e678ae5d4a07d..60dd36cdf52a4 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py @@ -3475,7 +3475,8 @@ def test_execute_openlineage_transport_info_injection( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) expected_labels = { - "airflow-dag-id": "adhoc_airflow", + "airflow-dag-id": "test-dataproc-operators", + "airflow-dag-display-name": "test-dataproc-operators", "airflow-task-id": "task-id", } @@ -3537,7 +3538,8 @@ def test_execute_openlineage_all_info_injection( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) expected_labels = { - "airflow-dag-id": "adhoc_airflow", + "airflow-dag-id": "test-dataproc-operators", + "airflow-dag-display-name": "test-dataproc-operators", "airflow-task-id": "task-id", } @@ -3635,9 +3637,14 @@ def test_execute_openlineage_transport_info_injection_skipped_when_already_prese mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) ) + expected_labels = { + "airflow-dag-id": "test-dataproc-operators", + "airflow-dag-display-name": "test-dataproc-operators", + "airflow-task-id": "task-id", + } batch = { **BATCH, - "labels": EXPECTED_LABELS, + "labels": expected_labels, "runtime_config": { "properties": { "spark.openlineage.transport.type": "console", From de092b49051608f5f71ab13241819b008db7adb5 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Wed, 2 Jul 2025 09:28:48 +0900 Subject: [PATCH 13/19] fix --- .../cloud/dataproc/example_dataproc_start_from_trigger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py index d432944b16020..db1923da52f33 100644 --- a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py +++ b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py @@ -124,4 +124,4 @@ from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) -test_run = get_test_run(dag) +test_run = get_test_run(dag) \ No newline at end of file From cd45912403d30122e46c090a86306a263bc0a635 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Wed, 2 Jul 2025 14:10:29 +0900 Subject: [PATCH 14/19] fix: example_dataproc_start_from_trigger pre-commit --- .../cloud/dataproc/example_dataproc_start_from_trigger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py index db1923da52f33..d432944b16020 100644 --- a/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py +++ b/providers/google/tests/system/google/cloud/dataproc/example_dataproc_start_from_trigger.py @@ -124,4 +124,4 @@ from tests_common.test_utils.system_tests import get_test_run # noqa: E402 # Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) -test_run = get_test_run(dag) \ No newline at end of file +test_run = get_test_run(dag) From a93d55b1b13157b6afcbe5568228e8af350f7278 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Sun, 20 Jul 2025 13:35:02 +0900 Subject: [PATCH 15/19] change to async hook --- .../src/airflow/providers/google/cloud/triggers/dataproc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 11dd2c4459841..50229d8fa83b8 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -315,7 +315,7 @@ async def safe_to_cancel(self) -> bool: return task_state != TaskInstanceState.DEFERRED def submit_job(self): - return self.get_sync_hook().submit_job( + return self.get_async_hook().submit_job( project_id=self.project_id, region=self.region, job=self.job, From 8f60603a8d4504acc387e232ec099815d9bab845 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Sun, 20 Jul 2025 22:18:34 +0900 Subject: [PATCH 16/19] add: normalize_retry_value --- .../google/cloud/triggers/dataproc.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index 50229d8fa83b8..e4beaa628ea23 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -235,6 +235,18 @@ def __init__( self.timeout = timeout self.metadata = metadata + def _normalize_retry_value(self, retry_value): + """ + Normalize retry value for serialization and API calls. + + Since DEFAULT and Retry objects don't serialize well, we convert them to None. + """ + if retry_value is DEFAULT or retry_value is None: + return None + # For other retry objects (like Retry instances), use None as fallback + # since they are complex objects that don't serialize well + return None + def serialize(self): return ( "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger", @@ -243,7 +255,7 @@ def serialize(self): "region": self.region, "job": self.job, "request_id": self.request_id, - "retry": self.retry, + "retry": self._normalize_retry_value(self.retry), "timeout": self.timeout, "metadata": self.metadata, "gcp_conn_id": self.gcp_conn_id, @@ -314,21 +326,18 @@ async def safe_to_cancel(self) -> bool: task_state = task_instance.state return task_state != TaskInstanceState.DEFERRED - def submit_job(self): - return self.get_async_hook().submit_job( - project_id=self.project_id, - region=self.region, - job=self.job, - request_id=self.request_id, - retry=self.retry, - timeout=self.timeout, - metadata=self.metadata, - ) - async def run(self): try: # Create a new Dataproc job - job = self.submit_job() + job = await self.get_async_hook().submit_job( + project_id=self.project_id, + region=self.region, + job=self.job, + request_id=self.request_id, + retry=self._normalize_retry_value(self.retry), + timeout=self.timeout, + metadata=self.metadata, + ) self.job_id = job.reference.job_id while True: job = await self.get_async_hook().get_job( @@ -339,7 +348,7 @@ async def run(self): if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR): break await asyncio.sleep(self.polling_interval_seconds) - yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job}) + yield TriggerEvent({"job_id": self.job_id, "job_state": str(state), "job": str(job)}) except asyncio.CancelledError: self.log.info("Task got cancelled.") try: From a8e6e147467381583499e6b06ba62315b5317b35 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Sun, 20 Jul 2025 22:46:03 +0900 Subject: [PATCH 17/19] fix: test code --- .../google/cloud/triggers/dataproc.py | 13 ++- .../google/cloud/triggers/test_dataproc.py | 81 ++++++++----------- 2 files changed, 43 insertions(+), 51 deletions(-) diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py index e4beaa628ea23..405daa2cf5f4e 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py @@ -234,6 +234,7 @@ def __init__( self.retry = retry self.timeout = timeout self.metadata = metadata + self.job_id = None # Initialize job_id to None def _normalize_retry_value(self, retry_value): """ @@ -352,7 +353,12 @@ async def run(self): except asyncio.CancelledError: self.log.info("Task got cancelled.") try: - if self.job_id and self.cancel_on_kill and await self.safe_to_cancel(): + if ( + hasattr(self, "job_id") + and self.job_id + and self.cancel_on_kill + and await self.safe_to_cancel() + ): self.log.info( "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not" " in deferred state." @@ -364,7 +370,10 @@ async def run(self): self.log.info("Job: %s is cancelled", self.job_id) yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING}) except Exception as e: - self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) + if hasattr(self, "job_id") and self.job_id: + self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e)) + else: + self.log.error("Failed to cancel the job (no job_id available) with error : %s", str(e)) raise e diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index 887e3e8e7bacf..b32bdb6e38494 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -695,7 +695,7 @@ def test_submit_job_trigger_serialization(self, submit_job_trigger): "region": TEST_REGION, "job": {}, "request_id": None, - "retry": RETRY, + "retry": None, "timeout": None, "metadata": METADATA, "gcp_conn_id": TEST_GCP_CONN_ID, @@ -706,19 +706,14 @@ def test_submit_job_trigger_serialization(self, submit_job_trigger): @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") - @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_run_success( - self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger - ): + async def test_submit_job_trigger_run_success(self, mock_get_async_hook, submit_job_trigger): """Test the trigger correctly submits a job and handles job completion.""" - # Mock sync hook for job submission - mock_sync_hook = mock_get_sync_hook.return_value - mock_job = mock.MagicMock() + mock_async_hook = mock_get_async_hook.return_value + + mock_job = mock.AsyncMock() mock_job.reference.job_id = TEST_JOB_ID - mock_sync_hook.submit_job.return_value = mock_job + mock_async_hook.submit_job = mock.AsyncMock(return_value=mock_job) - # Mock async hook for job polling - mock_async_hook = mock_get_async_hook.return_value mock_async_hook.get_job = mock.AsyncMock( return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE)) ) @@ -726,13 +721,12 @@ async def test_submit_job_trigger_run_success( async_gen = submit_job_trigger.run() event = await async_gen.asend(None) - # Verify job was submitted - mock_sync_hook.submit_job.assert_called_once_with( + mock_async_hook.submit_job.assert_called_once_with( project_id=TEST_PROJECT_ID, region=TEST_REGION, job={}, request_id=None, - retry=RETRY, + retry=None, timeout=None, metadata=METADATA, ) @@ -747,27 +741,22 @@ async def test_submit_job_trigger_run_success( expected_event = TriggerEvent( { "job_id": TEST_JOB_ID, - "job_state": JobStatus.State.DONE, - "job": mock_async_hook.get_job.return_value, + "job_state": str(JobStatus.State.DONE), + "job": str(mock_async_hook.get_job.return_value), } ) assert event.payload == expected_event.payload @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") - @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") - async def test_submit_job_trigger_run_error( - self, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger - ): + async def test_submit_job_trigger_run_error(self, mock_get_async_hook, submit_job_trigger): """Test the trigger correctly handles a job error.""" - # Mock sync hook for job submission - mock_sync_hook = mock_get_sync_hook.return_value - mock_job = mock.MagicMock() + mock_async_hook = mock_get_async_hook.return_value + + mock_job = mock.AsyncMock() mock_job.reference.job_id = TEST_JOB_ID - mock_sync_hook.submit_job.return_value = mock_job + mock_async_hook.submit_job = mock.AsyncMock(return_value=mock_job) - # Mock async hook for job polling - mock_async_hook = mock_get_async_hook.return_value mock_async_hook.get_job = mock.AsyncMock( return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.ERROR)) ) @@ -778,8 +767,8 @@ async def test_submit_job_trigger_run_error( expected_event = TriggerEvent( { "job_id": TEST_JOB_ID, - "job_state": JobStatus.State.ERROR, - "job": mock_async_hook.get_job.return_value, + "job_state": str(JobStatus.State.ERROR), + "job": str(mock_async_hook.get_job.return_value), } ) assert event.payload == expected_event.payload @@ -800,7 +789,7 @@ async def test_submit_job_trigger_run_cancelled( """Test the trigger correctly handles an asyncio.CancelledError.""" mock_safe_to_cancel.return_value = is_safe_to_cancel mock_async_hook = mock_get_async_hook.return_value - mock_async_hook.get_job.side_effect = asyncio.CancelledError + mock_async_hook.submit_job.side_effect = asyncio.CancelledError mock_sync_hook = mock_get_sync_hook.return_value mock_sync_hook.cancel_job = mock.MagicMock() @@ -822,35 +811,25 @@ async def test_submit_job_trigger_run_cancelled( pytest.fail(f"Unexpected exception raised: {e}") # Check if cancel_job was correctly called - if submit_job_trigger.cancel_on_kill and is_safe_to_cancel: - mock_sync_hook.cancel_job.assert_called_once_with( - job_id=submit_job_trigger.job_id, - project_id=submit_job_trigger.project_id, - region=submit_job_trigger.region, - ) - else: - mock_sync_hook.cancel_job.assert_not_called() + # Since the cancellation happens during submit_job, job_id is None, so cancel_job should not be called + mock_sync_hook.cancel_job.assert_not_called() # Clean up the generator await async_gen.aclose() @pytest.mark.asyncio @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_async_hook") - @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitJobTrigger.get_sync_hook") @mock.patch.object(DataprocSubmitJobTrigger, "log") - async def test_submit_job_trigger_polling_loop( - self, mock_log, mock_get_sync_hook, mock_get_async_hook, submit_job_trigger - ): + async def test_submit_job_trigger_polling_loop(self, mock_log, mock_get_async_hook, submit_job_trigger): """Test the trigger correctly polls for job status until completion.""" - # Mock sync hook for job submission - mock_sync_hook = mock_get_sync_hook.return_value - mock_job = mock.MagicMock() - mock_job.reference.job_id = TEST_JOB_ID - mock_sync_hook.submit_job.return_value = mock_job - - # Mock async hook for job polling + # Mock async hook for job submission and polling mock_async_hook = mock_get_async_hook.return_value + # Mock job submission + mock_job = mock.AsyncMock() + mock_job.reference.job_id = TEST_JOB_ID + mock_async_hook.submit_job = mock.AsyncMock(return_value=mock_job) + # Create mock objects for the side effect mock_running_job = mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.RUNNING)) mock_done_job = mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE)) @@ -871,6 +850,10 @@ async def test_submit_job_trigger_polling_loop( ) expected_event = TriggerEvent( - {"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_done_job} + { + "job_id": TEST_JOB_ID, + "job_state": str(JobStatus.State.DONE), + "job": str(mock_done_job), + } ) assert event.payload == expected_event.payload From 56e4c432acefd8e0f28cae10702478461e9e9bf3 Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Tue, 22 Jul 2025 22:12:16 +0900 Subject: [PATCH 18/19] add: TEST_RUNNING_CLUSTER, TEST_ERROR_CLUSTER --- .../tests/unit/google/cloud/triggers/test_dataproc.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index b32bdb6e38494..4aa30f4186b57 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -54,6 +54,14 @@ TEST_JOB_ID = "test-job-id" RETRY = mock.MagicMock(Retry) METADATA = [("key", "value")] +TEST_RUNNING_CLUSTER = Cluster( + cluster_name=TEST_CLUSTER_NAME, + status=ClusterStatus(state=ClusterStatus.State.RUNNING), +) +TEST_ERROR_CLUSTER = Cluster( + cluster_name=TEST_CLUSTER_NAME, + status=ClusterStatus(state=ClusterStatus.State.ERROR), +) @pytest.fixture From 7e7c2a53e29a7f9c0ff9703fb6d2f6517cbacddf Mon Sep 17 00:00:00 2001 From: kgw7401 Date: Tue, 22 Jul 2025 22:37:49 +0900 Subject: [PATCH 19/19] fix: mock_log --- .../google/tests/unit/google/cloud/triggers/test_dataproc.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py index fc1b4e8d23991..d6f92dde0d8b4 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py @@ -214,8 +214,9 @@ async def test_async_cluster_triggers_on_success_should_execute_successfully( return_value=asyncio.Future(), ) @mock.patch("google.auth.default") + @mock.patch.object(DataprocClusterTrigger, "log") async def test_async_cluster_trigger_run_returns_error_event( - self, mock_auth, mock_delete_cluster, mock_fetch_cluster, cluster_trigger, async_get_cluster, caplog + self, mock_log, mock_auth, mock_delete_cluster, mock_fetch_cluster, cluster_trigger, caplog ): mock_credentials = mock.MagicMock() mock_credentials.universe_domain = "googleapis.com"