diff --git a/astronomer/providers/google/cloud/example_dags/example_bigquery_sensors.py b/astronomer/providers/google/cloud/example_dags/example_bigquery_sensors.py new file mode 100644 index 000000000..ca2b31785 --- /dev/null +++ b/astronomer/providers/google/cloud/example_dags/example_bigquery_sensors.py @@ -0,0 +1,97 @@ +"""Example Airflow DAG for Google BigQuery Sensors.""" +import os +from datetime import datetime + +from airflow import DAG +from airflow.providers.google.cloud.operators.bigquery import ( + BigQueryCreateEmptyDatasetOperator, + BigQueryCreateEmptyTableOperator, + BigQueryDeleteDatasetOperator, + BigQueryInsertJobOperator, +) +from airflow.providers.google.cloud.sensors.bigquery import ( + BigQueryTablePartitionExistenceSensor, +) + +from astronomer.providers.google.cloud.sensors.bigquery import ( + BigQueryTableExistenceSensorAsync, +) + +PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "astronomer-airflow-providers") +DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "astro_dataset") +GCP_CONN_ID = os.environ.get("GCP_CONN_ID", "google_cloud_default") +LOCATION = os.environ.get("GCP_LOCATION", "us") + +TABLE_NAME = os.environ.get("TABLE_NAME", "partitioned_table") +INSERT_DATE = datetime.now().strftime("%Y-%m-%d") + +PARTITION_NAME = "{{ ds_nodash }}" + +INSERT_ROWS_QUERY = f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES " "(42, '{{ ds }}')" + +SCHEMA = [ + {"name": "value", "type": "INTEGER", "mode": "REQUIRED"}, + {"name": "ds", "type": "DATE", "mode": "NULLABLE"}, +] + +dag_id = "example_bigquery_sensors" + +with DAG( + dag_id, + schedule_interval="None", # Override to match your needs + start_date=datetime(2021, 1, 1), + catchup=False, + tags=["example", "async", "bigquery", "sensors"], + default_args={"gcp_conn_id": GCP_CONN_ID}, +) as dag: + + create_dataset = BigQueryCreateEmptyDatasetOperator( + task_id="create-dataset", dataset_id=DATASET_NAME, project_id=PROJECT_ID + ) + + create_table = BigQueryCreateEmptyTableOperator( + task_id="create_table", + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + schema_fields=SCHEMA, + time_partitioning={ + "type": "DAY", + "field": "ds", + }, + ) + # [START howto_sensor_bigquery_table] + check_table_exists = BigQueryTableExistenceSensorAsync( + task_id="check_table_exists", project_id=PROJECT_ID, dataset_id=DATASET_NAME, table_id=TABLE_NAME + ) + # [END howto_sensor_bigquery_table] + + execute_insert_query = BigQueryInsertJobOperator( + task_id="execute_insert_query", + configuration={ + "query": { + "query": INSERT_ROWS_QUERY, + "useLegacySql": False, + } + }, + ) + + # [START howto_sensor_bigquery_table_partition] + check_table_partition_exists = BigQueryTablePartitionExistenceSensor( + task_id="check_table_partition_exists", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + partition_id=PARTITION_NAME, + ) + # [END howto_sensor_bigquery_table_partition] + + delete_dataset = BigQueryDeleteDatasetOperator( + task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True + ) + + create_dataset >> create_table + create_table >> check_table_exists + create_table >> execute_insert_query + execute_insert_query >> check_table_partition_exists + check_table_exists >> delete_dataset + check_table_partition_exists >> delete_dataset diff --git a/astronomer/providers/google/cloud/hooks/bigquery.py b/astronomer/providers/google/cloud/hooks/bigquery.py index 85659ab9e..d05f190a0 100644 --- a/astronomer/providers/google/cloud/hooks/bigquery.py +++ b/astronomer/providers/google/cloud/hooks/bigquery.py @@ -4,7 +4,7 @@ from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, _bq_cast from airflow.providers.google.common.hooks.base_google import GoogleBaseHook -from gcloud.aio.bigquery import Job +from gcloud.aio.bigquery import Job, Table from google.cloud.bigquery import CopyJob, ExtractJob, LoadJob, QueryJob from requests import Session @@ -299,3 +299,31 @@ def interval_check( raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}") self.log.info("All tests have passed") + + +class BigQueryTableHookAsync(GoogleBaseHookAsync): + """Class to get async hook for Bigquery Table Async""" + + sync_hook_class = BigQueryHook + + async def get_table_client( + self, dataset: str, table_id: str, project_id: str, session: ClientSession + ) -> Table: + """ + Returns a Google Big Query Table object. + + :param dataset: The name of the dataset in which to look for the table storage bucket. + :param table_id: The name of the table to check the existence of. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :param session: aiohttp ClientSession + """ + with await self.service_file_as_context() as file: + return Table( + dataset_name=dataset, + table_name=table_id, + project=project_id, + service_file=file, + session=cast(Session, session), + ) diff --git a/astronomer/providers/google/cloud/sensors/bigquery.py b/astronomer/providers/google/cloud/sensors/bigquery.py new file mode 100644 index 000000000..12643155f --- /dev/null +++ b/astronomer/providers/google/cloud/sensors/bigquery.py @@ -0,0 +1,78 @@ +"""This module contains Google Big Query sensors.""" +from typing import Any, Dict, Optional + +from airflow.exceptions import AirflowException +from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor + +from astronomer.providers.google.cloud.triggers.bigquery import ( + BigQueryTableExistenceTrigger, +) + + +class BigQueryTableExistenceSensorAsync(BigQueryTableExistenceSensor): + """ + Checks for the existence of a table in Google Big Query. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + :param dataset_id: The name of the dataset in which to look for the table. + storage bucket. + :param table_id: The name of the table to check the existence of. + :param gcp_conn_id: The connection ID used to connect to Google Cloud. + :param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud. + This parameter has been deprecated. You should pass the gcp_conn_id parameter instead. + :param delegate_to: The account to impersonate using domain-wide delegation of authority, + if any. For this to work, the service account making the request must have + domain-wide delegation enabled. + :param impersonation_chain: Optional service account to impersonate using short-term + credentials, or chained list of accounts required to get the access_token + of the last account in the list, which will be impersonated in the request. + If set as a string, the account must grant the originating account + the Service Account Token Creator IAM role. + If set as a sequence, the identities from the list must grant + Service Account Token Creator IAM role to the directly preceding identity, with first + account from the list granting this role to the originating account (templated). + :param polling_interval: The interval in seconds to wait between checks table existence. + """ + + def __init__( + self, + gcp_conn_id: str = "google_cloud_default", + polling_interval: float = 5.0, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.polling_interval = polling_interval + self.gcp_conn_id = gcp_conn_id + + def execute(self, context: Dict[str, Any]) -> None: + """Airflow runs this method on the worker and defers using the trigger.""" + self.defer( + timeout=self.execution_timeout, + trigger=BigQueryTableExistenceTrigger( + dataset_id=self.dataset_id, + table_id=self.table_id, + project_id=self.project_id, + poll_interval=self.polling_interval, + gcp_conn_id=self.gcp_conn_id, + hook_params={ + "delegate_to": self.delegate_to, + "impersonation_chain": self.impersonation_chain, + }, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str: + """ + Callback for when the trigger fires - returns immediately. + Relies on trigger to throw an exception, otherwise it assumes execution was + successful. + """ + table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}" + self.log.info("Sensor checks existence of table: %s", table_uri) + if event: + if event["status"] == "success": + return event["message"] + raise AirflowException(event["message"]) + raise AirflowException("No event received in trigger callback") diff --git a/astronomer/providers/google/cloud/triggers/bigquery.py b/astronomer/providers/google/cloud/triggers/bigquery.py index c244bc61d..7ae535072 100644 --- a/astronomer/providers/google/cloud/triggers/bigquery.py +++ b/astronomer/providers/google/cloud/triggers/bigquery.py @@ -1,9 +1,14 @@ import asyncio from typing import Any, AsyncIterator, Dict, Optional, SupportsAbs, Tuple, Union +from aiohttp import ClientSession +from aiohttp.client_exceptions import ClientResponseError from airflow.triggers.base import BaseTrigger, TriggerEvent -from astronomer.providers.google.cloud.hooks.bigquery import BigQueryHookAsync +from astronomer.providers.google.cloud.hooks.bigquery import ( + BigQueryHookAsync, + BigQueryTableHookAsync, +) class BigQueryInsertJobTrigger(BaseTrigger): # noqa: D101 @@ -378,3 +383,82 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] self.log.exception("Exception occurred while checking for query completion") yield TriggerEvent({"status": "error", "message": str(e)}) return + + +class BigQueryTableExistenceTrigger(BaseTrigger): + """Initialise the BigQuery Table Existence Trigger with needed parameters""" + + def __init__( + self, + project_id: str, + dataset_id: str, + table_id: str, + gcp_conn_id: str, + hook_params: Dict[str, Any], + poll_interval: float = 4.0, + ): + self.dataset_id = dataset_id + self.project_id = project_id + self.table_id = table_id + self.gcp_conn_id: str = gcp_conn_id + self.poll_interval = poll_interval + self.hook_params = hook_params + + def serialize(self) -> Tuple[str, Dict[str, Any]]: + """Serializes BigQueryTableExistenceTrigger arguments and classpath.""" + return ( + "astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger", + { + "dataset_id": self.dataset_id, + "project_id": self.project_id, + "table_id": self.table_id, + "gcp_conn_id": self.gcp_conn_id, + "poll_interval": self.poll_interval, + "hook_params": self.hook_params, + }, + ) + + def _get_async_hook(self) -> BigQueryTableHookAsync: + return BigQueryTableHookAsync(gcp_conn_id=self.gcp_conn_id) + + async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override] + """Will run until the table exists in the Google Big Query.""" + while True: + try: + hook = self._get_async_hook() + response = await self._table_exists( + hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id + ) + if response: + yield TriggerEvent({"status": "success", "message": "success"}) + return + await asyncio.sleep(self.poll_interval) + except Exception as e: + self.log.exception("Exception occurred while checking for Table existence") + yield TriggerEvent({"status": "error", "message": str(e)}) + return + + async def _table_exists( + self, hook: BigQueryTableHookAsync, dataset: str, table_id: str, project_id: str + ) -> bool: + """ + Create client session and make call to BigQueryTableHookAsync and check for the table in Google Big Query. + + :param hook: BigQueryTableHookAsync Hook class + :param dataset: The name of the dataset in which to look for the table storage bucket. + :param table_id: The name of the table to check the existence of. + :param project_id: The Google cloud project in which to look for the table. + The connection supplied to the hook must provide + access to the specified project. + """ + async with ClientSession() as session: + try: + client = await hook.get_table_client( + dataset=dataset, table_id=table_id, project_id=project_id, session=session + ) + response = await client.get() + return True if response else False + except ClientResponseError as err: + if err.status == 404: + return False + raise err diff --git a/tests/google/cloud/hooks/test_bigquery.py b/tests/google/cloud/hooks/test_bigquery.py index 5cc342d09..836b5268e 100644 --- a/tests/google/cloud/hooks/test_bigquery.py +++ b/tests/google/cloud/hooks/test_bigquery.py @@ -2,10 +2,11 @@ import pytest from airflow.exceptions import AirflowException -from gcloud.aio.bigquery import Job +from gcloud.aio.bigquery import Job, Table from astronomer.providers.google.cloud.hooks.bigquery import ( BigQueryHookAsync, + BigQueryTableHookAsync, _BigQueryHook, ) @@ -259,3 +260,14 @@ def test_convert_to_float_if_possible(test_input, expected): """ assert BigQueryHookAsync._convert_to_float_if_possible(test_input) == expected + + +@pytest.mark.asyncio +@mock.patch("aiohttp.client.ClientSession") +async def test_get_table_client(mock_session): + """Test get_table_client async function and check whether the return value is a Table instance object""" + hook = BigQueryTableHookAsync() + result = await hook.get_table_client( + dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session + ) + assert isinstance(result, Table) diff --git a/tests/google/cloud/sensors/test_bigquery.py b/tests/google/cloud/sensors/test_bigquery.py new file mode 100644 index 000000000..fd3e45589 --- /dev/null +++ b/tests/google/cloud/sensors/test_bigquery.py @@ -0,0 +1,80 @@ +from unittest import mock + +import pytest +from airflow.exceptions import AirflowException, TaskDeferred + +from astronomer.providers.google.cloud.sensors.bigquery import ( + BigQueryTableExistenceSensorAsync, +) +from astronomer.providers.google.cloud.triggers.bigquery import ( + BigQueryTableExistenceTrigger, +) + +PROJECT_ID = "test-astronomer-airflow-providers" +DATASET_NAME = "test-astro_dataset" +TABLE_NAME = "test-partitioned_table" + + +@pytest.fixture() +def context(): + """ + Creates an empty context. + """ + context = {} + yield context + + +def test_big_query_table_existence_sensor_async(): + """ + Asserts that a task is deferred and a BigQueryTableExistenceTrigger will be fired + when the BigQueryTableExistenceSensorAsync is executed. + """ + task = BigQueryTableExistenceSensorAsync( + task_id="check_table_exists", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + with pytest.raises(TaskDeferred) as exc: + task.execute(context) + assert isinstance( + exc.value.trigger, BigQueryTableExistenceTrigger + ), "Trigger is not a BigQueryTableExistenceTrigger" + + +def test_big_query_table_existence_sensor_async_execute_failure(context): + """Tests that an AirflowException is raised in case of error event""" + task = BigQueryTableExistenceSensorAsync( + task_id="task-id", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + with pytest.raises(AirflowException): + task.execute_complete(context=None, event={"status": "error", "message": "test failure message"}) + + +def test_big_query_table_existence_sensor_async_execute_complete(): + """Asserts that logging occurs as expected""" + task = BigQueryTableExistenceSensorAsync( + task_id="task-id", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + table_uri = f"{PROJECT_ID}:{DATASET_NAME}.{TABLE_NAME}" + with mock.patch.object(task.log, "info") as mock_log_info: + task.execute_complete(context=None, event={"status": "success", "message": "Job completed"}) + mock_log_info.assert_called_with("Sensor checks existence of table: %s", table_uri) + + +def test_redshift_sensor_async_execute_complete_event_none(): + """Asserts that logging occurs as expected""" + task = BigQueryTableExistenceSensorAsync( + task_id="task-id", + project_id=PROJECT_ID, + dataset_id=DATASET_NAME, + table_id=TABLE_NAME, + ) + with pytest.raises(AirflowException): + task.execute_complete(context=None, event=None) diff --git a/tests/google/cloud/triggers/test_bigquery.py b/tests/google/cloud/triggers/test_bigquery.py index c958869c3..547bc8ae9 100644 --- a/tests/google/cloud/triggers/test_bigquery.py +++ b/tests/google/cloud/triggers/test_bigquery.py @@ -3,13 +3,19 @@ from unittest import mock import pytest +from aiohttp import ClientResponseError, RequestInfo from airflow.triggers.base import TriggerEvent +from gcloud.aio.bigquery import Table +from multidict import CIMultiDict +from yarl import URL +from astronomer.providers.google.cloud.hooks.bigquery import BigQueryTableHookAsync from astronomer.providers.google.cloud.triggers.bigquery import ( BigQueryCheckTrigger, BigQueryGetDataTrigger, BigQueryInsertJobTrigger, BigQueryIntervalCheckTrigger, + BigQueryTableExistenceTrigger, BigQueryValueCheckTrigger, ) @@ -32,6 +38,8 @@ TEST_DAYS_BACK = -7 TEST_RATIO_FORMULA = "max_over_min" TEST_IGNORE_ZERO = True +TEST_GCP_CONN_ID = "TEST_GCP_CONN_ID" +TEST_HOOK_PARAMS = {} def test_bigquery_insert_job_op_trigger_serialization(): @@ -752,3 +760,179 @@ async def test_bigquery_value_check_trigger_exception(mock_job_status): assert len(task) == 1 assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + +def test_big_query_table_existence_trigger_serialization(): + """ + Asserts that the BigQueryTableExistenceTrigger correctly serializes its arguments + and classpath. + """ + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger" + assert kwargs == { + "dataset_id": TEST_DATASET_ID, + "project_id": TEST_GCP_PROJECT_ID, + "table_id": TEST_TABLE_ID, + "gcp_conn_id": TEST_GCP_CONN_ID, + "poll_interval": POLLING_PERIOD_SECONDS, + "hook_params": TEST_HOOK_PARAMS, + } + + +@pytest.mark.asyncio +@mock.patch("astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists") +async def test_big_query_table_existence_trigger_success(mock_table_exists): + """ + Tests success case BigQueryTableExistenceTrigger + """ + mock_table_exists.return_value = True + + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + + task = [i async for i in trigger.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "success", "message": "success"}) in task + + +@pytest.mark.asyncio +@mock.patch("astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists") +async def test_big_query_table_existence_trigger_pending(mock_table_exists): + """ + Test that BigQueryTableExistenceTrigger is in loop till the table exist. + """ + mock_table_exists.return_value = False + + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was not returned + assert task.done() is False + asyncio.get_event_loop().stop() + + +@pytest.mark.asyncio +@mock.patch("astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger._table_exists") +async def test_big_query_table_existence_trigger_exception(mock_table_exists): + """ + Test BigQueryTableExistenceTrigger throws exception if any error. + """ + mock_table_exists.side_effect = mock.AsyncMock(side_effect=Exception("Test exception")) + + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + task = [i async for i in trigger.run()] + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "mock_get_table_client_value, expected_value", + [ + ( + Table, + True, + ) + ], +) +@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.BigQueryTableHookAsync.get_table_client") +async def test_table_exists(mock_get_table_client, mock_get_table_client_value, expected_value): + """Test BigQueryTableExistenceTrigger._table_exists async function with mocked value and mocked return value""" + hook = mock.AsyncMock(BigQueryTableHookAsync) + mock_get_table_client.return_value = mock_get_table_client_value + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + res = await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID) + assert res == expected_value + + +@pytest.mark.asyncio +@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.BigQueryTableHookAsync.get_table_client") +async def test_table_exists_exception(mock_get_table_client): + """Test BigQueryTableExistenceTrigger._table_exists async function with exception and return False""" + hook = BigQueryTableHookAsync() + mock_get_table_client.side_effect = ClientResponseError( + history=(), + request_info=RequestInfo( + headers=CIMultiDict(), + real_url=URL("https://example.com"), + method="GET", + url=URL("https://example.com"), + ), + status=404, + message="Not Found", + ) + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + res = await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID) + expected_response = False + assert res == expected_response + + +@pytest.mark.asyncio +@mock.patch("astronomer.providers.google.cloud.hooks.bigquery.BigQueryTableHookAsync.get_table_client") +async def test_table_exists_raise_exception(mock_get_table_client): + """Test BigQueryTableExistenceTrigger._table_exists async function with raise exception""" + hook = BigQueryTableHookAsync() + mock_get_table_client.side_effect = ClientResponseError( + history=(), + request_info=RequestInfo( + headers=CIMultiDict(), + real_url=URL("https://example.com"), + method="GET", + url=URL("https://example.com"), + ), + status=400, + message="Not Found", + ) + trigger = BigQueryTableExistenceTrigger( + TEST_GCP_PROJECT_ID, + TEST_DATASET_ID, + TEST_TABLE_ID, + TEST_GCP_CONN_ID, + TEST_HOOK_PARAMS, + POLLING_PERIOD_SECONDS, + ) + with pytest.raises(ClientResponseError): + await trigger._table_exists(hook, TEST_DATASET_ID, TEST_TABLE_ID, TEST_GCP_PROJECT_ID)