Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferred pagination mode to GenericTransfer #44809

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4c557c5
refactor: Refactored the GenericTransfer operator to support paginate…
davidblain-infrabel Dec 10, 2024
01dde54
refactor: updated provider dependencies
dabla Dec 10, 2024
4965c0c
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 10, 2024
a503b4d
refactor: Added TestSQLExecuteQueryTrigger and moved test code which …
davidblain-infrabel Dec 10, 2024
f8c96cc
refactor: Fixed static checks
davidblain-infrabel Dec 10, 2024
c4768a2
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 10, 2024
7d890d2
refactor: Fixed static checks
davidblain-infrabel Dec 10, 2024
7692225
refactor: Fixed static checks
davidblain-infrabel Dec 10, 2024
c2def9e
refactor: Reformatted GenericTransfer
davidblain-infrabel Dec 10, 2024
fbb9c72
refactor: Moved source and destination hooks into cached properties
davidblain-infrabel Dec 10, 2024
7aeebc0
refactor: Moved imports to type checking block
davidblain-infrabel Dec 11, 2024
971d58a
refactor: Fixed execute method of GenericTransfer
davidblain-infrabel Dec 11, 2024
86f399b
refactor: Refactored get_hook method of GenericTransfer which checks …
davidblain-infrabel Dec 11, 2024
f297752
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 11, 2024
bded9f4
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 11, 2024
6d9000f
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 13, 2024
8733636
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 13, 2024
0f4b029
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 13, 2024
60d7966
refactor: Remove white lines from mock_context
davidblain-infrabel Dec 11, 2024
63ec0ee
refactor: Reformatted get_hook in GenericTransfer operator
davidblain-infrabel Dec 11, 2024
02c5098
refactor: Added sql.pyi for SQLExecuteQueryTrigger
davidblain-infrabel Dec 16, 2024
c44a333
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 16, 2024
305ffe0
refactor: Reformatted SQLExecuteQueryTrigger definition
davidblain-infrabel Dec 16, 2024
df39fb9
refactor: Added alias in SQLExecuteQueryTrigger definition
davidblain-infrabel Dec 16, 2024
9eb1493
Merge branch 'main' into feature/paginated-generic-transfer
dabla Dec 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,9 @@
],
"devel-deps": [],
"plugins": [],
"cross-providers-deps": [],
"cross-providers-deps": [
"common.sql"
],
"excluded-python-versions": [],
"state": "ready"
},
Expand Down
5 changes: 5 additions & 0 deletions providers/src/airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ hooks:
- airflow.providers.common.sql.hooks.handlers
- airflow.providers.common.sql.hooks.sql

triggers:
- integration-name: Common SQL
python-modules:
- airflow.providers.common.sql.triggers.sql

sensors:
- integration-name: Common SQL
python-modules:
Expand Down
16 changes: 16 additions & 0 deletions providers/src/airflow/providers/common/sql/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
96 changes: 96 additions & 0 deletions providers/src/airflow/providers/common/sql/triggers/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from collections.abc import AsyncIterator
from typing import Any


class SQLExecuteQueryTrigger(BaseTrigger):
"""
A trigger that executes SQL code in async mode.

:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param conn_id: the connection ID used to connect to the database
:param hook_params: hook parameters
"""

