diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index 73cf777a51b03..370ba6e062e5f 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -38,6 +38,7 @@ from google.cloud.storage_transfer_v1 import ( ListTransferJobsRequest, + RunTransferJobRequest, StorageTransferServiceAsyncClient, TransferJob, TransferOperation, @@ -55,6 +56,7 @@ ) if TYPE_CHECKING: + from google.api_core import operation_async from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import ( ListTransferJobsAsyncPager, ) @@ -712,3 +714,17 @@ async def operations_contain_expected_statuses( f"Expected: {', '.join(expected_statuses_set)}" ) return False + + async def run_transfer_job(self, job_name: str) -> operation_async.AsyncOperation: + """ + Run Google Storage Transfer Service job. + + :param job_name: (Required) Name of the job to run. + """ + client = await self.get_conn() + request = RunTransferJobRequest( + job_name=job_name, + project_id=self.project_id, + ) + operation = await client.run_transfer_job(request=request) + return operation diff --git a/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py b/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py index 788a5822a71be..a6dda319744b0 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/cloud_storage_transfer_service.py @@ -65,6 +65,7 @@ ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudDataTransferServiceRunJobTrigger, CloudStorageTransferServiceCheckJobStatusTrigger, ) from airflow.providers.google.cloud.utils.helpers import normalize_directory_path @@ -468,6 +469,8 @@ class CloudDataTransferServiceRunJobOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param timeout: Time to wait for the operation to end in seconds. Defaults to 60 seconds if not specified. + :param deferrable: Run operator in the deferrable mode. """ # [START gcp_transfer_job_run_template_fields] @@ -489,6 +492,8 @@ def __init__( api_version: str = "v1", project_id: str = PROVIDE_PROJECT_ID, google_impersonation_chain: str | Sequence[str] | None = None, + timeout: float | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -497,6 +502,8 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.api_version = api_version self.google_impersonation_chain = google_impersonation_chain + self.timeout = timeout + self.deferrable = deferrable def _validate_inputs(self) -> None: if not self.job_name: @@ -518,8 +525,32 @@ def execute(self, context: Context) -> dict: job_name=self.job_name, ) + if self.deferrable: + self.defer( + timeout=timedelta(seconds=self.timeout or 60), + trigger=CloudDataTransferServiceRunJobTrigger( + job_name=self.job_name, + project_id=project_id, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.google_impersonation_chain, + ), + method_name="execute_complete", + ) + return hook.run_transfer_job(job_name=self.job_name, project_id=project_id) + def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: + """ + Act as a callback for when the trigger fires. + + This returns immediately. It relies on trigger to throw an exception, + otherwise it assumes execution was successful. + """ + if event["status"] == "error": + raise AirflowException(event["message"]) + + return event["job_result"] + class CloudDataTransferServiceGetOperationOperator(GoogleCloudBaseOperator): """ diff --git a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py index 89e3ec2e4457c..3c68abefc04f0 100644 --- a/providers/google/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +++ b/providers/google/src/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py @@ -23,6 +23,7 @@ from google.api_core.exceptions import GoogleAPIError from google.cloud.storage_transfer_v1.types import TransferOperation +from google.protobuf.json_format import MessageToDict from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( @@ -231,3 +232,92 @@ async def run(self) -> AsyncIterator[TriggerEvent]: except Exception as e: self.log.exception("Exception occurred while checking for query completion") yield TriggerEvent({"status": "error", "message": str(e)}) + + +class CloudDataTransferServiceRunJobTrigger(BaseTrigger): + """ + CloudDataTransferServiceRunJobTrigger run on the trigger worker to run Cloud Storage Transfer job. + + :param job_name: The name of the transfer job + :param project_id: The ID of the project that owns the Transfer Job. + :param poke_interval: Polling period in seconds to check for the status + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + """ + + def __init__( + self, + job_name: str, + project_id: str = PROVIDE_PROJECT_ID, + poke_interval: float = 10.0, + gcp_conn_id: str = "google_cloud_default", + impersonation_chain: str | Sequence[str] | None = None, + ): + super().__init__() + self.job_name = job_name + self.project_id = project_id + self.poke_interval = poke_interval + self.gcp_conn_id = gcp_conn_id + self.impersonation_chain = impersonation_chain + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize CloudDataTransferServiceRunJobTrigger arguments and classpath.""" + return ( + f"{self.__class__.__module__}.{self.__class__.__qualname__}", + { + "job_name": self.job_name, + "project_id": self.project_id, + "poke_interval": self.poke_interval, + "gcp_conn_id": self.gcp_conn_id, + "impersonation_chain": self.impersonation_chain, + }, + ) + + def _get_async_hook(self) -> CloudDataTransferServiceAsyncHook: + return CloudDataTransferServiceAsyncHook( + project_id=self.project_id, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Run the transfer job and yield a TriggerEvent.""" + hook = self._get_async_hook() + + try: + job_operation = await hook.run_transfer_job(self.job_name) + while True: + job_completed = await job_operation.done() + if job_completed: + yield TriggerEvent( + { + "status": "success", + "message": "Transfer operation run completed successfully", + "job_result": { + "name": job_operation.operation.name, + "metadata": MessageToDict( + job_operation.operation.metadata, preserving_proto_field_name=True + ), + "response": MessageToDict( + job_operation.operation.response, preserving_proto_field_name=True + ), + }, + } + ) + return + + self.log.info( + "Sleeping for %s seconds.", + self.poke_interval, + ) + await asyncio.sleep(self.poke_interval) + except Exception as e: + self.log.exception("Exception occurred while running transfer job") + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_cloud_storage_transfer_service_async.py b/providers/google/tests/unit/google/cloud/hooks/test_cloud_storage_transfer_service_async.py index 7f18363338d84..56517a00bb680 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_cloud_storage_transfer_service_async.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_cloud_storage_transfer_service_async.py @@ -223,3 +223,21 @@ async def test_operations_contain_expected_statuses_green_path(self, statuses, e operations, expected_norm ) assert result is True + + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") + @mock.patch(f"{TRANSFER_HOOK_PATH}.RunTransferJobRequest") + async def test_run_transfer_job(self, mock_run_transfer_job_request, mock_get_conn): + expected_job_result = AsyncMock() + mock_get_conn.return_value.run_transfer_job.side_effect = AsyncMock(return_value=expected_job_result) + + expected_request = mock.MagicMock() + mock_run_transfer_job_request.return_value = expected_request + + hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID) + job_name = "Job0" + jobs = await hook.run_transfer_job(job_name=job_name) + + assert jobs == expected_job_result + mock_run_transfer_job_request.assert_called_once_with(project_id=TEST_PROJECT_ID, job_name=job_name) + mock_get_conn.return_value.run_transfer_job.assert_called_once_with(request=expected_request) diff --git a/providers/google/tests/unit/google/cloud/triggers/test_cloud_storage_transfer_service.py b/providers/google/tests/unit/google/cloud/triggers/test_cloud_storage_transfer_service.py index 6bb9100b4236e..51c8f6c92d5a3 100644 --- a/providers/google/tests/unit/google/cloud/triggers/test_cloud_storage_transfer_service.py +++ b/providers/google/tests/unit/google/cloud/triggers/test_cloud_storage_transfer_service.py @@ -21,6 +21,7 @@ import pytest from google.api_core.exceptions import GoogleAPICallError from google.cloud.storage_transfer_v1 import TransferOperation +from google.protobuf import struct_pb2 from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import ( @@ -28,6 +29,7 @@ GcpTransferOperationStatus, ) from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import ( + CloudDataTransferServiceRunJobTrigger, CloudStorageTransferServiceCheckJobStatusTrigger, CloudStorageTransferServiceCreateJobsTrigger, ) @@ -391,3 +393,94 @@ async def test_run_returns_exception_event( actual_event = await trigger.run().asend(None) assert actual_event == expected_event + + +class TestCloudDataTransferServiceRunJobTrigger: + @pytest.fixture + def trigger(self): + return CloudDataTransferServiceRunJobTrigger( + project_id=PROJECT_ID, + job_name=JOB_0, + poke_interval=POLL_INTERVAL, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + def test_serialize(self, trigger): + class_path, serialized = trigger.serialize() + assert class_path == ( + "airflow.providers.google.cloud.triggers.cloud_storage_transfer_service" + ".CloudDataTransferServiceRunJobTrigger" + ) + assert serialized == { + "project_id": PROJECT_ID, + "job_name": JOB_0, + "poke_interval": POLL_INTERVAL, + "gcp_conn_id": GCP_CONN_ID, + "impersonation_chain": IMPERSONATION_CHAIN, + } + + @pytest.mark.parametrize( + "attr, expected_value", + [ + ("gcp_conn_id", GCP_CONN_ID), + ("impersonation_chain", IMPERSONATION_CHAIN), + ], + ) + def test_get_async_hook(self, attr, expected_value, trigger): + hook = trigger._get_async_hook() + actual_value = hook._hook_kwargs.get(attr) + assert isinstance(hook, CloudDataTransferServiceAsyncHook) + assert hook._hook_kwargs is not None + assert actual_value == expected_value + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".run_transfer_job") + async def test_run_returns_success_event( + self, + run_transfer_job, + trigger, + ): + test_metadata = struct_pb2.Struct() + test_metadata.update({"test": "test_metadata"}) + test_response = struct_pb2.Struct() + test_response.update({"test": "test_response"}) + test_operation = mock.Mock(metadata=test_metadata, response=test_response) + test_operation.name = "test_name" + run_transfer_job.return_value.operation = test_operation + run_transfer_job.done.side_effect = True + expected_event = TriggerEvent( + { + "status": "success", + "message": "Transfer operation run completed successfully", + "job_result": { + "name": "test_name", + "metadata": {"test": "test_metadata"}, + "response": {"test": "test_response"}, + }, + } + ) + + actual_event = await trigger.run().asend(None) + + assert actual_event == expected_event + assert run_transfer_job.call_count == 1 + + @pytest.mark.asyncio + @mock.patch(ASYNC_HOOK_CLASS_PATH + ".run_transfer_job") + async def test_run_returns_exception_event( + self, + run_transfer_job, + trigger, + ): + run_transfer_job.side_effect = Exception("Run transfer job operation failed") + expected_event = TriggerEvent( + { + "status": "error", + "message": "Run transfer job operation failed", + } + ) + + actual_event = await trigger.run().asend(None) + + assert actual_event == expected_event