Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Async Big Query Table Existence Sensor #135

Merged
merged 10 commits into from
Mar 23, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
Example Airflow DAG for Google BigQuery Sensors.
"""
import os
from datetime import datetime

from airflow import models
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCreateEmptyDatasetOperator,
BigQueryCreateEmptyTableOperator,
BigQueryDeleteDatasetOperator,
BigQueryInsertJobOperator,
)
from airflow.providers.google.cloud.sensors.bigquery import (
BigQueryTablePartitionExistenceSensor,
)

from astronomer.providers.google.cloud.sensors.bigquery import (
BigQueryTableExistenceSensorAsync
)

PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "astronomer-airflow-providers")
DATASET_NAME = os.environ.get("GCP_BIGQUERY_DATASET_NAME", "astro_dataset")
LOCATION = "us"

TABLE_NAME = "partitioned_table"
INSERT_DATE = datetime.now().strftime("%Y-%m-%d")

PARTITION_NAME = "{{ ds_nodash }}"

INSERT_ROWS_QUERY = f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES " "(42, '{{ ds }}')"

SCHEMA = [
{"name": "value", "type": "INTEGER", "mode": "REQUIRED"},
{"name": "ds", "type": "DATE", "mode": "NULLABLE"},
]

dag_id = "example_bigquery_sensors"

with models.DAG(
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
dag_id,
schedule_interval='@once', # Override to match your needs
start_date=datetime(2021, 1, 1),
catchup=False,
tags=["example"],
user_defined_macros={"DATASET": DATASET_NAME, "TABLE": TABLE_NAME},
default_args={"project_id": PROJECT_ID},
) as dag_with_locations:
create_dataset = BigQueryCreateEmptyDatasetOperator(
task_id="create-dataset", dataset_id=DATASET_NAME, project_id=PROJECT_ID
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
)

bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
create_table = BigQueryCreateEmptyTableOperator(
task_id="create_table",
dataset_id=DATASET_NAME,
table_id=TABLE_NAME,
schema_fields=SCHEMA,
time_partitioning={
"type": "DAY",
"field": "ds",
},
)
# [START howto_sensor_bigquery_table]
check_table_exists = BigQueryTableExistenceSensorAsync(
task_id="check_table_exists", project_id=PROJECT_ID, dataset_id=DATASET_NAME, table_id=TABLE_NAME
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
)
# [END howto_sensor_bigquery_table]

execute_insert_query = BigQueryInsertJobOperator(
task_id="execute_insert_query",
configuration={
"query": {
"query": INSERT_ROWS_QUERY,
"useLegacySql": False,
}
},
)

# [START howto_sensor_bigquery_table_partition]
check_table_partition_exists = BigQueryTablePartitionExistenceSensor(
task_id="check_table_partition_exists",
project_id=PROJECT_ID,
dataset_id=DATASET_NAME,
table_id=TABLE_NAME,
partition_id=PARTITION_NAME,
)
# [END howto_sensor_bigquery_table_partition]

delete_dataset = BigQueryDeleteDatasetOperator(
task_id="delete_dataset", dataset_id=DATASET_NAME, delete_contents=True
)

create_dataset >> create_table
create_table >> check_table_exists
create_table >> execute_insert_query
execute_insert_query >> check_table_partition_exists
check_table_exists >> delete_dataset
check_table_partition_exists >> delete_dataset
21 changes: 20 additions & 1 deletion astronomer/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook, _bq_cast
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from gcloud.aio.bigquery import Job
from gcloud.aio.bigquery import Job, Table
from google.cloud.bigquery import CopyJob, ExtractJob, LoadJob, QueryJob
from requests import Session

Expand Down Expand Up @@ -299,3 +299,22 @@ def interval_check(
raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")

self.log.info("All tests have passed")


class BigQueryTableHookAsync(GoogleBaseHookAsync):
"""Class to get async hook for Bigquery Table Async"""

sync_hook_class = BigQueryHook

async def get_table_client(
self, dataset: str, table_id: str, project_id: str, session: ClientSession
) -> Table:
"""Returns a Google Big Query Table object."""
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
with await self.service_file_as_context() as file:
return Table(
dataset_name=dataset,
table_name=table_id,
project=project_id,
service_file=file,
session=cast(Session, session),
)
85 changes: 85 additions & 0 deletions astronomer/providers/google/cloud/sensors/bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""This module contains Google Big Query sensors."""
from typing import Any, Dict, Optional
import warnings
from airflow.exceptions import AirflowException

from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor
from astronomer.providers.google.cloud.triggers.bigquery import BigQueryTableExistenceTrigger


class BigQueryTableExistenceSensorAsync(BigQueryTableExistenceSensor):
"""
Checks for the existence of a table in Google Big Query.
:param project_id: The Google cloud project in which to look for the table.
The connection supplied to the hook must provide
access to the specified project.
:param dataset_id: The name of the dataset in which to look for the table.
storage bucket.
:param table_id: The name of the table to check the existence of.
:param gcp_conn_id: The connection ID used to connect to Google Cloud.
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud.
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:param delegate_to: The account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:param polling_interval: The interval in seconds to wait between checks table existence.
"""