def __init__(
self,
sql: str | list[str],
conn_id: str,
hook_params: dict | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.sql = sql
self.conn_id = conn_id
self.hook_params = hook_params

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the SQLExecuteQueryTrigger arguments and classpath."""
return (
f"{self.__class__.__module__}.{self.__class__.__name__}",
{
"sql": self.sql,
"conn_id": self.conn_id,
"hook_params": self.hook_params,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
hook = BaseHook.get_hook(self.conn_id, hook_params=self.hook_params)

if not isinstance(hook, DbApiHook):
raise AirflowException(
f"You are trying to use `common-sql` with {hook.__class__.__name__},"
" but its provider does not support it. Please upgrade the provider to a version that"
" supports `common-sql`. The hook class should be a subclass of"
" `airflow.providers.common.sql.hooks.sql.DbApiHook`."
f" Got {hook.__class__.__name__} Hook with class hierarchy: {hook.__class__.mro()}"
)

self.log.info("Extracting data from %s", self.conn_id)
self.log.info("Executing: \n %s", self.sql)

get_records = getattr(hook, "get_records", None)

if not callable(get_records):
raise RuntimeError(
f"Hook for connection {self.conn_id!r} "
f"({type(hook).__name__}) has no `get_records` method"
)
else:
self.log.info("Reading records from %s", self.conn_id)
results = get_records(self.sql)
self.log.info("Reading records from %s done!", self.conn_id)

self.log.debug("results: %s", results)
yield TriggerEvent({"status": "success", "results": results})
except Exception as e:
self.log.exception("An error occurred: %s", e)
yield TriggerEvent({"status": "failure", "message": str(e)})
47 changes: 47 additions & 0 deletions providers/src/airflow/providers/common/sql/triggers/sql.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# This is automatically generated stub for the `common.sql` provider
#
# This file is generated automatically by the `update-common-sql-api stubs` pre-commit
# and the .pyi file represents part of the "public" API that the
# `common.sql` provider exposes to other providers.
#
# Any, potentially breaking change in the stubs will require deliberate manual action from the contributor
# making a change to the `common.sql` provider. Those stubs are also used by MyPy automatically when checking
# if only public API of the common.sql provider is used by all the other providers.
#
# You can read more in the README_API.md file
#
"""
Definition of the public interface for airflow.providers.common.sql.triggers.sql
isort:skip_file
"""
from airflow.triggers.base import BaseTrigger, TriggerEvent as TriggerEvent

from collections.abc import AsyncIterator
from typing import Any


class SQLExecuteQueryTrigger(BaseTrigger):
def __init__(
self, sql: str | list[str], conn_id: str, hook_params: dict | None = None, **kwargs,
) -> None: ...

def serialize(self) -> tuple[str, dict[str, Any]]: ...

async def run(self) -> AsyncIterator[TriggerEvent]: ...
151 changes: 116 additions & 35 deletions providers/src/airflow/providers/standard/operators/generic_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,18 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger

if TYPE_CHECKING:
import jinja2

from airflow.utils.context import Context


Expand All @@ -40,10 +46,13 @@ class GenericTransfer(BaseOperator):
:param sql: SQL query to execute against the source database. (templated)
:param destination_table: target table. (templated)
:param source_conn_id: source connection. (templated)
:param source_hook_params: source hook parameters.
:param destination_conn_id: destination connection. (templated)
:param destination_hook_params: destination hook parameters.
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:param insert_args: extra params for `insert_rows` method.
:param chunk_size: number of records to be read in paginated mode (optional).
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -72,6 +81,7 @@ def __init__(
destination_hook_params: dict | None = None,
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
chunk_size: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -83,52 +93,123 @@ def __init__(
self.destination_hook_params = destination_hook_params
self.preoperator = preoperator
self.insert_args = insert_args or {}
self.chunk_size = chunk_size
self._paginated_sql_statement_format = kwargs.get(
"paginated_sql_statement_format", "{} LIMIT {} OFFSET {}"
)

@classmethod
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook:
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook:
"""
Return default hook for this connection id.
Return DbApiHook for this connection id.

:param conn_id: connection id
:param hook_params: hook parameters
:return: default hook for this connection
:return: DbApiHook for this connection
"""
connection = BaseHook.get_connection(conn_id)
return connection.get_hook(hook_params=hook_params)
hook = connection.get_hook(hook_params=hook_params)
if not isinstance(hook, DbApiHook):
raise RuntimeError(f"Hook for connection {conn_id!r} must be of type {DbApiHook.__name__}")
return hook

def execute(self, context: Context):
source_hook = self.get_hook(conn_id=self.source_conn_id, hook_params=self.source_hook_params)
destination_hook = self.get_hook(
conn_id=self.destination_conn_id, hook_params=self.destination_hook_params
)
@cached_property
def source_hook(self) -> DbApiHook:
return self.get_hook(conn_id=self.source_conn_id, hook_params=self.source_hook_params)

self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)
get_records = getattr(source_hook, "get_records", None)
if not callable(get_records):
raise RuntimeError(
f"Hook for connection {self.source_conn_id!r} "
f"({type(source_hook).__name__}) has no `get_records` method"
)
else:
results = get_records(self.sql)
@cached_property
def destination_hook(self) -> DbApiHook:
return self.get_hook(conn_id=self.destination_conn_id, hook_params=self.destination_hook_params)

def get_paginated_sql(self, offset: int) -> str:
"""Format the paginated SQL statement using the current format."""
return self._paginated_sql_statement_format.format(self.sql, self.chunk_size, offset)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
super().render_template_fields(context=context, jinja_env=jinja_env)

# Make sure string are converted to integers
if isinstance(self.chunk_size, str):
self.chunk_size = int(self.chunk_size)
commit_every = self.insert_args.get("commit_every")
if isinstance(commit_every, str):
self.insert_args["commit_every"] = int(commit_every)

def execute(self, context: Context):
if self.preoperator:
run = getattr(destination_hook, "run", None)
if not callable(run):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
f"({type(destination_hook).__name__}) has no `run` method"
)
self.log.info("Running preoperator")
self.log.info(self.preoperator)
run(self.preoperator)

insert_rows = getattr(destination_hook, "insert_rows", None)
if not callable(insert_rows):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
f"({type(destination_hook).__name__}) has no `insert_rows` method"
self.destination_hook.run(self.preoperator)

if self.chunk_size and isinstance(self.sql, str):
self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
hook_params=self.source_hook_params,
sql=self.get_paginated_sql(0),
),
method_name=self.execute_complete.__name__,
)
self.log.info("Inserting rows into %s", self.destination_conn_id)
insert_rows(table=self.destination_table, rows=results, **self.insert_args)
else:
self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)

results = self.destination_hook.get_records(self.sql)

self.log.info("Inserting rows into %s", self.destination_conn_id)
self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args)

def execute_complete(
self,
context: Context,
event: dict[Any, Any] | None = None,
) -> Any:
if event:
if event.get("status") == "failure":
raise AirflowException(event.get("message"))

results = event.get("results")

if results:
map_index = context["ti"].map_index
offset = (
context["ti"].xcom_pull(
key="offset",
task_ids=self.task_id,
dag_id=self.dag_id,
map_indexes=map_index,
default=0,
)
+ self.chunk_size
)

self.log.info("Offset increased to %d", offset)
self.xcom_push(context=context, key="offset", value=offset)

self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id)
self.destination_hook.insert_rows(
table=self.destination_table, rows=results, **self.insert_args
)
self.log.info(
"Inserting %d rows into %s done!",
len(results),
self.destination_conn_id,
)

self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
hook_params=self.source_hook_params,
sql=self.get_paginated_sql(offset),
),
method_name=self.execute_complete.__name__,
)
else:
self.log.info(
"No more rows to fetch into %s; ending transfer.",
self.destination_table,
)
Loading
Loading