Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Async Big Query Table Existence Sensor #135

Merged
merged 10 commits into from
Mar 23, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Example Airflow DAG for Google BigQuery Sensors."""
import os
from datetime import datetime

from airflow import DAG
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCreateEmptyDatasetOperator,
BigQueryCreateEmptyTableOperator,
BigQueryDeleteDatasetOperator,
BigQueryInsertJobOperator,
)
from airflow.providers.google.cloud.sensors.bigquery import (
BigQueryTablePartitionExistenceSensor,
)

from astronomer.providers.google.cloud.sensors.bigquery import (
BigQueryTableExistenceSensorAsync,
)

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "astronomer-airflow-providers")
DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "astro_dataset")
GCP_CONN_ID = os.environ.get("GCP_CONN_ID", "google_cloud_default")
LOCATION = os.environ.get("GCP_LOCATION", "us")
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved

TABLE_NAME = os.environ.get("TABLE_NAME", "partitioned_table")
INSERT_DATE = datetime.now().strftime("%Y-%m-%d")

PARTITION_NAME = "{{ ds_nodash }}"

INSERT_ROWS_QUERY = f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES " "(42, '{{ ds }}')"

SCHEMA = [
{"name": "value", "type": "INTEGER", "mode": "REQUIRED"},
{"name": "ds", "type": "DATE", "mode": "NULLABLE"},
]

dag_id = "example_bigquery_sensors"

with DAG(
dag_id,
schedule_interval="None", # Override to match your needs
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example", "async", "bigquery", "sensors"],
default_args={"gcp_conn_id": GCP_CONN_ID},
) as dag:

create_dataset = BigQueryCreateEmptyDatasetOperator(
task_id="create-dataset", dataset_id=DATASET_NAME, project_id=PROJECT_ID
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
)

bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
create_table = BigQueryCreateEmptyTableOperator(
task_id="create_table",
dataset_id=DATASET_NAME,
table_id=TABLE_NAME,
schema_fields=SCHEMA,
time_partitioning={
"type": "DAY",
"field": "ds",
},
)
# [START howto_sensor_bigquery_table]
check_table_exists = BigQueryTableExistenceSensorAsync(
task_id="check_table_exists", project_id=PROJECT_ID, dataset_id=DATASET_NAME, table_id=TABLE_NAME
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
)
# [END howto_sensor_bigquery_table]

execute_insert_query = BigQueryInsertJobOperator(
task_id="execute_insert_query",
configuration={
"query": {
"query": INSERT_ROWS_QUERY,
"useLegacySql": False,
}
},
)

# [START howto_sensor_bigquery_table_partition]
check_table_partition_exists = BigQueryTablePartitionExistenceSensor(
task_id="check_table_partition_exists",
project_id=PROJECT_ID,
dataset_id=DATASET_NAME,
table_id=TABLE_NAME,
partition_id=PARTITION_NAME,
)
# [END howto_sensor_bigquery_table_partition]

delete_dataset = BigQueryDeleteDatasetOperator(
task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True
)

create_dataset >> create_table
create_table >> check_table_exists
create_table >> execute_insert_query
execute_insert_query >> check_table_partition_exists
check_table_exists >> delete_dataset
check_table_partition_exists >> delete_dataset
30 changes: 29 additions & 1 deletion astronomer/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, _bq_cast
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from gcloud.aio.bigquery import Job
from gcloud.aio.bigquery import Job, Table
from google.cloud.bigquery import CopyJob, ExtractJob, LoadJob, QueryJob
from requests import Session

Expand Down Expand Up @@ -299,3 +299,31 @@ def interval_check(
raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")

self.log.info("All tests have passed")


class BigQueryTableHookAsync(GoogleBaseHookAsync):
"""Class to get async hook for Bigquery Table Async"""

sync_hook_class = BigQueryHook

async def get_table_client(
self, dataset: str, table_id: str, project_id: str, session: ClientSession
) -> Table:
"""
Returns a Google Big Query Table object.

:param dataset: The name of the dataset in which to look for the table storage bucket.
:param table_id: The name of the table to check the existence of.
:param project_id: The Google cloud project in which to look for the table.
The connection supplied to the hook must provide
access to the specified project.
:param session: aiohttp ClientSession
"""
with await self.service_file_as_context() as file:
return Table(
dataset_name=dataset,
table_name=table_id,
project=project_id,
service_file=file,
session=cast(Session, session),
)
78 changes: 78 additions & 0 deletions astronomer/providers/google/cloud/sensors/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""This module contains Google Big Query sensors."""
from typing import Any, Dict, Optional

from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor

from astronomer.providers.google.cloud.triggers.bigquery import (
BigQueryTableExistenceTrigger,
)


class BigQueryTableExistenceSensorAsync(BigQueryTableExistenceSensor):
"""
Checks for the existence of a table in Google Big Query.
:param project_id: The Google cloud project in which to look for the table.
The connection supplied to the hook must provide
access to the specified project.
:param dataset_id: The name of the dataset in which to look for the table.
storage bucket.
:param table_id: The name of the table to check the existence of.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval: The interval in seconds to wait between checks table existence.
"""

def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
polling_interval: float = 5.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.polling_interval = polling_interval
self.gcp_conn_id = gcp_conn_id

def execute(self, context: Dict[str, Any]) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryTableExistenceTrigger(
dataset_id=self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poll_interval=self.polling_interval,
gcp_conn_id=self.gcp_conn_id,
hook_params={
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
self.log.info("Sensor checks existence of table: %s", table_uri)
if event:
if event["status"] == "success":
return event["message"]
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")
86 changes: 85 additions & 1 deletion astronomer/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import asyncio
from typing import Any, AsyncIterator, Dict, Optional, SupportsAbs, Tuple, Union

from aiohttp import ClientSession
from aiohttp.client_exceptions import ClientResponseError
from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.google.cloud.hooks.bigquery import BigQueryHookAsync
from astronomer.providers.google.cloud.hooks.bigquery import (
BigQueryHookAsync,
BigQueryTableHookAsync,
)


class BigQueryInsertJobTrigger(BaseTrigger): # noqa: D101
Expand Down Expand Up @@ -378,3 +383,82 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return


class BigQueryTableExistenceTrigger(BaseTrigger):
"""Initialise the BigQuery Table Existence Trigger with needed parameters"""

def __init__(
self,
project_id: str,
dataset_id: str,
table_id: str,
gcp_conn_id: str,
hook_params: Dict[str, Any],
poll_interval: float = 4.0,
):
self.dataset_id = dataset_id
self.project_id = project_id
self.table_id = table_id
self.gcp_conn_id: str = gcp_conn_id
self.poll_interval = poll_interval
self.hook_params = hook_params

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes BigQueryTableExistenceTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger",
{
"dataset_id": self.dataset_id,
"project_id": self.project_id,
"table_id": self.table_id,
"gcp_conn_id": self.gcp_conn_id,
"poll_interval": self.poll_interval,
"hook_params": self.hook_params,
},
)

def _get_async_hook(self) -> BigQueryTableHookAsync:
return BigQueryTableHookAsync(gcp_conn_id=self.gcp_conn_id)

async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Will run until the table exists in the Google Big Query."""
while True:
try:
hook = self._get_async_hook()
response = await self._table_exists(
hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id
)
if response:
yield TriggerEvent({"status": "success", "message": "success"})
return
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for Table existence")
yield TriggerEvent({"status": "error", "message": str(e)})
return

async def _table_exists(
self, hook: BigQueryTableHookAsync, dataset: str, table_id: str, project_id: str
) -> bool:
"""
Create client session and make call to BigQueryTableHookAsync and check for the table in Google Big Query.

:param hook: BigQueryTableHookAsync Hook class
:param dataset: The name of the dataset in which to look for the table storage bucket.
:param table_id: The name of the table to check the existence of.
:param project_id: The Google cloud project in which to look for the table.
The connection supplied to the hook must provide
access to the specified project.
"""
async with ClientSession() as session:
try:
client = await hook.get_table_client(
dataset=dataset, table_id=table_id, project_id=project_id, session=session
)
response = await client.get()
return True if response else False
except ClientResponseError as err:
if err.status == 404:
return False
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
raise err
14 changes: 13 additions & 1 deletion tests/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import pytest
from airflow.exceptions import AirflowException
from gcloud.aio.bigquery import Job
from gcloud.aio.bigquery import Job, Table

from astronomer.providers.google.cloud.hooks.bigquery import (
BigQueryHookAsync,
BigQueryTableHookAsync,
_BigQueryHook,
)

Expand Down Expand Up @@ -259,3 +260,14 @@ def test_convert_to_float_if_possible(test_input, expected):
"""

assert BigQueryHookAsync._convert_to_float_if_possible(test_input) == expected


@pytest.mark.asyncio
@mock.patch("aiohttp.client.ClientSession")
async def test_get_table_client(mock_session):
"""Test get_table_client async function and check whether the return value is a Table instance object"""
hook = BigQueryTableHookAsync()
result = await hook.get_table_client(
dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session
)
assert isinstance(result, Table)
Loading