Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
dc7b9f6
refactor: Fixed docstring for rows_processor parameter in SQLInsertRo…
dabla Jan 27, 2026
d37a16c
refactor: Added rows_processor parameter in GenericTransfer
dabla Jan 27, 2026
18e0560
refactor: Make SQLExecuteQueryTrigger non-blocking from Airflow 3.2+
dabla Jan 27, 2026
2fdba24
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Jan 27, 2026
2a47f9a
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Jan 27, 2026
8edc348
refactor: Make inserting of rows in sync and deferred mode more DRY
dabla Jan 27, 2026
5e4c530
refactor: Updated GenericTransfer type
dabla Jan 27, 2026
18d9c6d
refactor: Ignore mypy for _rows_processor
dabla Jan 28, 2026
bca6b89
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Jan 28, 2026
9541a74
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Jan 28, 2026
bb87f54
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Jan 30, 2026
7c3a2e6
refactor: Removed _process_rows and made rows_processor optional in c…
dabla Jan 31, 2026
1f31246
refactor: Removed _process_rows and made rows_processor optional in c…
dabla Jan 31, 2026
d5c9017
refactor: Made _get_records in trigger protected
dabla Jan 31, 2026
a670485
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Jan 31, 2026
4930c9a
refactor: Updated typing of rows_processor in GenericTransfer interface
dabla Feb 1, 2026
64eb707
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 1, 2026
b5fbe45
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 1, 2026
2c2164a
refactor: Fixed typing of rows parameter in _insert_rows method
dabla Feb 1, 2026
68df300
refactor: Refactored unit test for SQLInsertRowsOperator
dabla Feb 1, 2026
b973c03
refactor: Fixted typing of _insert_rows method
dabla Feb 1, 2026
aa8aafd
refactor: Updated docstring for rows_processor parameter
dabla Feb 1, 2026
9544569
refactor: Changed typing of context parameter from Context to Any in …
Feb 2, 2026
c0cbd84
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 2, 2026
9d3b2ef
refactor: Changed typing of rows_processor
dabla Feb 2, 2026
23869fe
refactor: Refactored _insert_rows method in SQLInsertRowsOperator and…
dabla Feb 2, 2026
ec1ff73
refactor: Removed unused import of jinja2
dabla Feb 2, 2026
e6f52a7
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 2, 2026
a3eeb84
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 2, 2026
1828cc9
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 3, 2026
1627210
docs: Added comment to explain why rows_processor is typed like this
dabla Feb 3, 2026
80fe566
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 3, 2026
55e91da
Update providers/common/sql/src/airflow/providers/common/sql/operator…
dabla Feb 6, 2026
6980f9f
refactor: Removed get_records
dabla Feb 6, 2026
2477b5b
refactor: Evaluate XComArgs first before evaluating rows
dabla Feb 6, 2026
51be7a5
Merge branch 'main' into feature/add-rows-processor-generic-transfer
dabla Feb 6, 2026
c1e2b33
refactor: Fixed check on rows in SQLInsertRowsOperator
dabla Feb 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -47,6 +47,9 @@ class GenericTransfer(BaseOperator):
:param source_hook_params: source hook parameters.
:param destination_conn_id: destination connection. (templated)
:param destination_hook_params: destination hook parameters.
:param rows_processor: (optional) A callable applied once per batch of rows before insertion.
It receives the full list of rows and the task context, and must return a list of rows compatible with
the underlying hook's.
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:param insert_args: extra params for `insert_rows` method.
Expand Down Expand Up @@ -80,6 +83,9 @@ def __init__(
source_hook_params: dict | None = None,
destination_conn_id: str,
destination_hook_params: dict | None = None,
rows_processor: Callable[..., list[Any]] | None = None,
# rows_processor is called as rows_processor(rows, **context);
# context keys vary, so Callable[..., list[Any]] is intentional.
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
page_size: int | None = None,
Expand All @@ -93,6 +99,7 @@ def __init__(
self.source_hook_params = source_hook_params
self.destination_conn_id = destination_conn_id
self.destination_hook_params = destination_hook_params
self._rows_processor = rows_processor
self.preoperator = preoperator
self.insert_args = insert_args or {}
self.page_size = page_size
Expand Down Expand Up @@ -139,6 +146,14 @@ def render_template_fields(
if isinstance(commit_every, str):
self.insert_args["commit_every"] = int(commit_every)

def _insert_rows(self, rows: list[Any], context: Context):
if self._rows_processor:
rows = self._rows_processor(rows, **context)

self.log.info("Inserting %d rows into %s", len(rows), self.destination_conn_id)

self.destination_hook.insert_rows(table=self.destination_table, rows=rows, **self.insert_args)

def execute(self, context: Context):
if self.preoperator:
self.log.info("Running preoperator")
Expand All @@ -162,12 +177,8 @@ def execute(self, context: Context):
for sql in self.sql:
self.log.info("Executing: \n %s", sql)

results = self.source_hook.get_records(sql)

self.log.info("Inserting rows into %s", self.destination_conn_id)
self.destination_hook.insert_rows(
table=self.destination_table, rows=results, **self.insert_args
)
rows = self.source_hook.get_records(sql)
self._insert_rows(rows=rows, context=context)

def execute_complete(
self,
Expand All @@ -178,9 +189,9 @@ def execute_complete(
if event.get("status") == "failure":
raise AirflowException(event.get("message"))

results = event.get("results")
rows = event.get("results")

if results:
if rows:
map_index = context["ti"].map_index
offset = (
context["ti"].xcom_pull(
Expand All @@ -196,15 +207,7 @@ def execute_complete(
self.log.info("Offset increased to %d", offset)
context["ti"].xcom_push(key="offset", value=offset)

self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id)
self.destination_hook.insert_rows(
table=self.destination_table, rows=results, **self.insert_args
)
self.log.info(
"Inserting %d rows into %s done!",
len(results),
self.destination_conn_id,
)
self._insert_rows(rows=rows, context=context)

self.defer(
trigger=SQLExecuteQueryTrigger(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#
"""Definition of the public interface for airflow.providers.common.sql.operators.generic_transfer."""

from collections.abc import Sequence
from collections.abc import Callable, Sequence
from functools import cached_property as cached_property
from typing import Any, ClassVar

Expand All @@ -51,6 +51,7 @@ class GenericTransfer(BaseOperator):
source_hook_params: Incomplete
destination_conn_id: Incomplete
destination_hook_params: Incomplete
rows_processor: Incomplete
preoperator: Incomplete
insert_args: Incomplete
page_size: Incomplete
Expand All @@ -63,6 +64,7 @@ class GenericTransfer(BaseOperator):
source_hook_params: dict | None = None,
destination_conn_id: str,
destination_hook_params: dict | None = None,
rows_processor: Callable[..., list[Any]] | None = None,
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
page_size: int | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
from airflow.utils.helpers import merge_dicts

if TYPE_CHECKING:
import jinja2

from airflow.providers.common.compat.sdk import Context
from airflow.providers.openlineage.extractors import OperatorLineage

Expand Down Expand Up @@ -1307,7 +1305,9 @@ class SQLInsertRowsOperator(BaseSQLOperator):
:param rows: the rows to insert into the table. Rows can be a list of tuples or a list of dictionaries.
When a list of dictionaries is provided, the column names are inferred from the dictionary keys and
will be matched with the column names, ignored columns will be filtered out.
:rows_processor: (optional) a function that will be applied to the rows before inserting them into the table.
:param rows_processor: (optional) A callable applied once per batch of rows before insertion.
It receives the full list of rows and the task context, and must return a list of rows compatible with
the underlying hook.
:param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated)
:param postoperator: sql statement or list of statements to be executed after loading the data. (templated)
:param insert_args: (optional) dictionary of additional arguments passed to the underlying hook's
Expand Down Expand Up @@ -1343,7 +1343,9 @@ def __init__(
columns: Iterable[str] | None = None,
ignored_columns: Iterable[str] | None = None,
rows: list[Any] | XComArg | None = None,
rows_processor: Callable[[Any, Context], Any] = lambda rows, **context: rows,
rows_processor: Callable[..., list[Any]] | None = None,
# rows_processor is called as rows_processor(rows, **context);
# context keys vary, so Callable[..., list[Any]] is intentional.
preoperator: str | list[str] | None = None,
postoperator: str | list[str] | None = None,
hook_params: dict | None = None,
Expand All @@ -1367,16 +1369,6 @@ def __init__(
self.insert_args = insert_args or {}
self.do_xcom_push = False

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
super().render_template_fields(context=context, jinja_env=jinja_env)

if isinstance(self.rows, XComArg):
self.rows = self.rows.resolve(context=context)

@property
def table_name_with_schema(self) -> str:
if self.schema is not None:
Expand All @@ -1395,11 +1387,23 @@ def column_names(self) -> list[str]:
return [column for column in self.columns if column not in self.ignored_columns]
return self.columns

def _process_rows(self, context: Context):
return self._rows_processor(self.rows, **context) # type: ignore
def _insert_rows(self, rows: list[Any], context: Context):
if self._rows_processor:
rows = self._rows_processor(rows, **context)

self.log.info("Inserting %d rows into %s", len(rows), self.conn_id)

self.get_db_hook().insert_rows(
table=self.table_name_with_schema,
rows=rows,
target_fields=self.column_names,
**self.insert_args,
)

def execute(self, context: Context) -> Any:
if not self.rows:
rows = self.rows.resolve(context=context) if isinstance(self.rows, XComArg) else self.rows

if not rows:
raise AirflowSkipException(f"Skipping task {self.task_id} because no rows.")

self.log.debug("Table: %s", self.table_name_with_schema)
Expand All @@ -1408,13 +1412,7 @@ def execute(self, context: Context) -> Any:
self.log.debug("Running preoperator")
self.log.debug(self.preoperator)
self.get_db_hook().run(self.preoperator)
rows = self._process_rows(context=context)
self.get_db_hook().insert_rows(
table=self.table_name_with_schema,
rows=rows,
target_fields=self.column_names,
**self.insert_args,
)
self._insert_rows(rows=rows, context=context)
if self.postoperator:
self.log.debug("Running postoperator")
self.log.debug(self.postoperator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import TYPE_CHECKING

from airflow.providers.common.compat.sdk import AirflowException, BaseHook
from airflow.providers.common.compat.version_compat import AIRFLOW_V_3_2_PLUS
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand Down Expand Up @@ -78,15 +79,23 @@ def get_hook(self) -> DbApiHook:
)
return hook

async def _get_records(self) -> Any:
from asgiref.sync import sync_to_async

hook = self.get_hook()

if AIRFLOW_V_3_2_PLUS:
# This is only supported from Airflow 3.2 or higher due to added async support in CommsDecoder
return await sync_to_async(hook.get_records)(self.sql)
return hook.get_records(self.sql)

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
hook = self.get_hook()

self.log.info("Extracting data from %s", self.conn_id)
self.log.info("Executing: \n %s", self.sql)
self.log.info("Reading records from %s", self.conn_id)

results = hook.get_records(self.sql)
results = await self._get_records()

self.log.info("Reading records from %s done!", self.conn_id)
self.log.debug("results: %s", results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@
from collections.abc import AsyncIterator
from typing import Any

from airflow.providers.common.sql.hooks.sql import DbApiHook as DbApiHook
from airflow.triggers.base import BaseTrigger as BaseTrigger, TriggerEvent as TriggerEvent

class SQLExecuteQueryTrigger(BaseTrigger):
def __init__(
self, sql: str | list[str], conn_id: str, hook_params: dict | None = None, **kwargs
) -> None: ...
def serialize(self) -> tuple[str, dict[str, Any]]: ...
def get_hook(self) -> DbApiHook: ...
async def run(self) -> AsyncIterator[TriggerEvent]: ... # type: ignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,10 @@ def get_connection(cls, conn_id: str):
mocked_conn.get_hook.return_value = mocked_hook
return mocked_conn

@classmethod
def convert_to_tuples(cls, rows, **context):
return [tuple(row) for row in rows]

def setup_method(self):
# Reset mock states before each test
self.mocked_source_hook.reset_mock()
Expand Down Expand Up @@ -289,6 +293,30 @@ def test_non_paginated_read(self):
**{"rows": [[1, 2], [11, 12], [3, 4], [13, 14], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
}

def test_non_paginated_read_with_rows_processor(self):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook):
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql="SELECT * FROM HR.EMPLOYEES",
destination_table="NEW_HR.EMPLOYEES",
insert_args=INSERT_ARGS,
execution_timeout=timedelta(hours=1),
rows_processor=self.convert_to_tuples,
)

operator.execute(context=mock_context(task=operator))

assert self.mocked_source_hook.get_records.call_count == 1
assert self.mocked_source_hook.get_records.call_args_list[0].args[0] == "SELECT * FROM HR.EMPLOYEES"
assert self.mocked_destination_hook.insert_rows.call_count == 1
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
**INSERT_ARGS,
**{"rows": [(1, 2), (11, 12), (3, 4), (13, 14), (3, 4), (13, 14)], "table": "NEW_HR.EMPLOYEES"},
}

def test_non_paginated_read_for_multiple_sql_statements(self):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook):
Expand Down Expand Up @@ -321,6 +349,39 @@ def test_non_paginated_read_for_multiple_sql_statements(self):
"table": "NEW_HR.EMPLOYEES",
}

def test_non_paginated_read_for_multiple_sql_statements_with_rows_processor(self):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook):
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql=["SELECT * FROM HR.EMPLOYEES", "SELECT * FROM HR.PEOPLE"],
destination_table="NEW_HR.EMPLOYEES",
insert_args=INSERT_ARGS,
execution_timeout=timedelta(hours=1),
rows_processor=self.convert_to_tuples,
)

operator.execute(context=mock_context(task=operator))

assert self.mocked_source_hook.get_records.call_count == 2
assert [call.args[0] for call in self.mocked_source_hook.get_records.call_args_list] == [
"SELECT * FROM HR.EMPLOYEES",
"SELECT * FROM HR.PEOPLE",
]
assert self.mocked_destination_hook.insert_rows.call_count == 2
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
**INSERT_ARGS,
"rows": [(1, 2), (11, 12), (3, 4), (13, 14), (3, 4), (13, 14)],
"table": "NEW_HR.EMPLOYEES",
}
assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == {
**INSERT_ARGS,
"rows": [(1, 2), (11, 12), (3, 4), (13, 14), (3, 4), (13, 14)],
"table": "NEW_HR.EMPLOYEES",
}

def test_paginated_read(self):
"""
This unit test is based on the example described in the medium article:
Expand Down
Loading