Skip to content

Commit

Permalink
Always add exclude condition, import sqlalchemy from common lib
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Oct 17, 2024
1 parent a6f1c2c commit dfa300e
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
21 changes: 12 additions & 9 deletions dlt/sources/sql_database/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
)
import operator

from sqlalchemy import and_, or_

import dlt
from dlt.common.configuration.specs import (
BaseConfiguration,
Expand All @@ -38,7 +36,7 @@
TTypeAdapter,
)

from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine
from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine, sa


TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"]
Expand Down Expand Up @@ -99,15 +97,20 @@ def _make_query(self) -> SelectAny:
else: # Custom last_value, load everything and let incremental handle filtering
return query # type: ignore[no-any-return]

where_clause = True
if self.last_value is not None:
where_and_clauses = [filter_op(self.cursor_column, self.last_value)]
where_clause = filter_op(self.cursor_column, self.last_value)
if self.end_value is not None:
where_and_clauses.append(filter_op_end(self.cursor_column, self.end_value))
where_clause = and_(*where_and_clauses)
where_clause = sa.and_(
where_clause, filter_op_end(self.cursor_column, self.end_value)
)

if self.on_cursor_value_missing == "include":
where_clause = or_(where_clause, self.cursor_column.is_(None))
elif self.on_cursor_value_missing == "exclude":
where_clause = and_(where_clause, self.cursor_column.isnot(None))
where_clause = sa.or_(where_clause, self.cursor_column.is_(None))
if self.on_cursor_value_missing == "exclude":
where_clause = sa.and_(where_clause, self.cursor_column.isnot(None))

if where_clause is not True:
query = query.where(where_clause)

# generate order by from declared row order
Expand Down
38 changes: 38 additions & 0 deletions tests/load/sources/sql_database/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,44 @@ class MockIncremental:
assert query.compare(expected)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
@pytest.mark.parametrize("cursor_value_missing", ["include", "exclude"])
def test_make_query_incremental_on_cursor_value_missing_no_last_value(
sql_source_db: SQLAlchemySourceDB,
backend: TableBackend,
cursor_value_missing: str,
) -> None:
class MockIncremental:
last_value = None
last_value_func = max
cursor_path = "created_at"
row_order = "asc"
end_value = None
on_cursor_value_missing = cursor_value_missing

table = sql_source_db.get_table("chat_message")
loader = TableLoader(
sql_source_db.engine,
backend,
table,
table_to_columns(table),
incremental=MockIncremental(), # type: ignore[arg-type]
)

query = loader.make_query()

if cursor_value_missing == "include":
# There is no where clause for include without last
expected = table.select().order_by(table.c.created_at.asc())
else:
# exclude always has a where clause
expected = (
table.select().order_by(table.c.created_at.asc()).where(table.c.created_at.isnot(None))
)

assert query.compare(expected)


@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"])
def test_make_query_incremental_end_value(
sql_source_db: SQLAlchemySourceDB, backend: TableBackend
Expand Down

0 comments on commit dfa300e

Please sign in to comment.