diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 078b02b1dca5c..50e8d3b225a74 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -110,6 +110,7 @@ ast astroid Async async +asyncio AsyncResult athena Atlassian diff --git a/providers/google/docs/operators/cloud/cloud_run.rst b/providers/google/docs/operators/cloud/cloud_run.rst index 0cdc6ee0af649..94291f25718b7 100644 --- a/providers/google/docs/operators/cloud/cloud_run.rst +++ b/providers/google/docs/operators/cloud/cloud_run.rst @@ -126,6 +126,33 @@ or you can define the same operator in the deferrable mode: :start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode] :end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode] +Transport +^^^^^^^^^ + +The :class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator` accepts an optional ``transport`` +parameter to choose the underlying API transport. + +* ``transport="grpc"`` (default): use gRPC transport. If ``transport`` is not set, gRPC is used. +* ``transport="rest"``: use REST/HTTP transport. + +In deferrable mode, when using gRPC (explicitly or by default), the trigger uses an async gRPC client internally; for +non-deferrable execution, the operator uses the regular (synchronous) gRPC client. + +In general, it is better to use gRPC (or leave ``transport`` unset) unless there is a specific reason you must use REST (for example, +if gRPC is not available or fails in your environment). + +.. rubric:: Deferrable mode considerations + +When using deferrable mode, the operator defers to an async trigger that polls the long-running operation status. + +* With gRPC (explicitly or by default), the trigger uses the native async gRPC client internally. The ``grpc_asyncio`` transport is + an implementation detail of the Google client library and is not a user-facing ``transport`` value. +* With REST, the REST transport is synchronous-only in the Google Cloud library. To remain compatible with deferrable mode, the + trigger performs REST calls using the synchronous client wrapped in a background thread. + +REST can be used with deferrable mode, but it may be less efficient than gRPC and is generally best reserved for cases where gRPC +cannot be used. + You can also specify overrides that allow you to give a new entrypoint command to the job and more: :class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator` diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py index f2f3ba417b1c0..24fe0f200c880 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_run.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import asyncio import itertools from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING, Any, Literal @@ -93,8 +94,9 @@ def get_conn(self): client_kwargs = { "credentials": self.get_credentials(), "client_info": CLIENT_INFO, - "transport": self.transport, } + if self.transport: + client_kwargs["transport"] = self.transport self._client = JobsClient(**client_kwargs) return self._client @@ -187,8 +189,9 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook): Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account. :param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'. - If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available - or fails in your environment (e.g., Docker containers with certain network configurations). + When set to 'rest', uses the synchronous REST client wrapped with + ``asyncio.to_thread()`` for compatibility with async triggers. + When None or 'grpc', uses the native async gRPC transport (grpc_asyncio). """ sync_hook_class = CloudRunHook @@ -200,27 +203,38 @@ def __init__( transport: Literal["rest", "grpc"] | None = None, **kwargs, ): - self._client: JobsAsyncClient | None = None + self._client: JobsAsyncClient | JobsClient | None = None self.transport = transport - super().__init__( - gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, transport=transport, **kwargs - ) + super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs) async def get_conn(self): if self._client is None: sync_hook = await self.get_sync_hook() - client_kwargs = { - "credentials": sync_hook.get_credentials(), - "client_info": CLIENT_INFO, - "transport": self.transport, - } - self._client = JobsAsyncClient(**client_kwargs) + credentials = sync_hook.get_credentials() + if self.transport == "rest": + # REST transport is synchronous-only. Use the sync JobsClient here; + # get_operation() wraps calls with asyncio.to_thread() for async compat. + self._client = JobsClient( + credentials=credentials, + client_info=CLIENT_INFO, + transport="rest", + ) + else: + # Default: use JobsAsyncClient which picks grpc_asyncio transport. + self._client = JobsAsyncClient( + credentials=credentials, + client_info=CLIENT_INFO, + ) return self._client async def get_operation(self, operation_name: str) -> operations_pb2.Operation: conn = await self.get_conn() - return await conn.get_operation(operations_pb2.GetOperationRequest(name=operation_name), timeout=120) + request = operations_pb2.GetOperationRequest(name=operation_name) + if self.transport == "rest": + # REST client is synchronous — run in a thread to avoid blocking the event loop. + return await asyncio.to_thread(conn.get_operation, request, timeout=120) + return await conn.get_operation(request, timeout=120) class CloudRunServiceHook(GoogleBaseHook): diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py index 8261edd416a3e..746ca9e883815 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_run.py @@ -149,5 +149,5 @@ def _get_async_hook(self) -> CloudRunAsyncHook: return CloudRunAsyncHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, - transport=self.transport or "grpc", + transport=self.transport, ) diff --git a/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py b/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py index cf82bbf45b3b9..789f128dbca1a 100644 --- a/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py +++ b/providers/google/tests/system/google/cloud/cloud_run/example_cloud_run.py @@ -51,14 +51,20 @@ job1_name = f"{job_name_prefix}1-{ENV_ID}" job2_name = f"{job_name_prefix}2-{ENV_ID}" job3_name = f"{job_name_prefix}3-{ENV_ID}" +job4_name = f"{job_name_prefix}4-{ENV_ID}" +job5_name = f"{job_name_prefix}5-{ENV_ID}" create1_task_name = "create-job1" create2_task_name = "create-job2" create3_task_name = "create-job3" +create4_task_name = "create-job4" +create5_task_name = "create-job5" execute1_task_name = "execute-job1" execute2_task_name = "execute-job2" execute3_task_name = "execute-job3" +execute4_task_name = "execute-job4" +execute5_task_name = "execute-job5" update_job1_task_name = "update-job1" @@ -82,6 +88,12 @@ def _assert_executed_jobs_xcom(ti): job3_dict = ti.xcom_pull(execute3_task_name) assert job3_name in job3_dict["name"] + job4_dict = ti.xcom_pull(execute4_task_name) + assert job4_name in job4_dict["name"] + + job5_dict = ti.xcom_pull(execute5_task_name) + assert job5_name in job5_dict["name"] + def _assert_created_jobs_xcom(ti): job1_dict = ti.xcom_pull(create1_task_name) @@ -93,6 +105,12 @@ def _assert_created_jobs_xcom(ti): job3_dict = ti.xcom_pull(create3_task_name) assert job3_name in job3_dict["name"] + job4_dict = ti.xcom_pull(create4_task_name) + assert job4_name in job4_dict["name"] + + job5_dict = ti.xcom_pull(create5_task_name) + assert job5_name in job5_dict["name"] + def _assert_updated_job(ti): job_dict = ti.xcom_pull(update_job1_task_name) @@ -104,9 +122,13 @@ def _assert_jobs(ti): job1_exists = any(job1_name in job["name"] for job in job_list) job2_exists = any(job2_name in job["name"] for job in job_list) + job4_exists = any(job4_name in job["name"] for job in job_list) + job5_exists = any(job5_name in job["name"] for job in job_list) assert job1_exists assert job2_exists + assert job4_exists + assert job5_exists def _assert_one_job(ti): @@ -234,6 +256,24 @@ def _create_job_instance_with_label(): dag=dag, ) + create4 = CloudRunCreateJobOperator( + task_id=create4_task_name, + project_id=PROJECT_ID, + region=region, + job_name=job4_name, + job=_create_job_dict(), + dag=dag, + ) + + create5 = CloudRunCreateJobOperator( + task_id=create5_task_name, + project_id=PROJECT_ID, + region=region, + job_name=job5_name, + job=_create_job_dict(), + dag=dag, + ) + assert_created_jobs = PythonOperator( task_id="assert-created-jobs", python_callable=_assert_created_jobs_xcom, dag=dag ) @@ -285,6 +325,26 @@ def _create_job_instance_with_label(): ) # [END howto_operator_cloud_run_execute_job_with_overrides] + execute4 = CloudRunExecuteJobOperator( + task_id=execute4_task_name, + project_id=PROJECT_ID, + region=region, + job_name=job4_name, + dag=dag, + deferrable=False, + transport="rest", + ) + + execute5 = CloudRunExecuteJobOperator( + task_id=execute5_task_name, + project_id=PROJECT_ID, + region=region, + job_name=job5_name, + dag=dag, + deferrable=True, + transport="rest", + ) + assert_executed_jobs = PythonOperator( task_id="assert-executed-jobs", python_callable=_assert_executed_jobs_xcom, dag=dag ) @@ -347,10 +407,28 @@ def _create_job_instance_with_label(): trigger_rule=TriggerRule.ALL_DONE, ) + delete_job4 = CloudRunDeleteJobOperator( + task_id="delete-job4", + project_id=PROJECT_ID, + region=region, + job_name=job4_name, + dag=dag, + trigger_rule=TriggerRule.ALL_DONE, + ) + + delete_job5 = CloudRunDeleteJobOperator( + task_id="delete-job5", + project_id=PROJECT_ID, + region=region, + job_name=job5_name, + dag=dag, + trigger_rule=TriggerRule.ALL_DONE, + ) + ( - (create1, create2, create3) + (create1, create2, create3, create4, create5) >> assert_created_jobs - >> (execute1, execute2, execute3) + >> (execute1, execute2, execute3, execute4, execute5) >> assert_executed_jobs >> list_jobs_limit >> assert_jobs_limit @@ -358,7 +436,7 @@ def _create_job_instance_with_label(): >> assert_jobs >> update_job1 >> assert_job_updated - >> (delete_job1, delete_job2, delete_job3) + >> (delete_job1, delete_job2, delete_job3, delete_job4, delete_job5) ) from tests_common.test_utils.watcher import watcher diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py index 4a8150459c4a0..bf7ff46123313 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py @@ -28,6 +28,7 @@ GetJobRequest, GetServiceRequest, Job, + JobsClient, ListJobsRequest, RunJobRequest, Service, @@ -264,16 +265,30 @@ def test_delete_job(self, mock_batch_service_client, cloud_run_hook): new=mock_base_gcp_hook_default_project_id, ) @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient") - @pytest.mark.parametrize(("transport", "expected_transport"), [("rest", "rest"), (None, None)]) - def test_get_conn_with_transport(self, mock_jobs_client, transport, expected_transport): + def test_get_conn_with_transport(self, mock_jobs_client): """Test that transport parameter is passed to JobsClient.""" - hook = CloudRunHook(transport=transport) + hook = CloudRunHook(transport="rest") hook.get_credentials = self.dummy_get_credentials hook.get_conn() mock_jobs_client.assert_called_once() call_kwargs = mock_jobs_client.call_args[1] - assert call_kwargs["transport"] == expected_transport + assert call_kwargs["transport"] == "rest" + + @mock.patch( + "airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__", + new=mock_base_gcp_hook_default_project_id, + ) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient") + def test_get_conn_omits_transport_when_none(self, mock_jobs_client): + """Test that transport is not passed to JobsClient when None.""" + hook = CloudRunHook(transport=None) + hook.get_credentials = self.dummy_get_credentials + hook.get_conn() + + mock_jobs_client.assert_called_once() + call_kwargs = mock_jobs_client.call_args[1] + assert "transport" not in call_kwargs def _mock_pager(self, number_of_jobs): mock_pager = [] @@ -293,13 +308,56 @@ async def test_get_operation(self): operations_pb2.GetOperationRequest(name=OPERATION_NAME), timeout=120 ) - def mock_get_operation(self, expected_operation): - get_operation_mock = mock.AsyncMock() - get_operation_mock.return_value = expected_operation - return get_operation_mock + @pytest.mark.asyncio + @pytest.mark.parametrize("transport", [None, "grpc"]) + @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsAsyncClient") + async def test_get_conn_uses_async_client_by_default(self, mock_async_client, transport): + """Test that get_conn uses JobsAsyncClient (grpc_asyncio) when transport is None or grpc.""" + hook = CloudRunAsyncHook(transport=transport) + mock_sync_hook = mock.MagicMock(spec=CloudRunHook) + mock_sync_hook.get_credentials.return_value = "credentials" + hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook) - def _dummy_get_credentials(self): - pass + await hook.get_conn() + + mock_async_client.assert_called_once() + call_kwargs = mock_async_client.call_args[1] + assert "transport" not in call_kwargs + + @pytest.mark.asyncio + @mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient") + async def test_get_conn_uses_sync_client_for_rest(self, mock_sync_client): + """Test that get_conn uses sync JobsClient with REST transport.""" + hook = CloudRunAsyncHook(transport="rest") + mock_sync_hook = mock.MagicMock(spec=CloudRunHook) + mock_sync_hook.get_credentials.return_value = "credentials" + hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook) + + await hook.get_conn() + + mock_sync_client.assert_called_once() + call_kwargs = mock_sync_client.call_args[1] + assert call_kwargs["transport"] == "rest" + + @pytest.mark.asyncio + @mock.patch("asyncio.to_thread") + async def test_get_operation_rest_uses_to_thread(self, mock_to_thread): + """Test that get_operation uses asyncio.to_thread for REST transport.""" + expected_operation = operations_pb2.Operation(name=OPERATION_NAME) + mock_to_thread.return_value = expected_operation + + hook = CloudRunAsyncHook(transport="rest") + mock_conn = mock.MagicMock(spec=JobsClient) # sync client + hook.get_conn = mock.AsyncMock(return_value=mock_conn) + + result = await hook.get_operation(operation_name=OPERATION_NAME) + + mock_to_thread.assert_called_once_with( + mock_conn.get_operation, + operations_pb2.GetOperationRequest(name=OPERATION_NAME), + timeout=120, + ) + assert result == expected_operation @pytest.mark.db_test