Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from google.cloud.storage_transfer_v1 import (
ListTransferJobsRequest,
RunTransferJobRequest,
StorageTransferServiceAsyncClient,
TransferJob,
TransferOperation,
Expand All @@ -55,6 +56,7 @@
)

if TYPE_CHECKING:
from google.api_core import operation_async
from google.cloud.storage_transfer_v1.services.storage_transfer_service.pagers import (
ListTransferJobsAsyncPager,
)
Expand Down Expand Up @@ -712,3 +714,17 @@ async def operations_contain_expected_statuses(
f"Expected: {', '.join(expected_statuses_set)}"
)
return False

async def run_transfer_job(self, job_name: str) -> operation_async.AsyncOperation:
"""
Run Google Storage Transfer Service job.

:param job_name: (Required) Name of the job to run.
"""
client = await self.get_conn()
request = RunTransferJobRequest(
job_name=job_name,
project_id=self.project_id,
)
operation = await client.run_transfer_job(request=request)
return operation
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
)
from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudDataTransferServiceRunJobTrigger,
CloudStorageTransferServiceCheckJobStatusTrigger,
)
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
Expand Down Expand Up @@ -468,6 +469,8 @@ class CloudDataTransferServiceRunJobOperator(GoogleCloudBaseOperator):
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param timeout: Time to wait for the operation to end in seconds. Defaults to 60 seconds if not specified.
:param deferrable: Run operator in the deferrable mode.
"""

# [START gcp_transfer_job_run_template_fields]
Expand All @@ -489,6 +492,8 @@ def __init__(
api_version: str = "v1",
project_id: str = PROVIDE_PROJECT_ID,
google_impersonation_chain: str | Sequence[str] | None = None,
timeout: float | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -497,6 +502,8 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self.google_impersonation_chain = google_impersonation_chain
self.timeout = timeout
self.deferrable = deferrable

def _validate_inputs(self) -> None:
if not self.job_name:
Expand All @@ -518,8 +525,32 @@ def execute(self, context: Context) -> dict:
job_name=self.job_name,
)

if self.deferrable:
self.defer(
timeout=timedelta(seconds=self.timeout or 60),
trigger=CloudDataTransferServiceRunJobTrigger(
job_name=self.job_name,
project_id=project_id,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.google_impersonation_chain,
),
method_name="execute_complete",
)

return hook.run_transfer_job(job_name=self.job_name, project_id=project_id)

def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
"""
Act as a callback for when the trigger fires.

This returns immediately. It relies on trigger to throw an exception,
otherwise it assumes execution was successful.
"""
if event["status"] == "error":
raise AirflowException(event["message"])

return event["job_result"]


class CloudDataTransferServiceGetOperationOperator(GoogleCloudBaseOperator):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from google.api_core.exceptions import GoogleAPIError
from google.cloud.storage_transfer_v1.types import TransferOperation
from google.protobuf.json_format import MessageToDict

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
Expand Down Expand Up @@ -231,3 +232,92 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
except Exception as e:
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})


class CloudDataTransferServiceRunJobTrigger(BaseTrigger):
"""
CloudDataTransferServiceRunJobTrigger run on the trigger worker to run Cloud Storage Transfer job.

:param job_name: The name of the transfer job
:param project_id: The ID of the project that owns the Transfer Job.
:param poke_interval: Polling period in seconds to check for the status
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

def __init__(
self,
job_name: str,
project_id: str = PROVIDE_PROJECT_ID,
poke_interval: float = 10.0,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
):
super().__init__()
self.job_name = job_name
self.project_id = project_id
self.poke_interval = poke_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize CloudDataTransferServiceRunJobTrigger arguments and classpath."""
return (
f"{self.__class__.__module__}.{self.__class__.__qualname__}",
{
"job_name": self.job_name,
"project_id": self.project_id,
"poke_interval": self.poke_interval,
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
},
)

def _get_async_hook(self) -> CloudDataTransferServiceAsyncHook:
return CloudDataTransferServiceAsyncHook(
project_id=self.project_id,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Run the transfer job and yield a TriggerEvent."""
hook = self._get_async_hook()

try:
job_operation = await hook.run_transfer_job(self.job_name)
while True:
job_completed = await job_operation.done()
if job_completed:
yield TriggerEvent(
{
"status": "success",
"message": "Transfer operation run completed successfully",
"job_result": {
"name": job_operation.operation.name,
"metadata": MessageToDict(
job_operation.operation.metadata, preserving_proto_field_name=True
),
"response": MessageToDict(
job_operation.operation.response, preserving_proto_field_name=True
),
},
}
)
return

