Skip to content

Commit

Permalink
SQL Database: Support including/excluding NULL cursor values (#1946)
Browse files Browse the repository at this point in the history
* SQL Database: Support including NULL cursor values

* Support exclude option

* Test skip import

* Always add exclude condition, import sqlalchemy from common lib
  • Loading branch information
steinitzu authored Oct 20, 2024
1 parent 1fa6609 commit 213f82e
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 8 deletions.
19 changes: 16 additions & 3 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
TTypeAdapter,
)

from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine
from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine, sa


TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"]
Expand Down Expand Up @@ -72,11 +72,13 @@ def __init__(
self.last_value = incremental.last_value
self.end_value = incremental.end_value
self.row_order: TSortOrder = self.incremental.row_order
self.on_cursor_value_missing = self.incremental.on_cursor_value_missing
else:
self.cursor_column = None
self.last_value = None
self.end_value = None
self.row_order = None
self.on_cursor_value_missing = None

def _make_query(self) -> SelectAny:
table = self.table
Expand All @@ -95,10 +97,21 @@ def _make_query(self) -> SelectAny:
else: # Custom last_value, load everything and let incremental handle filtering
return query # type: ignore[no-any-return]

where_clause = True
if self.last_value is not None:
query = query.where(filter_op(self.cursor_column, self.last_value))
where_clause = filter_op(self.cursor_column, self.last_value)
if self.end_value is not None:
query = query.where(filter_op_end(self.cursor_column, self.end_value))
where_clause = sa.and_(
where_clause, filter_op_end(self.cursor_column, self.end_value)
)

if self.on_cursor_value_missing == "include":
where_clause = sa.or_(where_clause, self.cursor_column.is_(None))
if self.on_cursor_value_missing == "exclude":
where_clause = sa.and_(where_clause, self.cursor_column.isnot(None))

if where_clause is not True:
query = query.where(where_clause)

# generate order by from declared row order
order_by = None
Expand Down
106 changes: 101 additions & 5 deletions tests/load/sources/sql_database/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest


import dlt
from dlt.common.typing import TDataItem

Expand All @@ -10,7 +11,8 @@
from dlt.sources.sql_database.helpers import TableLoader, TableBackend
from dlt.sources.sql_database.schema_types import table_to_columns
from tests.load.sources.sql_database.sql_source import SQLAlchemySourceDB
except MissingDependencyException:
import sqlalchemy as sa
except (MissingDependencyException, ModuleNotFoundError):
pytest.skip("Tests require sql alchemy", allow_module_level=True)


Expand Down Expand Up @@ -42,6 +44,7 @@ class MockIncremental:
cursor_path = "created_at"
row_order = "asc"
end_value = None
on_cursor_value_missing = "raise"

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
Expand Down Expand Up @@ -72,6 +75,7 @@ class MockIncremental:
cursor_path = "created_at"
row_order = "desc"
end_value = None
on_cursor_value_missing = "raise"

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
Expand All @@ -92,6 +96,95 @@ class MockIncremental:
assert query.compare(expected)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
@pytest.mark.parametrize("with_end_value", [True, False])
@pytest.mark.parametrize("cursor_value_missing", ["include", "exclude"])
def test_make_query_incremental_on_cursor_value_missing_set(
sql_source_db: SQLAlchemySourceDB,
backend: TableBackend,
with_end_value: bool,
cursor_value_missing: str,
) -> None:
class MockIncremental:
last_value = dlt.common.pendulum.now()
last_value_func = max
cursor_path = "created_at"
row_order = "asc"
end_value = None if not with_end_value else dlt.common.pendulum.now().add(hours=1)
on_cursor_value_missing = cursor_value_missing

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
sql_source_db.engine,
backend,
table,
table_to_columns(table),
incremental=MockIncremental(), # type: ignore[arg-type]
)

query = loader.make_query()
if cursor_value_missing == "include":
missing_cond = table.c.created_at.is_(None)
operator = sa.or_
else:
missing_cond = table.c.created_at.isnot(None)
operator = sa.and_

if with_end_value:
where_clause = operator(
sa.and_(
table.c.created_at >= MockIncremental.last_value,
table.c.created_at < MockIncremental.end_value,
),
missing_cond,
)
else:
where_clause = operator(
table.c.created_at >= MockIncremental.last_value,
missing_cond,
)
expected = table.select().order_by(table.c.created_at.asc()).where(where_clause)
assert query.compare(expected)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
@pytest.mark.parametrize("cursor_value_missing", ["include", "exclude"])
def test_make_query_incremental_on_cursor_value_missing_no_last_value(
sql_source_db: SQLAlchemySourceDB,
backend: TableBackend,
cursor_value_missing: str,
) -> None:
class MockIncremental:
last_value = None
last_value_func = max
cursor_path = "created_at"
row_order = "asc"
end_value = None
on_cursor_value_missing = cursor_value_missing

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
sql_source_db.engine,
backend,
table,
table_to_columns(table),
incremental=MockIncremental(), # type: ignore[arg-type]
)

query = loader.make_query()

if cursor_value_missing == "include":
# There is no where clause for include without last
expected = table.select().order_by(table.c.created_at.asc())
else:
# exclude always has a where clause
expected = (
table.select().order_by(table.c.created_at.asc()).where(table.c.created_at.isnot(None))
)

assert query.compare(expected)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
def test_make_query_incremental_end_value(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend
Expand All @@ -104,6 +197,7 @@ class MockIncremental:
cursor_path = "created_at"
end_value = now.add(hours=1)
row_order = None
on_cursor_value_missing = "raise"

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
Expand All @@ -115,10 +209,11 @@ class MockIncremental:
)

query = loader.make_query()
expected = (
table.select()
.where(table.c.created_at <= MockIncremental.last_value)
.where(table.c.created_at > MockIncremental.end_value)
expected = table.select().where(
sa.and_(
table.c.created_at <= MockIncremental.last_value,
table.c.created_at > MockIncremental.end_value,
)
)

assert query.compare(expected)
Expand All @@ -134,6 +229,7 @@ class MockIncremental:
cursor_path = "created_at"
row_order = "asc"
end_value = dlt.common.pendulum.now()
on_cursor_value_missing = "raise"

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
Expand Down

0 comments on commit 213f82e

Please sign in to comment.