Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

from google.api_core import protobuf_helpers
from google.cloud.storage_transfer_v1 import (
ListTransferJobsRequest,
StorageTransferServiceAsyncClient,
TransferJob,
TransferOperation,
)
from google.protobuf.json_format import MessageToDict
from googleapiclient.discovery import Resource, build
from googleapiclient.errors import HttpError

Expand Down Expand Up @@ -603,7 +603,7 @@ async def list_transfer_operations(
self,
request_filter: dict | None = None,
**kwargs,
) -> list[TransferOperation]:
) -> list[dict[str, Any]]:
"""
Get a transfer operation in Google Storage Transfer Service.

Expand Down Expand Up @@ -660,7 +660,12 @@ async def list_transfer_operations(
)

transfer_operations = [
protobuf_helpers.from_any_pb(TransferOperation, op.metadata) for op in operations
MessageToDict(
getattr(op, "_pb", op),
preserving_proto_field_name=True,
use_integers_for_enums=True,
)
for op in operations
]

return transfer_operations
Expand All @@ -677,7 +682,7 @@ async def _inject_project_id(self, body: dict, param_name: str, target_key: str)

@staticmethod
async def operations_contain_expected_statuses(
operations: list[TransferOperation], expected_statuses: set[str] | str
operations: list[dict[str, Any]], expected_statuses: set[str] | str
) -> bool:
"""
Check whether an operation exists with the expected status.
Expand All @@ -696,7 +701,7 @@ async def operations_contain_expected_statuses(
if not operations:
return False

current_statuses = {operation.status.name for operation in operations}
current_statuses = {TransferOperation.Status(op["metadata"]["status"]).name for op in operations}

if len(current_statuses - expected_statuses_set) != len(current_statuses):
return True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,11 +1082,11 @@ def execute_complete(self, context, event=None) -> None:
"""
job_state = event["job_state"]
job_id = event["job_id"]
if job_state == DataScanJob.State.FAILED:
if job_state == DataScanJob.State.FAILED.name: # type: ignore
raise AirflowException(f"Job failed:\n{job_id}")
if job_state == DataScanJob.State.CANCELLED:
if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
raise AirflowException(f"Job was cancelled:\n{job_id}")
if job_state == DataScanJob.State.SUCCEEDED:
if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
job = event["job"]
if not job["data_quality_result"]["passed"]:
if self.fail_on_dq_failure:
Expand Down Expand Up @@ -1260,11 +1260,11 @@ def execute_complete(self, context, event=None) -> None:
job_state = event["job_state"]
job_id = event["job_id"]
job = event["job"]
if job_state == DataScanJob.State.FAILED:
if job_state == DataScanJob.State.FAILED.name: # type: ignore
raise AirflowException(f"Job failed:\n{job_id}")
if job_state == DataScanJob.State.CANCELLED:
if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
raise AirflowException(f"Job was cancelled:\n{job_id}")
if job_state == DataScanJob.State.SUCCEEDED:
if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
if not job["data_quality_result"]["passed"]:
if self.fail_on_dq_failure:
raise AirflowDataQualityScanException(
Expand Down Expand Up @@ -1639,12 +1639,12 @@ def execute(self, context: Context) -> dict:
result_timeout=self.result_timeout,
)

if job.state == DataScanJob.State.FAILED:
if job.state == DataScanJob.State.FAILED.name: # type: ignore
raise AirflowException(f"Data Profile job failed: {job_id}")
if job.state == DataScanJob.State.SUCCEEDED:
if job.state == DataScanJob.State.SUCCEEDED.name: # type: ignore
self.log.info("Data Profile job executed successfully.")
else:
self.log.info("Data Profile job execution returned status: %s", job.status)
self.log.info("Data Profile job execution returned status: %s", job.state)

return job_id

Expand All @@ -1657,11 +1657,11 @@ def execute_complete(self, context, event=None) -> None:
"""
job_state = event["job_state"]
job_id = event["job_id"]
if job_state == DataScanJob.State.FAILED:
if job_state == DataScanJob.State.FAILED.name: # type: ignore
raise AirflowException(f"Job failed:\n{job_id}")
if job_state == DataScanJob.State.CANCELLED:
if job_state == DataScanJob.State.CANCELLED.name: # type: ignore
raise AirflowException(f"Job was cancelled:\n{job_id}")
if job_state == DataScanJob.State.SUCCEEDED:
if job_state == DataScanJob.State.SUCCEEDED.name: # type: ignore
self.log.info("Data Profile job executed successfully.")
return job_id

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1952,9 +1952,9 @@ def execute_complete(self, context, event=None) -> None:
job_state = event["job_state"]
job_id = event["job_id"]
job = event["job"]
if job_state == JobStatus.State.ERROR:
if job_state == JobStatus.State.ERROR.name: # type: ignore
raise AirflowException(f"Job {job_id} failed:\n{job}")
if job_state == JobStatus.State.CANCELLED:
if job_state == JobStatus.State.CANCELLED.name: # type: ignore
raise AirflowException(f"Job {job_id} was cancelled:\n{job}")
self.log.info("%s completed successfully.", self.task_id)
return job_id
Expand Down Expand Up @@ -2462,7 +2462,7 @@ def execute(self, context: Context):
if not self.hook.check_error_for_resource_is_not_ready_msg(batch.state_message):
break

self.handle_batch_status(context, batch.state, batch_id, batch.state_message)
self.handle_batch_status(context, batch.state.name, batch_id, batch.state_message)
return Batch.to_dict(batch)

@cached_property
Expand All @@ -2487,19 +2487,19 @@ def on_kill(self):
self.operation.cancel()

def handle_batch_status(
self, context: Context, state: Batch.State, batch_id: str, state_message: str | None = None
self, context: Context, state: str, batch_id: str, state_message: str | None = None
) -> None:
# The existing batch may be a number of states other than 'SUCCEEDED'\
# wait_for_operation doesn't fail if the job is cancelled, so we will check for it here which also
# finds a cancelling|canceled|unspecified job from wait_for_batch or the deferred trigger
link = DATAPROC_BATCH_LINK.format(region=self.region, project_id=self.project_id, batch_id=batch_id)
if state == Batch.State.FAILED:
if state == Batch.State.FAILED.name: # type: ignore
raise AirflowException(
f"Batch job {batch_id} failed with error: {state_message}.\nDriver logs: {link}"
)
if state in (Batch.State.CANCELLED, Batch.State.CANCELLING):
if state in (Batch.State.CANCELLED.name, Batch.State.CANCELLING.name): # type: ignore
raise AirflowException(f"Batch job {batch_id} was cancelled.\nDriver logs: {link}")
if state == Batch.State.STATE_UNSPECIFIED:
if state == Batch.State.STATE_UNSPECIFIED.name: # type: ignore
raise AirflowException(f"Batch job {batch_id} unspecified.\nDriver logs: {link}")
self.log.info("Batch job %s completed.\nDriver logs: %s", batch_id, link)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.polling_interval_seconds,
)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)})
yield TriggerEvent(
{
"job_id": self.job_id,
"job_state": DataScanJob.State(state).name,
"job": self._convert_to_dict(job),
}
)