self.log.info(
"Sleeping for %s seconds.",
self.poke_interval,
)
await asyncio.sleep(self.poke_interval)
except Exception as e:
self.log.exception("Exception occurred while running transfer job")
yield TriggerEvent({"status": "error", "message": str(e)})
Original file line number Diff line number Diff line change
Expand Up @@ -223,3 +223,21 @@ async def test_operations_contain_expected_statuses_green_path(self, statuses, e
operations, expected_norm
)
assert result is True

@pytest.mark.asyncio
@mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn")
@mock.patch(f"{TRANSFER_HOOK_PATH}.RunTransferJobRequest")
async def test_run_transfer_job(self, mock_run_transfer_job_request, mock_get_conn):
expected_job_result = AsyncMock()
mock_get_conn.return_value.run_transfer_job.side_effect = AsyncMock(return_value=expected_job_result)

expected_request = mock.MagicMock()
mock_run_transfer_job_request.return_value = expected_request

hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID)
job_name = "Job0"
jobs = await hook.run_transfer_job(job_name=job_name)

assert jobs == expected_job_result
mock_run_transfer_job_request.assert_called_once_with(project_id=TEST_PROJECT_ID, job_name=job_name)
mock_get_conn.return_value.run_transfer_job.assert_called_once_with(request=expected_request)
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
import pytest
from google.api_core.exceptions import GoogleAPICallError
from google.cloud.storage_transfer_v1 import TransferOperation
from google.protobuf import struct_pb2

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
CloudDataTransferServiceAsyncHook,
GcpTransferOperationStatus,
)
from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service import (
CloudDataTransferServiceRunJobTrigger,
CloudStorageTransferServiceCheckJobStatusTrigger,
CloudStorageTransferServiceCreateJobsTrigger,
)
Expand Down Expand Up @@ -391,3 +393,94 @@ async def test_run_returns_exception_event(
actual_event = await trigger.run().asend(None)

assert actual_event == expected_event


class TestCloudDataTransferServiceRunJobTrigger:
@pytest.fixture
def trigger(self):
return CloudDataTransferServiceRunJobTrigger(
project_id=PROJECT_ID,
job_name=JOB_0,
poke_interval=POLL_INTERVAL,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)

def test_serialize(self, trigger):
class_path, serialized = trigger.serialize()
assert class_path == (
"airflow.providers.google.cloud.triggers.cloud_storage_transfer_service"
".CloudDataTransferServiceRunJobTrigger"
)
assert serialized == {
"project_id": PROJECT_ID,
"job_name": JOB_0,
"poke_interval": POLL_INTERVAL,
"gcp_conn_id": GCP_CONN_ID,
"impersonation_chain": IMPERSONATION_CHAIN,
}

@pytest.mark.parametrize(
"attr, expected_value",
[
("gcp_conn_id", GCP_CONN_ID),
("impersonation_chain", IMPERSONATION_CHAIN),
],
)
def test_get_async_hook(self, attr, expected_value, trigger):
hook = trigger._get_async_hook()
actual_value = hook._hook_kwargs.get(attr)
assert isinstance(hook, CloudDataTransferServiceAsyncHook)
assert hook._hook_kwargs is not None
assert actual_value == expected_value

@pytest.mark.asyncio
@mock.patch(ASYNC_HOOK_CLASS_PATH + ".run_transfer_job")
async def test_run_returns_success_event(
self,
run_transfer_job,
trigger,
):
test_metadata = struct_pb2.Struct()
test_metadata.update({"test": "test_metadata"})
test_response = struct_pb2.Struct()
test_response.update({"test": "test_response"})
test_operation = mock.Mock(metadata=test_metadata, response=test_response)
test_operation.name = "test_name"
run_transfer_job.return_value.operation = test_operation
run_transfer_job.done.side_effect = True
expected_event = TriggerEvent(
{
"status": "success",
"message": "Transfer operation run completed successfully",
"job_result": {
"name": "test_name",
"metadata": {"test": "test_metadata"},
"response": {"test": "test_response"},
},
}
)

actual_event = await trigger.run().asend(None)

assert actual_event == expected_event
assert run_transfer_job.call_count == 1

@pytest.mark.asyncio
@mock.patch(ASYNC_HOOK_CLASS_PATH + ".run_transfer_job")
async def test_run_returns_exception_event(
self,
run_transfer_job,
trigger,
):
run_transfer_job.side_effect = Exception("Run transfer job operation failed")
expected_event = TriggerEvent(
{
"status": "error",
"message": "Run transfer job operation failed",
}
)

actual_event = await trigger.run().asend(None)

assert actual_event == expected_event
Loading