def __init__(
self,
gcp_conn_id: str = 'google_cloud_default',
polling_interval: float = 5.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.polling_interval = polling_interval
if self.bigquery_conn_id:
warnings.warn(
"The bigquery_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.",
DeprecationWarning,
stacklevel=3,
)
gcp_conn_id = self.bigquery_conn_id
self.gcp_conn_id = gcp_conn_id

def execute(self, context: Dict[str, Any]) -> None:
"""Airflow runs this method on the worker and defers using the trigger."""
self.defer(
timeout=self.execution_timeout,
trigger=BigQueryTableExistenceTrigger(
dataset_id = self.dataset_id,
table_id=self.table_id,
project_id=self.project_id,
poll_interval=self.polling_interval,
google_cloud_conn_id=self.gcp_conn_id,
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
hook_params={
"delegate_to": self.delegate_to,
"impersonation_chain": self.impersonation_chain,
},
),
method_name="execute_complete",
)

def execute_complete(self, context: Dict[str, Any], event: Optional[Dict[str, str]] = None) -> str:
"""
Callback for when the trigger fires - returns immediately.
Relies on trigger to throw an exception, otherwise it assumes execution was
successful.
"""
table_uri = f'{self.project_id}:{self.dataset_id}.{self.table_id}'
self.log.info('Sensor checks existence of table: %s', table_uri)
if event:
if event["status"] == "success":
return event["message"]
raise AirflowException(event["message"])
raise AirflowException("No event received in trigger callback")

86 changes: 85 additions & 1 deletion astronomer/providers/google/cloud/triggers/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import asyncio
from typing import Any, AsyncIterator, Dict, Optional, SupportsAbs, Tuple, Union

from aiohttp import ClientSession
from airflow.triggers.base import BaseTrigger, TriggerEvent

from astronomer.providers.google.cloud.hooks.bigquery import BigQueryHookAsync
from astronomer.providers.google.cloud.hooks.bigquery import (
BigQueryHookAsync,
BigQueryTableHookAsync,
)


class BigQueryInsertJobTrigger(BaseTrigger): # noqa: D101
Expand Down Expand Up @@ -378,3 +382,83 @@ async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
self.log.exception("Exception occurred while checking for query completion")
yield TriggerEvent({"status": "error", "message": str(e)})
return


class BigQueryTableExistenceTrigger(BaseTrigger):
"""Initialise the BigQuery Table Existence Trigger with needed parameters"""

def __init__(
self,
project_id: str,
dataset_id: str,
table_id: str,
google_cloud_conn_id: str,
hook_params: Dict[str, Any],
poll_interval: float = 4.0,
):
super().__init__()
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
self.dataset_id = dataset_id
self.project_id = project_id
self.table_id = table_id
self.google_cloud_conn_id: str = google_cloud_conn_id
self.poll_interval = poll_interval
self.hook_params = hook_params

def serialize(self) -> Tuple[str, Dict[str, Any]]:
"""Serializes BigQueryTableExistenceTrigger arguments and classpath."""
return (
"astronomer.providers.google.cloud.triggers.bigquery.BigQueryTableExistenceTrigger",
{
"dataset_id": self.dataset_id,
"project_id": self.project_id,
"table_id": self.table_id,
"google_cloud_conn_id": self.google_cloud_conn_id,
"poll_interval": self.poll_interval,
"hook_params": self.hook_params,
},
)

def _get_async_hook(self) -> BigQueryTableHookAsync:
return BigQueryTableHookAsync(gcp_conn_id=self.google_cloud_conn_id)
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved

async def run(self) -> AsyncIterator["TriggerEvent"]: # type: ignore[override]
"""Simple loop until the table exists in the google big query."""
while True:
try:
hook = self._get_async_hook()
response = await self._table_exists(
hook=hook, dataset=self.dataset_id, table_id=self.table_id, project_id=self.project_id
)
if response:
yield TriggerEvent({"status": "success", "message": "success"})
return
await asyncio.sleep(self.poll_interval)
except Exception as e:
self.log.exception("Exception occurred while checking for Table existence")
yield TriggerEvent({"status": "error", "message": str(e)})
return

async def _table_exists(
self, hook: BigQueryTableHookAsync, dataset: str, table_id: str, project_id: str
) -> bool:
"""
Checks if the object in the bucket is updated.

:param hook: BigQueryTableHookAsync Hook class
:param dataset: The Google Cloud Storage bucket where the object is.
:param table_id: The name of the blob_name to check in the Google cloud.
:param project_id: context datetime to compare with blob object updated time
"""
async with ClientSession() as session:
try:
client = await hook.get_table_client(
dataset=dataset, table_id=table_id, project_id=project_id, session=session
)
response = await client.get()
return True if response else False
except Exception as e:
# when url is not found it returns 404 error it should poll and wait
# until the table id is found
if e.status == 404: # type: ignore[attr-defined]
return False
bharanidharan14 marked this conversation as resolved.
Show resolved Hide resolved
raise e