diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 829fe8db82..7bc64240e1 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -238,16 +238,16 @@ def _sqlite_create_dataset(self, dataset_name: str) -> None: """Mimic multiple schemas in sqlite using ATTACH DATABASE to attach a new database file to the current connection. """ - if dataset_name == "main": - # main always exists - return if self._sqlite_is_memory_db(): new_db_fn = ":memory:" else: new_db_fn = self._sqlite_dataset_filename(dataset_name) - statement = "ATTACH DATABASE :fn AS :name" - self.execute_sql(statement, fn=new_db_fn, name=dataset_name) + if dataset_name != "main": # main is the current file, it is always attached + statement = "ATTACH DATABASE :fn AS :name" + self.execute_sql(statement, fn=new_db_fn, name=dataset_name) + # WAL mode is applied to all currently attached databases + self.execute_sql("PRAGMA journal_mode=WAL") self._sqlite_attached_datasets.add(dataset_name) def _sqlite_drop_dataset(self, dataset_name: str) -> None: diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py index 360dd89192..bf05c42f08 100644 --- a/dlt/destinations/impl/sqlalchemy/factory.py +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -1,5 +1,6 @@ import typing as t +from dlt.common import pendulum from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.destination.capabilities import DataTypeMapper from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE @@ -9,6 +10,7 @@ SqlalchemyCredentials, SqlalchemyClientConfiguration, ) +from dlt.common.data_writers.escape import format_datetime_literal SqlalchemyTypeMapper: t.Type[DataTypeMapper] @@ -24,6 +26,13 @@ from sqlalchemy.engine import Engine +def _format_mysql_datetime_literal( + v: pendulum.DateTime, precision: int = 6, no_tz: bool = False +) -> str: + # Format without timezone to prevent tz conversion in SELECT + return format_datetime_literal(v, precision, no_tz=True) + + class sqlalchemy(Destination[SqlalchemyClientConfiguration, "SqlalchemyJobClient"]): spec = SqlalchemyClientConfiguration @@ -50,6 +59,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supports_multiple_statements = False caps.type_mapper = SqlalchemyTypeMapper caps.supported_replace_strategies = ["truncate-and-insert", "insert-from-staging"] + caps.supported_merge_strategies = ["delete-insert", "scd2"] return caps @@ -67,6 +77,8 @@ def adjust_capabilities( caps.max_identifier_length = dialect.max_identifier_length caps.max_column_identifier_length = dialect.max_identifier_length caps.supports_native_boolean = dialect.supports_native_boolean + if dialect.name == "mysql": + caps.format_datetime_literal = _format_mysql_datetime_literal return caps diff --git a/dlt/destinations/impl/sqlalchemy/load_jobs.py b/dlt/destinations/impl/sqlalchemy/load_jobs.py index c8486dc0f0..3cfd6bd910 100644 --- a/dlt/destinations/impl/sqlalchemy/load_jobs.py +++ b/dlt/destinations/impl/sqlalchemy/load_jobs.py @@ -13,6 +13,7 @@ from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.merge_job import SqlalchemyMergeFollowupJob if TYPE_CHECKING: from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient @@ -134,3 +135,11 @@ def generate_sql( statements.append(stmt) return statements + + +__all__ = [ + "SqlalchemyJsonLInsertJob", + "SqlalchemyParquetInsertJob", + "SqlalchemyStagingCopyJob", + "SqlalchemyMergeFollowupJob", +] diff --git a/dlt/destinations/impl/sqlalchemy/merge_job.py b/dlt/destinations/impl/sqlalchemy/merge_job.py new file mode 100644 index 0000000000..5360939ba0 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/merge_job.py @@ -0,0 +1,441 @@ +from typing import Sequence, Tuple, Optional, List, Union +import operator + +import sqlalchemy as sa + +from dlt.destinations.sql_jobs import SqlMergeFollowupJob +from dlt.common.destination.reference import PreparedTableSchema, DestinationCapabilitiesContext +from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_dedup_sort_tuple, + get_first_column_name_with_prop, + is_nested_table, + get_validity_column_names, + get_active_record_timestamp, +) +from dlt.common.time import ensure_pendulum_datetime +from dlt.common.storages.load_package import load_package as current_load_package + + +class SqlalchemyMergeFollowupJob(SqlMergeFollowupJob): + """Uses SQLAlchemy to generate merge SQL statements. + Result is equivalent to the SQL generated by `SqlMergeFollowupJob` + except for delete-insert we use concrete tables instead of temporary tables. + """ + + @classmethod + def gen_merge_sql( + cls, + table_chain: Sequence[PreparedTableSchema], + sql_client: SqlalchemyClient, # type: ignore[override] + ) -> List[str]: + root_table = table_chain[0] + + root_table_obj = sql_client.get_existing_table(root_table["name"]) + staging_root_table_obj = root_table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + primary_key_names = get_columns_names_with_prop(root_table, "primary_key") + merge_key_names = get_columns_names_with_prop(root_table, "merge_key") + + temp_metadata = sa.MetaData() + + append_fallback = (len(primary_key_names) + len(merge_key_names)) == 0 + + sqla_statements = [] + tables_to_drop: List[sa.Table] = ( + [] + ) # Keep track of temp tables to drop at the end of the job + + if not append_fallback: + key_clause = cls._generate_key_table_clauses( + primary_key_names, merge_key_names, root_table_obj, staging_root_table_obj + ) + + # Generate the delete statements + if len(table_chain) == 1 and not cls.requires_temp_table_for_delete(): + delete_statement = root_table_obj.delete().where( + sa.exists( + sa.select(sa.literal(1)) + .where(key_clause) + .select_from(staging_root_table_obj) + ) + ) + sqla_statements.append(delete_statement) + else: + row_key_col_name = cls._get_row_key_col(table_chain, sql_client, root_table) + row_key_col = root_table_obj.c[row_key_col_name] + # Use a real table cause sqlalchemy doesn't have TEMPORARY TABLE abstractions + delete_temp_table = sa.Table( + "delete_" + root_table_obj.name, + temp_metadata, + # Give this column a fixed name to be able to reference it later + sa.Column("_dlt_id", row_key_col.type), + schema=staging_root_table_obj.schema, + ) + tables_to_drop.append(delete_temp_table) + # Add the CREATE TABLE statement + sqla_statements.append(sa.sql.ddl.CreateTable(delete_temp_table)) + # Insert data into the "temporary" table + insert_statement = delete_temp_table.insert().from_select( + [row_key_col], + sa.select(row_key_col).where( + sa.exists( + sa.select(sa.literal(1)) + .where(key_clause) + .select_from(staging_root_table_obj) + ) + ), + ) + sqla_statements.append(insert_statement) + + for table in table_chain[1:]: + chain_table_obj = sql_client.get_existing_table(table["name"]) + root_key_name = cls._get_root_key_col(table_chain, sql_client, table) + root_key_col = chain_table_obj.c[root_key_name] + + delete_statement = chain_table_obj.delete().where( + root_key_col.in_(sa.select(delete_temp_table.c._dlt_id)) + ) + + sqla_statements.append(delete_statement) + + # Delete from root table + delete_statement = root_table_obj.delete().where( + row_key_col.in_(sa.select(delete_temp_table.c._dlt_id)) + ) + sqla_statements.append(delete_statement) + + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + root_table_obj, + invert=True, + ) + + dedup_sort = get_dedup_sort_tuple(root_table) # column_name, 'asc' | 'desc' + + if len(table_chain) > 1 and (primary_key_names or hard_delete_col_name is not None): + condition_column_names = ( + None if hard_delete_col_name is None else [hard_delete_col_name] + ) + condition_columns = ( + [staging_root_table_obj.c[col_name] for col_name in condition_column_names] + if condition_column_names is not None + else [] + ) + + staging_row_key_col = staging_root_table_obj.c[row_key_col_name] + + # Create the insert "temporary" table (but use a concrete table) + insert_temp_table = sa.Table( + "insert_" + root_table_obj.name, + temp_metadata, + sa.Column(row_key_col_name, staging_row_key_col.type), + schema=staging_root_table_obj.schema, + ) + tables_to_drop.append(insert_temp_table) + create_insert_temp_table_statement = sa.sql.ddl.CreateTable(insert_temp_table) + sqla_statements.append(create_insert_temp_table_statement) + staging_primary_key_cols = [ + staging_root_table_obj.c[col_name] for col_name in primary_key_names + ] + + inner_cols = [staging_row_key_col] + + if primary_key_names: + if dedup_sort is not None: + order_by_col = staging_root_table_obj.c[dedup_sort[0]] + order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc + else: + order_by_col = sa.select(sa.literal(None)) + order_dir_func = sa.asc + if condition_columns: + inner_cols += condition_columns + + inner_select = sa.select( + sa.func.row_number() + .over( + partition_by=set(staging_primary_key_cols), + order_by=order_dir_func(order_by_col), + ) + .label("_dlt_dedup_rn"), + *inner_cols, + ).subquery() + + select_for_temp_insert = sa.select(inner_select.c[row_key_col_name]).where( + inner_select.c._dlt_dedup_rn == 1 + ) + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + inner_select, + invert=True, + ) + + if not_delete_cond is not None: + select_for_temp_insert = select_for_temp_insert.where(not_delete_cond) + else: + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + staging_root_table_obj, + invert=True, + ) + select_for_temp_insert = sa.select(staging_row_key_col).where(not_delete_cond) + + insert_into_temp_table = insert_temp_table.insert().from_select( + [row_key_col_name], select_for_temp_insert + ) + sqla_statements.append(insert_into_temp_table) + + # Insert from staging to dataset + for table in table_chain: + table_obj = sql_client.get_existing_table(table["name"]) + staging_table_obj = table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + select_sql = staging_table_obj.select() + + if (primary_key_names and len(table_chain) > 1) or ( + not primary_key_names + and is_nested_table(table) + and hard_delete_col_name is not None + ): + uniq_column_name = root_key_name if is_nested_table(table) else row_key_col_name + uniq_column = staging_table_obj.c[uniq_column_name] + select_sql = select_sql.where( + uniq_column.in_( + sa.select( + insert_temp_table.c[row_key_col_name].label(uniq_column_name) + ).subquery() + ) + ) + elif primary_key_names and len(table_chain) == 1: + staging_primary_key_cols = [ + staging_table_obj.c[col_name] for col_name in primary_key_names + ] + if dedup_sort is not None: + order_by_col = staging_table_obj.c[dedup_sort[0]] + order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc + else: + order_by_col = sa.select(sa.literal(None)) + order_dir_func = sa.asc + + inner_select = sa.select( + staging_table_obj, + sa.func.row_number() + .over( + partition_by=set(staging_primary_key_cols), + order_by=order_dir_func(order_by_col), + ) + .label("_dlt_dedup_rn"), + ).subquery() + + select_sql = sa.select( + *[c for c in inner_select.c if c.name != "_dlt_dedup_rn"] + ).where(inner_select.c._dlt_dedup_rn == 1) + + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, inner_select, invert=True + ) + + if hard_delete_col_name is not None: + select_sql = select_sql.where(not_delete_cond) + else: + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, staging_root_table_obj, invert=True + ) + + if hard_delete_col_name is not None: + select_sql = select_sql.where(not_delete_cond) + + insert_statement = table_obj.insert().from_select( + [col.name for col in table_obj.columns], select_sql + ) + sqla_statements.append(insert_statement) + + # Drop all "temp" tables at the end + for table_obj in tables_to_drop: + sqla_statements.append(sa.sql.ddl.DropTable(table_obj)) + + return [ + x + ";" if not x.endswith(";") else x + for x in ( + str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True})) + for stmt in sqla_statements + ) + ] + + @classmethod + def _get_hard_delete_col_and_cond( # type: ignore[override] + cls, + table: PreparedTableSchema, + table_obj: sa.Table, + invert: bool = False, + ) -> Tuple[Optional[str], Optional[sa.sql.elements.BinaryExpression]]: + col_name = get_first_column_name_with_prop(table, "hard_delete") + if col_name is None: + return None, None + col = table_obj.c[col_name] + if invert: + cond = col.is_(None) + else: + cond = col.isnot(None) + if table["columns"][col_name]["data_type"] == "bool": + if invert: + cond = sa.or_(cond, col.is_(False)) + else: + cond = col.is_(True) + return col_name, cond + + @classmethod + def _generate_key_table_clauses( + cls, + primary_keys: Sequence[str], + merge_keys: Sequence[str], + root_table_obj: sa.Table, + staging_root_table_obj: sa.Table, + ) -> sa.sql.ClauseElement: + # Returns an sqlalchemy or_ clause + clauses = [] + if primary_keys or merge_keys: + for key in primary_keys: + clauses.append( + sa.and_( + *[ + root_table_obj.c[key] == staging_root_table_obj.c[key] + for key in primary_keys + ] + ) + ) + for key in merge_keys: + clauses.append( + sa.and_( + *[ + root_table_obj.c[key] == staging_root_table_obj.c[key] + for key in merge_keys + ] + ) + ) + return sa.or_(*clauses) # type: ignore[no-any-return] + else: + return sa.true() # type: ignore[no-any-return] + + @classmethod + def _gen_concat_sqla( + cls, columns: Sequence[sa.Column] + ) -> Union[sa.sql.elements.BinaryExpression, sa.Column]: + # Use col1 + col2 + col3 ... to generate a dialect specific concat expression + result = columns[0] + if len(columns) == 1: + return result + # Cast because CONCAT is only generated for string columns + result = sa.cast(result, sa.String) + for col in columns[1:]: + result = operator.add(result, sa.cast(col, sa.String)) + return result + + @classmethod + def gen_scd2_sql( + cls, + table_chain: Sequence[PreparedTableSchema], + sql_client: SqlalchemyClient, # type: ignore[override] + ) -> List[str]: + sqla_statements = [] + root_table = table_chain[0] + root_table_obj = sql_client.get_existing_table(root_table["name"]) + staging_root_table_obj = root_table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + from_, to = get_validity_column_names(root_table) + hash_ = get_first_column_name_with_prop(root_table, "x-row-version") + + caps = sql_client.capabilities + + format_datetime_literal = caps.format_datetime_literal + if format_datetime_literal is None: + format_datetime_literal = ( + DestinationCapabilitiesContext.generic_capabilities().format_datetime_literal + ) + + boundary_ts = ensure_pendulum_datetime( + root_table.get("x-boundary-timestamp", current_load_package()["state"]["created_at"]) # type: ignore[arg-type] + ) + + boundary_literal = format_datetime_literal(boundary_ts, caps.timestamp_precision) + + active_record_timestamp = get_active_record_timestamp(root_table) + + update_statement = ( + root_table_obj.update() + .values({to: sa.text(boundary_literal)}) + .where(root_table_obj.c[hash_].notin_(sa.select(staging_root_table_obj.c[hash_]))) + ) + + if active_record_timestamp is None: + active_record_literal = None + root_is_active_clause = root_table_obj.c[to].is_(None) + else: + active_record_literal = format_datetime_literal( + active_record_timestamp, caps.timestamp_precision + ) + root_is_active_clause = root_table_obj.c[to] == sa.text(active_record_literal) + + update_statement = update_statement.where(root_is_active_clause) + + merge_keys = get_columns_names_with_prop(root_table, "merge_key") + if merge_keys: + root_merge_key_cols = [root_table_obj.c[key] for key in merge_keys] + staging_merge_key_cols = [staging_root_table_obj.c[key] for key in merge_keys] + + update_statement = update_statement.where( + cls._gen_concat_sqla(root_merge_key_cols).in_( + sa.select(cls._gen_concat_sqla(staging_merge_key_cols)) + ) + ) + + sqla_statements.append(update_statement) + + insert_statement = root_table_obj.insert().from_select( + [col.name for col in root_table_obj.columns], + sa.select( + sa.literal(boundary_literal.strip("'")).label(from_), + sa.literal( + active_record_literal.strip("'") if active_record_literal is not None else None + ).label(to), + *[c for c in staging_root_table_obj.columns if c.name not in [from_, to]], + ).where( + staging_root_table_obj.c[hash_].notin_( + sa.select(root_table_obj.c[hash_]).where(root_is_active_clause) + ) + ), + ) + sqla_statements.append(insert_statement) + + nested_tables = table_chain[1:] + for table in nested_tables: + row_key_column = cls._get_root_key_col(table_chain, sql_client, table) + + table_obj = sql_client.get_existing_table(table["name"]) + staging_table_obj = table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + insert_statement = table_obj.insert().from_select( + [col.name for col in table_obj.columns], + staging_table_obj.select().where( + staging_table_obj.c[row_key_column].notin_( + sa.select(table_obj.c[row_key_column]) + ) + ), + ) + sqla_statements.append(insert_statement) + + return [ + x + ";" if not x.endswith(";") else x + for x in ( + str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True})) + for stmt in sqla_statements + ) + ] diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index a2514a43e0..c5a6442d8a 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -18,7 +18,11 @@ from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import Schema, TTableSchema, TColumnSchema, TSchemaTables from dlt.common.schema.typing import TColumnType, TTableSchemaColumns -from dlt.common.schema.utils import pipeline_state_table, normalize_table_identifiers +from dlt.common.schema.utils import ( + pipeline_state_table, + normalize_table_identifiers, + is_complete_column, +) from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient from dlt.destinations.impl.sqlalchemy.configuration import SqlalchemyClientConfiguration @@ -26,6 +30,7 @@ SqlalchemyJsonLInsertJob, SqlalchemyParquetInsertJob, SqlalchemyStagingCopyJob, + SqlalchemyMergeFollowupJob, ) @@ -65,6 +70,7 @@ def _to_table_object(self, schema_table: PreparedTableSchema) -> sa.Table: *[ self._to_column_object(col, schema_table) for col in schema_table["columns"].values() + if is_complete_column(col) ], extend_existing=True, schema=self.sql_client.dataset_name, @@ -97,13 +103,10 @@ def _create_replace_followup_jobs( def _create_merge_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: + # Ensure all tables exist in metadata before generating sql job for table in table_chain: self._to_table_object(table) - return [ - SqlalchemyStagingCopyJob.from_table_chain( - table_chain, self.sql_client, {"replace": False} - ) - ] + return [SqlalchemyMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False diff --git a/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md b/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md index b9014e0564..9f33c02337 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md +++ b/docs/website/docs/dlt-ecosystem/destinations/sqlalchemy.md @@ -135,8 +135,7 @@ The following write dispositions are supported: - `append` - `replace` with `truncate-and-insert` and `insert-from-staging` replace strategies. `staging-optimized` falls back to `insert-from-staging`. - -The `merge` disposition is not supported and falls back to `append`. +- `merge` with `delete-insert` and `scd2` merge strategies. ## Data loading diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index 3e08b792ed..2a5b9ed296 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -52,13 +52,22 @@ def strip_timezone(ts: TAnyDateTime) -> pendulum.DateTime: def get_table( - pipeline: dlt.Pipeline, table_name: str, sort_column: str = None, include_root_id: bool = True + pipeline: dlt.Pipeline, + table_name: str, + sort_column: str = None, + include_root_id: bool = True, + ts_columns: Optional[List[str]] = None, ) -> List[Dict[str, Any]]: """Returns destination table contents as list of dictionaries.""" + ts_columns = ts_columns or [] table = [ { - k: strip_timezone(v) if isinstance(v, datetime) else v + k: ( + strip_timezone(v) + if isinstance(v, datetime) or (k in ts_columns and v is not None) + else v + ) for k, v in r.items() if not k.startswith("_dlt") or k in DEFAULT_VALIDITY_COLUMN_NAMES @@ -128,7 +137,7 @@ def r(data): # assert load results ts_1 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ { from_: ts_1, to: None, @@ -153,7 +162,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ { from_: ts_1, to: None, @@ -178,7 +187,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_3 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ {from_: ts_1, to: ts_3, "nk": 2, "c1": "bar", "c2__nc1": "bar"}, {from_: ts_1, to: ts_2, "nk": 1, "c1": "foo", "c2__nc1": "foo"}, { @@ -198,7 +207,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_4 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c2__nc1") == [ + assert get_table(p, "dim_test", "c2__nc1", ts_columns=[from_, to]) == [ {from_: ts_1, to: ts_3, "nk": 2, "c1": "bar", "c2__nc1": "bar"}, { from_: ts_4, @@ -242,7 +251,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_1 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c1") == [ + assert get_table(p, "dim_test", "c1", ts_columns=[FROM, TO]) == [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: None, "nk": 1, "c1": "foo"}, ] @@ -261,7 +270,7 @@ def r(data): info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) - assert get_table(p, "dim_test", "c1") == [ + assert get_table(p, "dim_test", "c1", ts_columns=[FROM, TO]) == [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, # updated {FROM: ts_2, TO: None, "nk": 1, "c1": "foo_updated"}, # new @@ -289,7 +298,7 @@ def r(data): ts_3 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: None, "nk": 2, "c1": "bar"}, {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, @@ -315,7 +324,7 @@ def r(data): ts_4 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: ts_4, "nk": 2, "c1": "bar"}, # updated {FROM: ts_1, TO: ts_2, "nk": 1, "c1": "foo"}, @@ -336,7 +345,7 @@ def r(data): ts_5 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( - get_table(p, "dim_test"), + get_table(p, "dim_test", ts_columns=[FROM, TO]), [ {FROM: ts_1, TO: ts_4, "nk": 2, "c1": "bar"}, {FROM: ts_5, TO: None, "nk": 3, "c1": "baz"}, # new @@ -502,7 +511,7 @@ def r(data): {**{FROM: ts_3, TO: None}, **r1_no_child}, {**{FROM: ts_1, TO: None}, **r2_no_child}, ] - assert_records_as_set(get_table(p, "dim_test"), expected) + assert_records_as_set(get_table(p, "dim_test", ts_columns=[FROM, TO]), expected) # assert child records expected = [ @@ -739,7 +748,10 @@ def dim_test(data): assert load_table_counts(p, "dim_test")["dim_test"] == 3 ts3 = get_load_package_created_at(p, info) # natural key 1 should now have two records (one retired, one active) - actual = [{k: v for k, v in row.items() if k in ("nk", TO)} for row in get_table(p, "dim_test")] + actual = [ + {k: v for k, v in row.items() if k in ("nk", TO)} + for row in get_table(p, "dim_test", ts_columns=[FROM, TO]) + ] expected = [{"nk": 1, TO: ts3}, {"nk": 1, TO: None}, {"nk": 2, TO: None}] assert_records_as_set(actual, expected) # type: ignore[arg-type] @@ -753,7 +765,10 @@ def dim_test(data): assert load_table_counts(p, "dim_test")["dim_test"] == 4 ts4 = get_load_package_created_at(p, info) # natural key 1 should now have three records (two retired, one active) - actual = [{k: v for k, v in row.items() if k in ("nk", TO)} for row in get_table(p, "dim_test")] + actual = [ + {k: v for k, v in row.items() if k in ("nk", TO)} + for row in get_table(p, "dim_test", ts_columns=[FROM, TO]) + ] expected = [{"nk": 1, TO: ts3}, {"nk": 1, TO: ts4}, {"nk": 1, TO: None}, {"nk": 2, TO: None}] assert_records_as_set(actual, expected) # type: ignore[arg-type] @@ -805,7 +820,7 @@ def dim_test_compound(data): # "Doe" should now have two records (one retired, one active) actual = [ {k: v for k, v in row.items() if k in ("first_name", "last_name", TO)} - for row in get_table(p, "dim_test_compound") + for row in get_table(p, "dim_test_compound", ts_columns=[FROM, TO]) ] expected = [ {"first_name": first_name, "last_name": "Doe", TO: ts3}, @@ -869,7 +884,7 @@ def dim_test(data): ts2 = get_load_package_created_at(p, info) actual = [ {k: v for k, v in row.items() if k in ("date", "name", TO)} - for row in get_table(p, "dim_test") + for row in get_table(p, "dim_test", ts_columns=[TO]) ] expected = [ {"date": "2024-01-01", "name": "a", TO: None}, diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index f7d915903e..fad244fa71 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -128,11 +128,15 @@ def source(): # schemaless destinations allow adding of root key without the pipeline failing # they do not mind adding NOT NULL columns to tables with existing data (id NOT NULL is supported at all) # doing this will result in somewhat useless behavior - destination_allows_adding_root_key = destination_config.destination_type in [ - "dremio", - "clickhouse", - "athena", - ] + destination_allows_adding_root_key = ( + destination_config.destination_type + in [ + "dremio", + "clickhouse", + "athena", + ] + or destination_config.destination_name == "sqlalchemy_mysql" + ) if destination_allows_adding_root_key and not with_root_key: pipeline.run( diff --git a/tests/load/utils.py b/tests/load/utils.py index 19601f2cf1..9cfb6984a5 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -331,13 +331,13 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( destination_type="sqlalchemy", - supports_merge=False, + supports_merge=True, supports_dbt=False, destination_name="sqlalchemy_mysql", ), DestinationTestConfiguration( destination_type="sqlalchemy", - supports_merge=False, + supports_merge=True, supports_dbt=False, destination_name="sqlalchemy_sqlite", ),