def _convert_to_dict(self, job: DataScanJob) -> dict:
"""Return a representation of a DataScanJob instance as a dict."""
Expand Down Expand Up @@ -185,7 +191,13 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
self.polling_interval_seconds,
)
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": self._convert_to_dict(job)})
yield TriggerEvent(
{
"job_id": self.job_id,
"job_state": DataScanJob.State(state).name,
"job": self._convert_to_dict(job),
}
)

def _convert_to_dict(self, job: DataScanJob) -> dict:
"""Return a representation of a DataScanJob instance as a dict."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from asgiref.sync import sync_to_async
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, Job, JobStatus

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
Expand Down Expand Up @@ -194,7 +194,9 @@ async def run(self):
if state in (JobStatus.State.DONE, JobStatus.State.CANCELLED, JobStatus.State.ERROR):
break
await asyncio.sleep(self.polling_interval_seconds)
yield TriggerEvent({"job_id": self.job_id, "job_state": state, "job": job})
yield TriggerEvent(
{"job_id": self.job_id, "job_state": JobStatus.State(state).name, "job": Job.to_dict(job)}
)
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
Expand All @@ -212,7 +214,12 @@ async def run(self):
job_id=self.job_id, project_id=self.project_id, region=self.region
)
self.log.info("Job: %s is cancelled", self.job_id)
yield TriggerEvent({"job_id": self.job_id, "job_state": ClusterStatus.State.DELETING})
yield TriggerEvent(
{
"job_id": self.job_id,
"job_state": ClusterStatus.State.DELETING.name,
}
)
except Exception as e:
self.log.error("Failed to cancel the job: %s with error : %s", self.job_id, str(e))
raise e
Expand Down Expand Up @@ -322,7 +329,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
yield TriggerEvent(
{
"cluster_name": self.cluster_name,
"cluster_state": ClusterStatus.State(ClusterStatus.State.DELETING).name,
"cluster_state": ClusterStatus.State.DELETING.name, # type: ignore
"cluster": Cluster.to_dict(cluster),
}
)
Expand Down Expand Up @@ -428,12 +435,16 @@ async def run(self):

if state in (Batch.State.FAILED, Batch.State.SUCCEEDED, Batch.State.CANCELLED):
break
self.log.info("Current state is %s", state)
self.log.info("Current state is %s", Batch.State(state).name)
self.log.info("Sleeping for %s seconds.", self.polling_interval_seconds)
await asyncio.sleep(self.polling_interval_seconds)

yield TriggerEvent(
{"batch_id": self.batch_id, "batch_state": state, "batch_state_message": batch.state_message}
{
"batch_id": self.batch_id,
"batch_state": Batch.State(state).name,
"batch_state_message": batch.state_message,
}
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from __future__ import annotations

import json
from types import SimpleNamespace
from unittest import mock
from unittest.mock import AsyncMock

import pytest
from google.cloud.storage_transfer_v1.types.transfer_types import TransferOperation

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import (
Expand Down Expand Up @@ -124,23 +126,28 @@ async def test_get_last_operation_none(self, mock_deserialize, mock_conn, hook_a

@pytest.mark.asyncio
@mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn")
@mock.patch("google.api_core.protobuf_helpers.from_any_pb")
async def test_list_transfer_operations(self, from_any_pb, mock_conn, hook_async):
expected_operations = [mock.MagicMock(), mock.MagicMock()]
from_any_pb.side_effect = expected_operations

mock_conn.return_value.list_operations.side_effect = [
mock.MagicMock(next_page_token="token", operations=[mock.MagicMock()]),
mock.MagicMock(next_page_token=None, operations=[mock.MagicMock()]),
]
@mock.patch(f"{TRANSFER_HOOK_PATH}.MessageToDict")
async def test_list_transfer_operations(self, message_to_dict, mock_conn, hook_async):
expected = [{"name": "op1"}, {"name": "op2"}]
message_to_dict.side_effect = expected

op_with_pb = SimpleNamespace(_pb=mock.sentinel.pb1)
op_without_pb = object()

first_page = mock.MagicMock(next_page_token="token", operations=[op_with_pb])
second_page = mock.MagicMock(next_page_token=None, operations=[op_without_pb])
mock_conn.return_value.list_operations.side_effect = [first_page, second_page]

actual_operations = await hook_async.list_transfer_operations(
request_filter={
"project_id": TEST_PROJECT_ID,
},
actual = await hook_async.list_transfer_operations(
request_filter={"project_id": TEST_PROJECT_ID},
)
assert actual_operations == expected_operations

assert actual == expected
assert mock_conn.return_value.list_operations.call_count == 2
assert message_to_dict.call_args_list == [
mock.call(mock.sentinel.pb1, preserving_proto_field_name=True, use_integers_for_enums=True),
mock.call(op_without_pb, preserving_proto_field_name=True, use_integers_for_enums=True),
]

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand All @@ -158,14 +165,23 @@ async def test_list_transfer_operations(self, from_any_pb, mock_conn, hook_async
],
)
async def test_operations_contain_expected_statuses_red_path(self, statuses, expected_statuses):
operations = [mock.MagicMock(**{"status.name": status}) for status in statuses]
def to_name(x):
return x.name if hasattr(x, "name") else x

def proto_int(name: str) -> int:
return int(getattr(TransferOperation.Status, name))

operations = [{"metadata": {"status": proto_int(to_name(s))}} for s in statuses]

expected_names = tuple(to_name(s) for s in expected_statuses)

with pytest.raises(
AirflowException,
match=f"An unexpected operation status was encountered. Expected: {', '.join(expected_statuses)}",
match=f"An unexpected operation status was encountered. Expected: {', '.join(expected_names)}",
):
await CloudDataTransferServiceAsyncHook.operations_contain_expected_statuses(
operations, GcpTransferOperationStatus.IN_PROGRESS
operations,
GcpTransferOperationStatus.IN_PROGRESS,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -193,10 +209,17 @@ async def test_operations_contain_expected_statuses_red_path(self, statuses, exp
],
)
async def test_operations_contain_expected_statuses_green_path(self, statuses, expected_statuses):
operations = [mock.MagicMock(**{"status.name": status}) for status in statuses]
to_name = lambda x: x.name if hasattr(x, "name") else x
name_to_proto_int = lambda name: int(getattr(TransferOperation.Status, name))

operations = [{"metadata": {"status": name_to_proto_int(to_name(s))}} for s in statuses]

if isinstance(expected_statuses, (list, tuple, set)):
expected_norm = {to_name(s) for s in expected_statuses}
else:
expected_norm = to_name(expected_statuses)

result = await CloudDataTransferServiceAsyncHook.operations_contain_expected_statuses(
operations, expected_statuses
operations, expected_norm
)

assert result
assert result is True
Original file line number Diff line number Diff line change
Expand Up @@ -3314,7 +3314,7 @@ def test_execute_batch_already_exists_succeeds(self, mock_hook, mock_log):
)
mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}"
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED)
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.SUCCEEDED.name)

op.execute(context=MagicMock())
mock_hook.return_value.wait_for_batch.assert_called_once_with(
Expand Down Expand Up @@ -3357,7 +3357,7 @@ def test_execute_batch_already_exists_fails(self, mock_hook, mock_log):
)
mock_hook.return_value.create_batch.side_effect = AlreadyExists("")
mock_hook.return_value.create_batch.return_value.metadata.batch = f"prefix/{BATCH_ID}"
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED)
mock_hook.return_value.wait_for_batch.return_value = Batch(state=Batch.State.FAILED.name)

with pytest.raises(AirflowException) as exc:
op.execute(context=MagicMock())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ async def test_async_dataplex_job_triggers_on_success_should_execute_successfull
expected_event = TriggerEvent(
{
"job_id": TEST_JOB_ID,
"job_state": DataScanJob.State.SUCCEEDED,
"job_state": DataScanJob.State.SUCCEEDED.name,
"job": {},
}
)
Expand All @@ -113,7 +113,7 @@ async def test_async_dataplex_job_trigger_run_returns_error_event(
await asyncio.sleep(0.5)

expected_event = TriggerEvent(
{"job_id": TEST_JOB_ID, "job_state": DataScanJob.State.FAILED, "job": {}}
{"job_id": TEST_JOB_ID, "job_state": DataScanJob.State.FAILED.name, "job": {}}
)
assert expected_event == actual_event

Expand Down
Loading