From dfa300ead1ff1bb126ecdb18bf0b45013e128d13 Mon Sep 17 00:00:00 2001 From: Steinthor Palsson Date: Thu, 17 Oct 2024 12:19:46 -0400 Subject: [PATCH] Always add exclude condition, import sqlalchemy from common lib --- dlt/sources/sql_database/helpers.py | 21 +++++----- .../load/sources/sql_database/test_helpers.py | 38 +++++++++++++++++++ 2 files changed, 50 insertions(+), 9 deletions(-) diff --git a/dlt/sources/sql_database/helpers.py b/dlt/sources/sql_database/helpers.py index d6ac767438..0a57eff904 100644 --- a/dlt/sources/sql_database/helpers.py +++ b/dlt/sources/sql_database/helpers.py @@ -13,8 +13,6 @@ ) import operator -from sqlalchemy import and_, or_ - import dlt from dlt.common.configuration.specs import ( BaseConfiguration, @@ -38,7 +36,7 @@ TTypeAdapter, ) -from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine +from dlt.common.libs.sql_alchemy import Engine, CompileError, create_engine, sa TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"] @@ -99,15 +97,20 @@ def _make_query(self) -> SelectAny: else: # Custom last_value, load everything and let incremental handle filtering return query # type: ignore[no-any-return] + where_clause = True if self.last_value is not None: - where_and_clauses = [filter_op(self.cursor_column, self.last_value)] + where_clause = filter_op(self.cursor_column, self.last_value) if self.end_value is not None: - where_and_clauses.append(filter_op_end(self.cursor_column, self.end_value)) - where_clause = and_(*where_and_clauses) + where_clause = sa.and_( + where_clause, filter_op_end(self.cursor_column, self.end_value) + ) + if self.on_cursor_value_missing == "include": - where_clause = or_(where_clause, self.cursor_column.is_(None)) - elif self.on_cursor_value_missing == "exclude": - where_clause = and_(where_clause, self.cursor_column.isnot(None)) + where_clause = sa.or_(where_clause, self.cursor_column.is_(None)) + if self.on_cursor_value_missing == "exclude": + where_clause = sa.and_(where_clause, self.cursor_column.isnot(None)) + + if where_clause is not True: query = query.where(where_clause) # generate order by from declared row order diff --git a/tests/load/sources/sql_database/test_helpers.py b/tests/load/sources/sql_database/test_helpers.py index 3fb62ad4f3..4748f226a9 100644 --- a/tests/load/sources/sql_database/test_helpers.py +++ b/tests/load/sources/sql_database/test_helpers.py @@ -147,6 +147,44 @@ class MockIncremental: assert query.compare(expected) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("cursor_value_missing", ["include", "exclude"]) +def test_make_query_incremental_on_cursor_value_missing_no_last_value( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + cursor_value_missing: str, +) -> None: + class MockIncremental: + last_value = None + last_value_func = max + cursor_path = "created_at" + row_order = "asc" + end_value = None + on_cursor_value_missing = cursor_value_missing + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + + if cursor_value_missing == "include": + # There is no where clause for include without last + expected = table.select().order_by(table.c.created_at.asc()) + else: + # exclude always has a where clause + expected = ( + table.select().order_by(table.c.created_at.asc()).where(table.c.created_at.isnot(None)) + ) + + assert query.compare(expected) + + @pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) def test_make_query_incremental_end_value( sql_source_db: SQLAlchemySourceDB, backend: TableBackend