diff --git a/dlt/destinations/impl/sqlalchemy/factory.py b/dlt/destinations/impl/sqlalchemy/factory.py index 74214f8b75..1d35f4f49b 100644 --- a/dlt/destinations/impl/sqlalchemy/factory.py +++ b/dlt/destinations/impl/sqlalchemy/factory.py @@ -49,6 +49,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: # Multiple concatenated statements are not supported by all engines, so leave them off by default caps.supports_multiple_statements = False caps.type_mapper = SqlalchemyTypeMapper + caps.supported_merge_strategies = ["delete-insert"] return caps diff --git a/dlt/destinations/impl/sqlalchemy/load_jobs.py b/dlt/destinations/impl/sqlalchemy/load_jobs.py index c8486dc0f0..d3f0afaefe 100644 --- a/dlt/destinations/impl/sqlalchemy/load_jobs.py +++ b/dlt/destinations/impl/sqlalchemy/load_jobs.py @@ -3,6 +3,7 @@ import sqlalchemy as sa +from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.destination.reference import ( RunnableLoadJob, HasFollowupJobs, @@ -10,9 +11,12 @@ ) from dlt.common.storages import FileStorage from dlt.common.json import json, PY_DATETIME_DECODERS -from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams +from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams, SqlMergeFollowupJob from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.destinations.impl.sqlalchemy.merge_job import ( + SqlalchemyMergeFollowupJob as SqlalchemyMergeFollowupJob, +) if TYPE_CHECKING: from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient diff --git a/dlt/destinations/impl/sqlalchemy/merge_job.py b/dlt/destinations/impl/sqlalchemy/merge_job.py new file mode 100644 index 0000000000..4c741fb372 --- /dev/null +++ b/dlt/destinations/impl/sqlalchemy/merge_job.py @@ -0,0 +1,296 @@ +from typing import Sequence, Tuple, Optional, List + +import sqlalchemy as sa + +from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlJobParams +from dlt.common.destination.reference import PreparedTableSchema +from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient +from dlt.common.schema.utils import ( + get_columns_names_with_prop, + get_dedup_sort_tuple, + get_first_column_name_with_prop, + is_nested_table, +) + + +class SqlalchemyMergeFollowupJob(SqlMergeFollowupJob): + """Uses SQLAlchemy to generate merge SQL statements. + Result is equivalent to the SQL generated by `SqlMergeFollowupJob` + except we use concrete tables instead of temporary tables. + """ + @classmethod + def generate_sql( + cls, + table_chain: Sequence[PreparedTableSchema], + sql_client: SqlalchemyClient, # type: ignore[override] + params: Optional[SqlJobParams] = None, + ) -> List[str]: + root_table = table_chain[0] + + root_table_obj = sql_client.get_existing_table(root_table["name"]) + staging_root_table_obj = root_table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + primary_key_names = get_columns_names_with_prop(root_table, "primary_key") + merge_key_names = get_columns_names_with_prop(root_table, "merge_key") + + temp_metadata = sa.MetaData() + + append_fallback = (len(primary_key_names) + len(merge_key_names)) == 0 + + if not append_fallback: + key_clause = cls._generate_key_table_clauses( + primary_key_names, merge_key_names, root_table_obj, staging_root_table_obj + ) + + sqla_statements = [] + + tables_to_drop: List[sa.Table] = [] # Keep track of temp tables to drop at the end of the job + + # Generate the delete statements + if len(table_chain) == 1 and not cls.requires_temp_table_for_delete(): + delete_statement = root_table_obj.delete().where( + sa.exists( + sa.select([sa.literal(1)]).where(key_clause).select_from(staging_root_table_obj) + ) + ) + sqla_statements.append(delete_statement) + else: + row_key_col_name = cls._get_row_key_col(table_chain, sql_client, root_table) + row_key_col = root_table_obj.c[row_key_col_name] + # Use a real table cause sqlalchemy doesn't have TEMPORARY TABLE abstractions + delete_temp_table = sa.Table( + "delete_" + root_table_obj.name, + temp_metadata, + # Give this column a fixed name to be able to reference it later + sa.Column("_dlt_id", row_key_col.type), + schema=staging_root_table_obj.schema, + ) + tables_to_drop.append(delete_temp_table) + # Add the CREATE TABLE statement + sqla_statements.append(sa.sql.ddl.CreateTable(delete_temp_table)) + # Insert data into the "temporary" table + insert_statement = delete_temp_table.insert().from_select( + [row_key_col], + sa.select([row_key_col]).where( + sa.exists( + sa.select([sa.literal(1)]) + .where(key_clause) + .select_from(staging_root_table_obj) + ) + ), + ) + sqla_statements.append(insert_statement) + + for table in table_chain[1:]: + chain_table_obj = sql_client.get_existing_table(table["name"]) + root_key_name = cls._get_root_key_col(table_chain, sql_client, table) + root_key_col = chain_table_obj.c[root_key_name] + + delete_statement = chain_table_obj.delete().where( + root_key_col.in_(sa.select(delete_temp_table.c._dlt_id)) + ) + + sqla_statements.append(delete_statement) + + # Delete from root table + delete_statement = root_table_obj.delete().where( + row_key_col.in_(sa.select(delete_temp_table.c._dlt_id)) + ) + sqla_statements.append(delete_statement) + + hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond( + root_table, + root_table_obj, + invert=True, + ) + + dedup_sort = get_dedup_sort_tuple(root_table) # column_name, 'asc' | 'desc' + + if len(table_chain) > 1 and (primary_key_names or hard_delete_col_name is not None): + condition_column_names = ( + None if hard_delete_col_name is None else [hard_delete_col_name] + ) + condition_columns = ( + [staging_root_table_obj.c[col_name] for col_name in condition_column_names] + if condition_column_names is not None + else [] + ) + + staging_row_key_col = staging_root_table_obj.c[row_key_col_name] + # Create the insert "temporary" table (but use a concrete table) + + insert_temp_table = sa.Table( + "insert_" + root_table_obj.name, + temp_metadata, + sa.Column(row_key_col_name, staging_row_key_col.type), + schema=staging_root_table_obj.schema, + ) + tables_to_drop.append(insert_temp_table) + create_insert_temp_table_statement = sa.sql.ddl.CreateTable(insert_temp_table) + sqla_statements.append(create_insert_temp_table_statement) + staging_primary_key_cols = [ + staging_root_table_obj.c[col_name] for col_name in primary_key_names + ] + + if primary_key_names: + if dedup_sort is not None: + order_by_col = staging_root_table_obj.c[dedup_sort[0]] + order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc + else: + order_by_col = sa.select(sa.literal(None)) + order_dir_func = sa.asc + inner_cols = ( + condition_columns if condition_columns is not None else [staging_row_key_col] + ) + + inner_select = sa.select( + sa.func.row_number() + .over( + partition_by=set(staging_primary_key_cols), + order_by=order_dir_func(order_by_col), + ) + .label("_dlt_dedup_rn"), + *inner_cols + ).subquery() + + select_for_temp_insert = ( + sa.select(staging_row_key_col) + .select_from(inner_select) + .where(inner_select.c._dlt_dedup_rn == 1) + ) + if not_delete_cond is not None: + select_for_temp_insert = select_for_temp_insert.where(not_delete_cond) + else: + select_for_temp_insert = sa.select(staging_row_key_col).where(not_delete_cond) + + insert_into_temp_table = insert_temp_table.insert().from_select( + [row_key_col_name], select_for_temp_insert + ) + sqla_statements.append(insert_into_temp_table) + + # Insert from staging to dataset + for table in table_chain: + table_obj = sql_client.get_existing_table(table["name"]) + staging_table_obj = table_obj.to_metadata( + sql_client.metadata, schema=sql_client.staging_dataset_name + ) + + insert_cond = not_delete_cond if hard_delete_col_name is not None else sa.true() + + if (primary_key_names and len(table_chain) > 1) or ( + not primary_key_names + and is_nested_table(table) + and hard_delete_col_name is not None + ): + uniq_column_name = root_key_name if is_nested_table(table) else row_key_col_name + uniq_column = staging_table_obj.c[uniq_column_name] + insert_cond = uniq_column.in_( + sa.select( + insert_temp_table.c[row_key_col_name].label(uniq_column_name) + ).subquery() + ) + + select_sql = staging_table_obj.select().where(insert_cond) + + if primary_key_names and len(table_chain) == 1: + staging_primary_key_cols = [ + staging_table_obj.c[col_name] for col_name in primary_key_names + ] + if dedup_sort is not None: + order_by_col = staging_table_obj.c[dedup_sort[0]] + order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc + else: + order_by_col = sa.select(sa.literal(None)) + order_dir_func = sa.asc + + inner_select = sa.select( + staging_table_obj, + sa.func.row_number() + .over( + partition_by=set(staging_primary_key_cols), + order_by=order_dir_func(order_by_col), + ) + .label("_dlt_dedup_rn"), + ).subquery() + + select_sql = ( + sa.select(staging_table_obj) + .select_from(inner_select) + .where(inner_select.c._dlt_dedup_rn == 1) + ) + if insert_cond is not None: + select_sql = select_sql.where(insert_cond) + + insert_statement = table_obj.insert().from_select( + [col.name for col in table_obj.columns], select_sql + ) + sqla_statements.append(insert_statement) + + # Drop all "temp" tables at the end + for table_obj in tables_to_drop: + sqla_statements.append(sa.sql.ddl.DropTable(table_obj)) + + return [ + x + ";" if not x.endswith(";") else x + for x in ( + str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True})) + for stmt in sqla_statements + ) + ] + + @classmethod + def _get_hard_delete_col_and_cond( # type: ignore[override] + cls, + table: PreparedTableSchema, + table_obj: sa.Table, + invert: bool = False, + ) -> Tuple[Optional[str], Optional[sa.sql.elements.BinaryExpression]]: + col_name = get_first_column_name_with_prop(table, "hard_delete") + if col_name is None: + return None, None + col = table_obj.c[col_name] + if invert: + cond = col.is_(None) + else: + cond = col.isnot(None) + if table["columns"][col_name]["data_type"] == "bool": + if invert: + cond = sa.or_(cond, col.is_(False)) + else: + cond = col.is_(True) + return col_name, cond + + @classmethod + def _generate_key_table_clauses( + cls, + primary_keys: Sequence[str], + merge_keys: Sequence[str], + root_table_obj: sa.Table, + staging_root_table_obj: sa.Table, + ) -> sa.sql.ClauseElement: + # Returns an sqlalchemy or_ clause + clauses = [] + if primary_keys or merge_keys: + for key in primary_keys: + clauses.append( + sa.and_( + *[ + root_table_obj.c[key] == staging_root_table_obj.c[key] + for key in primary_keys + ] + ) + ) + for key in merge_keys: + clauses.append( + sa.and_( + *[ + root_table_obj.c[key] == staging_root_table_obj.c[key] + for key in merge_keys + ] + ) + ) + return sa.or_(*clauses) # type: ignore[no-any-return] + else: + return sa.true() # type: ignore[no-any-return] diff --git a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py index a2514a43e0..da438f6cef 100644 --- a/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py +++ b/dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py @@ -26,6 +26,7 @@ SqlalchemyJsonLInsertJob, SqlalchemyParquetInsertJob, SqlalchemyStagingCopyJob, + SqlalchemyMergeFollowupJob, ) @@ -97,13 +98,10 @@ def _create_replace_followup_jobs( def _create_merge_followup_jobs( self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: + # Ensure all tables exist in metadata before generating sql job for table in table_chain: self._to_table_object(table) - return [ - SqlalchemyStagingCopyJob.from_table_chain( - table_chain, self.sql_client, {"replace": False} - ) - ] + return [SqlalchemyMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def create_load_job( self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False diff --git a/tests/load/utils.py b/tests/load/utils.py index f443748f8e..75596b6f5f 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -326,13 +326,13 @@ def destinations_configs( destination_configs += [ DestinationTestConfiguration( destination_type="sqlalchemy", - supports_merge=False, + supports_merge=True, supports_dbt=False, destination_name="sqlalchemy_mysql", ), DestinationTestConfiguration( destination_type="sqlalchemy", - supports_merge=False, + supports_merge=True, supports_dbt=False, destination_name="sqlalchemy_sqlite", ),