Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ ast
astroid
Async
async
asyncio
AsyncResult
athena
Atlassian
Expand Down
27 changes: 27 additions & 0 deletions providers/google/docs/operators/cloud/cloud_run.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,33 @@ or you can define the same operator in the deferrable mode:
:start-after: [START howto_operator_cloud_run_execute_job_deferrable_mode]
:end-before: [END howto_operator_cloud_run_execute_job_deferrable_mode]

Transport
^^^^^^^^^

The :class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator` accepts an optional ``transport``
parameter to choose the underlying API transport.

* ``transport="grpc"`` (default): use gRPC transport. If ``transport`` is not set, gRPC is used.
* ``transport="rest"``: use REST/HTTP transport.

In deferrable mode, when using gRPC (explicitly or by default), the trigger uses an async gRPC client internally; for
non-deferrable execution, the operator uses the regular (synchronous) gRPC client.

In general, it is better to use gRPC (or leave ``transport`` unset) unless there is a specific reason you must use REST (for example,
if gRPC is not available or fails in your environment).

.. rubric:: Deferrable mode considerations

When using deferrable mode, the operator defers to an async trigger that polls the long-running operation status.

* With gRPC (explicitly or by default), the trigger uses the native async gRPC client internally. The ``grpc_asyncio`` transport is
an implementation detail of the Google client library and is not a user-facing ``transport`` value.
* With REST, the REST transport is synchronous-only in the Google Cloud library. To remain compatible with deferrable mode, the
trigger performs REST calls using the synchronous client wrapped in a background thread.

REST can be used with deferrable mode, but it may be less efficient than gRPC and is generally best reserved for cases where gRPC
cannot be used.

You can also specify overrides that allow you to give a new entrypoint command to the job and more:

:class:`~airflow.providers.google.cloud.operators.cloud_run.CloudRunExecuteJobOperator`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import asyncio
import itertools
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -93,8 +94,9 @@ def get_conn(self):
client_kwargs = {
"credentials": self.get_credentials(),
"client_info": CLIENT_INFO,
"transport": self.transport,
}
if self.transport:
client_kwargs["transport"] = self.transport
self._client = JobsClient(**client_kwargs)
return self._client

Expand Down Expand Up @@ -187,8 +189,9 @@ class CloudRunAsyncHook(GoogleBaseAsyncHook):
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account.
:param transport: Optional. The transport to use for API requests. Can be 'rest' or 'grpc'.
If set to None, a transport is chosen automatically. Use 'rest' if gRPC is not available
or fails in your environment (e.g., Docker containers with certain network configurations).
When set to 'rest', uses the synchronous REST client wrapped with
``asyncio.to_thread()`` for compatibility with async triggers.
When None or 'grpc', uses the native async gRPC transport (grpc_asyncio).
"""

sync_hook_class = CloudRunHook
Expand All @@ -200,27 +203,38 @@ def __init__(
transport: Literal["rest", "grpc"] | None = None,
**kwargs,
):
self._client: JobsAsyncClient | None = None
self._client: JobsAsyncClient | JobsClient | None = None
self.transport = transport
super().__init__(
gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, transport=transport, **kwargs
)
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain, **kwargs)

async def get_conn(self):
if self._client is None:
sync_hook = await self.get_sync_hook()
client_kwargs = {
"credentials": sync_hook.get_credentials(),
"client_info": CLIENT_INFO,
"transport": self.transport,
}
self._client = JobsAsyncClient(**client_kwargs)
credentials = sync_hook.get_credentials()
if self.transport == "rest":
# REST transport is synchronous-only. Use the sync JobsClient here;
# get_operation() wraps calls with asyncio.to_thread() for async compat.
self._client = JobsClient(
credentials=credentials,
client_info=CLIENT_INFO,
transport="rest",
)
else:
# Default: use JobsAsyncClient which picks grpc_asyncio transport.
self._client = JobsAsyncClient(
credentials=credentials,
client_info=CLIENT_INFO,
)

return self._client

async def get_operation(self, operation_name: str) -> operations_pb2.Operation:
conn = await self.get_conn()
return await conn.get_operation(operations_pb2.GetOperationRequest(name=operation_name), timeout=120)
request = operations_pb2.GetOperationRequest(name=operation_name)
if self.transport == "rest":
# REST client is synchronous — run in a thread to avoid blocking the event loop.
return await asyncio.to_thread(conn.get_operation, request, timeout=120)
return await conn.get_operation(request, timeout=120)


class CloudRunServiceHook(GoogleBaseHook):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,5 @@ def _get_async_hook(self) -> CloudRunAsyncHook:
return CloudRunAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
transport=self.transport or "grpc",
transport=self.transport,
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,20 @@
job1_name = f"{job_name_prefix}1-{ENV_ID}"
job2_name = f"{job_name_prefix}2-{ENV_ID}"
job3_name = f"{job_name_prefix}3-{ENV_ID}"
job4_name = f"{job_name_prefix}4-{ENV_ID}"
job5_name = f"{job_name_prefix}5-{ENV_ID}"

create1_task_name = "create-job1"
create2_task_name = "create-job2"
create3_task_name = "create-job3"
create4_task_name = "create-job4"
create5_task_name = "create-job5"

execute1_task_name = "execute-job1"
execute2_task_name = "execute-job2"
execute3_task_name = "execute-job3"
execute4_task_name = "execute-job4"
execute5_task_name = "execute-job5"

update_job1_task_name = "update-job1"

Expand All @@ -82,6 +88,12 @@ def _assert_executed_jobs_xcom(ti):
job3_dict = ti.xcom_pull(execute3_task_name)
assert job3_name in job3_dict["name"]

job4_dict = ti.xcom_pull(execute4_task_name)
assert job4_name in job4_dict["name"]

job5_dict = ti.xcom_pull(execute5_task_name)
assert job5_name in job5_dict["name"]


def _assert_created_jobs_xcom(ti):
job1_dict = ti.xcom_pull(create1_task_name)
Expand All @@ -93,6 +105,12 @@ def _assert_created_jobs_xcom(ti):
job3_dict = ti.xcom_pull(create3_task_name)
assert job3_name in job3_dict["name"]

job4_dict = ti.xcom_pull(create4_task_name)
assert job4_name in job4_dict["name"]

job5_dict = ti.xcom_pull(create5_task_name)
assert job5_name in job5_dict["name"]


def _assert_updated_job(ti):
job_dict = ti.xcom_pull(update_job1_task_name)
Expand All @@ -104,9 +122,13 @@ def _assert_jobs(ti):

job1_exists = any(job1_name in job["name"] for job in job_list)
job2_exists = any(job2_name in job["name"] for job in job_list)
job4_exists = any(job4_name in job["name"] for job in job_list)
job5_exists = any(job5_name in job["name"] for job in job_list)

assert job1_exists
assert job2_exists
assert job4_exists
assert job5_exists


def _assert_one_job(ti):
Expand Down Expand Up @@ -234,6 +256,24 @@ def _create_job_instance_with_label():
dag=dag,
)

create4 = CloudRunCreateJobOperator(
task_id=create4_task_name,
project_id=PROJECT_ID,
region=region,
job_name=job4_name,
job=_create_job_dict(),
dag=dag,
)

create5 = CloudRunCreateJobOperator(
task_id=create5_task_name,
project_id=PROJECT_ID,
region=region,
job_name=job5_name,
job=_create_job_dict(),
dag=dag,
)

assert_created_jobs = PythonOperator(
task_id="assert-created-jobs", python_callable=_assert_created_jobs_xcom, dag=dag
)
Expand Down Expand Up @@ -285,6 +325,26 @@ def _create_job_instance_with_label():
)
# [END howto_operator_cloud_run_execute_job_with_overrides]

execute4 = CloudRunExecuteJobOperator(
task_id=execute4_task_name,
project_id=PROJECT_ID,
region=region,
job_name=job4_name,
dag=dag,
deferrable=False,
transport="rest",
)

execute5 = CloudRunExecuteJobOperator(
task_id=execute5_task_name,
project_id=PROJECT_ID,
region=region,
job_name=job5_name,
dag=dag,
deferrable=True,
transport="rest",
)

assert_executed_jobs = PythonOperator(
task_id="assert-executed-jobs", python_callable=_assert_executed_jobs_xcom, dag=dag
)
Expand Down Expand Up @@ -347,18 +407,36 @@ def _create_job_instance_with_label():
trigger_rule=TriggerRule.ALL_DONE,
)

delete_job4 = CloudRunDeleteJobOperator(
task_id="delete-job4",
project_id=PROJECT_ID,
region=region,
job_name=job4_name,
dag=dag,
trigger_rule=TriggerRule.ALL_DONE,
)

delete_job5 = CloudRunDeleteJobOperator(
task_id="delete-job5",
project_id=PROJECT_ID,
region=region,
job_name=job5_name,
dag=dag,
trigger_rule=TriggerRule.ALL_DONE,
)

(
(create1, create2, create3)
(create1, create2, create3, create4, create5)
>> assert_created_jobs
>> (execute1, execute2, execute3)
>> (execute1, execute2, execute3, execute4, execute5)
>> assert_executed_jobs
>> list_jobs_limit
>> assert_jobs_limit
>> list_jobs
>> assert_jobs
>> update_job1
>> assert_job_updated
>> (delete_job1, delete_job2, delete_job3)
>> (delete_job1, delete_job2, delete_job3, delete_job4, delete_job5)
)

from tests_common.test_utils.watcher import watcher
Expand Down
78 changes: 68 additions & 10 deletions providers/google/tests/unit/google/cloud/hooks/test_cloud_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
GetJobRequest,
GetServiceRequest,
Job,
JobsClient,
ListJobsRequest,
RunJobRequest,
Service,
Expand Down Expand Up @@ -264,16 +265,30 @@ def test_delete_job(self, mock_batch_service_client, cloud_run_hook):
new=mock_base_gcp_hook_default_project_id,
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
@pytest.mark.parametrize(("transport", "expected_transport"), [("rest", "rest"), (None, None)])
def test_get_conn_with_transport(self, mock_jobs_client, transport, expected_transport):
def test_get_conn_with_transport(self, mock_jobs_client):
"""Test that transport parameter is passed to JobsClient."""
hook = CloudRunHook(transport=transport)
hook = CloudRunHook(transport="rest")
hook.get_credentials = self.dummy_get_credentials
hook.get_conn()

mock_jobs_client.assert_called_once()
call_kwargs = mock_jobs_client.call_args[1]
assert call_kwargs["transport"] == expected_transport
assert call_kwargs["transport"] == "rest"

@mock.patch(
"airflow.providers.google.common.hooks.base_google.GoogleBaseHook.__init__",
new=mock_base_gcp_hook_default_project_id,
)
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
def test_get_conn_omits_transport_when_none(self, mock_jobs_client):
"""Test that transport is not passed to JobsClient when None."""
hook = CloudRunHook(transport=None)
hook.get_credentials = self.dummy_get_credentials
hook.get_conn()

mock_jobs_client.assert_called_once()
call_kwargs = mock_jobs_client.call_args[1]
assert "transport" not in call_kwargs

def _mock_pager(self, number_of_jobs):
mock_pager = []
Expand All @@ -293,13 +308,56 @@ async def test_get_operation(self):
operations_pb2.GetOperationRequest(name=OPERATION_NAME), timeout=120
)

def mock_get_operation(self, expected_operation):
get_operation_mock = mock.AsyncMock()
get_operation_mock.return_value = expected_operation
return get_operation_mock
@pytest.mark.asyncio
@pytest.mark.parametrize("transport", [None, "grpc"])
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsAsyncClient")
async def test_get_conn_uses_async_client_by_default(self, mock_async_client, transport):
"""Test that get_conn uses JobsAsyncClient (grpc_asyncio) when transport is None or grpc."""
hook = CloudRunAsyncHook(transport=transport)
mock_sync_hook = mock.MagicMock(spec=CloudRunHook)
mock_sync_hook.get_credentials.return_value = "credentials"
hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook)

def _dummy_get_credentials(self):
pass
await hook.get_conn()

mock_async_client.assert_called_once()
call_kwargs = mock_async_client.call_args[1]
assert "transport" not in call_kwargs

@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.cloud_run.JobsClient")
async def test_get_conn_uses_sync_client_for_rest(self, mock_sync_client):
"""Test that get_conn uses sync JobsClient with REST transport."""
hook = CloudRunAsyncHook(transport="rest")
mock_sync_hook = mock.MagicMock(spec=CloudRunHook)
mock_sync_hook.get_credentials.return_value = "credentials"
hook.get_sync_hook = mock.AsyncMock(return_value=mock_sync_hook)

await hook.get_conn()

mock_sync_client.assert_called_once()
call_kwargs = mock_sync_client.call_args[1]
assert call_kwargs["transport"] == "rest"

@pytest.mark.asyncio
@mock.patch("asyncio.to_thread")
async def test_get_operation_rest_uses_to_thread(self, mock_to_thread):
"""Test that get_operation uses asyncio.to_thread for REST transport."""
expected_operation = operations_pb2.Operation(name=OPERATION_NAME)
mock_to_thread.return_value = expected_operation

hook = CloudRunAsyncHook(transport="rest")
mock_conn = mock.MagicMock(spec=JobsClient) # sync client
hook.get_conn = mock.AsyncMock(return_value=mock_conn)

result = await hook.get_operation(operation_name=OPERATION_NAME)

mock_to_thread.assert_called_once_with(
mock_conn.get_operation,
operations_pb2.GetOperationRequest(name=OPERATION_NAME),
timeout=120,
)
assert result == expected_operation


@pytest.mark.db_test
Expand Down