Skip to content

Commit

Permalink
Init sqlalchemy merge job
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Sep 19, 2024
1 parent bcafd62 commit 8d2f997
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 8 deletions.
1 change: 1 addition & 0 deletions dlt/destinations/impl/sqlalchemy/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext:
# Multiple concatenated statements are not supported by all engines, so leave them off by default
caps.supports_multiple_statements = False
caps.type_mapper = SqlalchemyTypeMapper
caps.supported_merge_strategies = ["delete-insert"]

return caps

Expand Down
6 changes: 5 additions & 1 deletion dlt/destinations/impl/sqlalchemy/load_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@

import sqlalchemy as sa

from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.destination.reference import (
RunnableLoadJob,
HasFollowupJobs,
PreparedTableSchema,
)
from dlt.common.storages import FileStorage
from dlt.common.json import json, PY_DATETIME_DECODERS
from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams
from dlt.destinations.sql_jobs import SqlFollowupJob, SqlJobParams, SqlMergeFollowupJob

from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
from dlt.destinations.impl.sqlalchemy.merge_job import (
SqlalchemyMergeFollowupJob as SqlalchemyMergeFollowupJob,
)

if TYPE_CHECKING:
from dlt.destinations.impl.sqlalchemy.sqlalchemy_job_client import SqlalchemyJobClient
Expand Down
296 changes: 296 additions & 0 deletions dlt/destinations/impl/sqlalchemy/merge_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,296 @@
from typing import Sequence, Tuple, Optional, List

import sqlalchemy as sa

from dlt.destinations.sql_jobs import SqlMergeFollowupJob, SqlJobParams
from dlt.common.destination.reference import PreparedTableSchema
from dlt.destinations.impl.sqlalchemy.db_api_client import SqlalchemyClient
from dlt.common.schema.utils import (
get_columns_names_with_prop,
get_dedup_sort_tuple,
get_first_column_name_with_prop,
is_nested_table,
)


class SqlalchemyMergeFollowupJob(SqlMergeFollowupJob):
"""Uses SQLAlchemy to generate merge SQL statements.
Result is equivalent to the SQL generated by `SqlMergeFollowupJob`
except we use concrete tables instead of temporary tables.
"""
@classmethod
def generate_sql(
cls,
table_chain: Sequence[PreparedTableSchema],
sql_client: SqlalchemyClient, # type: ignore[override]
params: Optional[SqlJobParams] = None,
) -> List[str]:
root_table = table_chain[0]

root_table_obj = sql_client.get_existing_table(root_table["name"])
staging_root_table_obj = root_table_obj.to_metadata(
sql_client.metadata, schema=sql_client.staging_dataset_name
)

primary_key_names = get_columns_names_with_prop(root_table, "primary_key")
merge_key_names = get_columns_names_with_prop(root_table, "merge_key")

temp_metadata = sa.MetaData()

append_fallback = (len(primary_key_names) + len(merge_key_names)) == 0

if not append_fallback:
key_clause = cls._generate_key_table_clauses(
primary_key_names, merge_key_names, root_table_obj, staging_root_table_obj
)

sqla_statements = []

tables_to_drop: List[sa.Table] = [] # Keep track of temp tables to drop at the end of the job

# Generate the delete statements
if len(table_chain) == 1 and not cls.requires_temp_table_for_delete():
delete_statement = root_table_obj.delete().where(
sa.exists(
sa.select([sa.literal(1)]).where(key_clause).select_from(staging_root_table_obj)
)
)
sqla_statements.append(delete_statement)
else:
row_key_col_name = cls._get_row_key_col(table_chain, sql_client, root_table)
row_key_col = root_table_obj.c[row_key_col_name]
# Use a real table cause sqlalchemy doesn't have TEMPORARY TABLE abstractions
delete_temp_table = sa.Table(
"delete_" + root_table_obj.name,
temp_metadata,
# Give this column a fixed name to be able to reference it later
sa.Column("_dlt_id", row_key_col.type),
schema=staging_root_table_obj.schema,
)
tables_to_drop.append(delete_temp_table)
# Add the CREATE TABLE statement
sqla_statements.append(sa.sql.ddl.CreateTable(delete_temp_table))
# Insert data into the "temporary" table
insert_statement = delete_temp_table.insert().from_select(
[row_key_col],
sa.select([row_key_col]).where(
sa.exists(
sa.select([sa.literal(1)])
.where(key_clause)
.select_from(staging_root_table_obj)
)
),
)
sqla_statements.append(insert_statement)

for table in table_chain[1:]:
chain_table_obj = sql_client.get_existing_table(table["name"])
root_key_name = cls._get_root_key_col(table_chain, sql_client, table)
root_key_col = chain_table_obj.c[root_key_name]

delete_statement = chain_table_obj.delete().where(
root_key_col.in_(sa.select(delete_temp_table.c._dlt_id))
)

sqla_statements.append(delete_statement)

# Delete from root table
delete_statement = root_table_obj.delete().where(
row_key_col.in_(sa.select(delete_temp_table.c._dlt_id))
)
sqla_statements.append(delete_statement)

hard_delete_col_name, not_delete_cond = cls._get_hard_delete_col_and_cond(
root_table,
root_table_obj,
invert=True,
)

dedup_sort = get_dedup_sort_tuple(root_table) # column_name, 'asc' | 'desc'

if len(table_chain) > 1 and (primary_key_names or hard_delete_col_name is not None):
condition_column_names = (
None if hard_delete_col_name is None else [hard_delete_col_name]
)
condition_columns = (
[staging_root_table_obj.c[col_name] for col_name in condition_column_names]
if condition_column_names is not None
else []
)

staging_row_key_col = staging_root_table_obj.c[row_key_col_name]
# Create the insert "temporary" table (but use a concrete table)

insert_temp_table = sa.Table(
"insert_" + root_table_obj.name,
temp_metadata,
sa.Column(row_key_col_name, staging_row_key_col.type),
schema=staging_root_table_obj.schema,
)
tables_to_drop.append(insert_temp_table)
create_insert_temp_table_statement = sa.sql.ddl.CreateTable(insert_temp_table)
sqla_statements.append(create_insert_temp_table_statement)
staging_primary_key_cols = [
staging_root_table_obj.c[col_name] for col_name in primary_key_names
]

if primary_key_names:
if dedup_sort is not None:
order_by_col = staging_root_table_obj.c[dedup_sort[0]]
order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc
else:
order_by_col = sa.select(sa.literal(None))
order_dir_func = sa.asc
inner_cols = (
condition_columns if condition_columns is not None else [staging_row_key_col]
)

inner_select = sa.select(
sa.func.row_number()
.over(
partition_by=set(staging_primary_key_cols),
order_by=order_dir_func(order_by_col),
)
.label("_dlt_dedup_rn"),
*inner_cols
).subquery()

select_for_temp_insert = (
sa.select(staging_row_key_col)
.select_from(inner_select)
.where(inner_select.c._dlt_dedup_rn == 1)
)
if not_delete_cond is not None:
select_for_temp_insert = select_for_temp_insert.where(not_delete_cond)
else:
select_for_temp_insert = sa.select(staging_row_key_col).where(not_delete_cond)

insert_into_temp_table = insert_temp_table.insert().from_select(
[row_key_col_name], select_for_temp_insert
)
sqla_statements.append(insert_into_temp_table)

# Insert from staging to dataset
for table in table_chain:
table_obj = sql_client.get_existing_table(table["name"])
staging_table_obj = table_obj.to_metadata(
sql_client.metadata, schema=sql_client.staging_dataset_name
)

insert_cond = not_delete_cond if hard_delete_col_name is not None else sa.true()

if (primary_key_names and len(table_chain) > 1) or (
not primary_key_names
and is_nested_table(table)
and hard_delete_col_name is not None
):
uniq_column_name = root_key_name if is_nested_table(table) else row_key_col_name
uniq_column = staging_table_obj.c[uniq_column_name]
insert_cond = uniq_column.in_(
sa.select(
insert_temp_table.c[row_key_col_name].label(uniq_column_name)
).subquery()
)

select_sql = staging_table_obj.select().where(insert_cond)

if primary_key_names and len(table_chain) == 1:
staging_primary_key_cols = [
staging_table_obj.c[col_name] for col_name in primary_key_names
]
if dedup_sort is not None:
order_by_col = staging_table_obj.c[dedup_sort[0]]
order_dir_func = sa.asc if dedup_sort[1] == "asc" else sa.desc
else:
order_by_col = sa.select(sa.literal(None))
order_dir_func = sa.asc

inner_select = sa.select(
staging_table_obj,
sa.func.row_number()
.over(
partition_by=set(staging_primary_key_cols),
order_by=order_dir_func(order_by_col),
)
.label("_dlt_dedup_rn"),
).subquery()

select_sql = (
sa.select(staging_table_obj)
.select_from(inner_select)
.where(inner_select.c._dlt_dedup_rn == 1)
)
if insert_cond is not None:
select_sql = select_sql.where(insert_cond)

insert_statement = table_obj.insert().from_select(
[col.name for col in table_obj.columns], select_sql
)
sqla_statements.append(insert_statement)

# Drop all "temp" tables at the end
for table_obj in tables_to_drop:
sqla_statements.append(sa.sql.ddl.DropTable(table_obj))

return [
x + ";" if not x.endswith(";") else x
for x in (
str(stmt.compile(sql_client.engine, compile_kwargs={"literal_binds": True}))
for stmt in sqla_statements
)
]

@classmethod
def _get_hard_delete_col_and_cond( # type: ignore[override]
cls,
table: PreparedTableSchema,
table_obj: sa.Table,
invert: bool = False,
) -> Tuple[Optional[str], Optional[sa.sql.elements.BinaryExpression]]:
col_name = get_first_column_name_with_prop(table, "hard_delete")
if col_name is None:
return None, None
col = table_obj.c[col_name]
if invert:
cond = col.is_(None)
else:
cond = col.isnot(None)
if table["columns"][col_name]["data_type"] == "bool":
if invert:
cond = sa.or_(cond, col.is_(False))
else:
cond = col.is_(True)
return col_name, cond

@classmethod
def _generate_key_table_clauses(
cls,
primary_keys: Sequence[str],
merge_keys: Sequence[str],
root_table_obj: sa.Table,
staging_root_table_obj: sa.Table,
) -> sa.sql.ClauseElement:
# Returns an sqlalchemy or_ clause
clauses = []
if primary_keys or merge_keys:
for key in primary_keys:
clauses.append(
sa.and_(
*[
root_table_obj.c[key] == staging_root_table_obj.c[key]
for key in primary_keys
]
)
)
for key in merge_keys:
clauses.append(
sa.and_(
*[
root_table_obj.c[key] == staging_root_table_obj.c[key]
for key in merge_keys
]
)
)
return sa.or_(*clauses) # type: ignore[no-any-return]
else:
return sa.true() # type: ignore[no-any-return]
8 changes: 3 additions & 5 deletions dlt/destinations/impl/sqlalchemy/sqlalchemy_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SqlalchemyJsonLInsertJob,
SqlalchemyParquetInsertJob,
SqlalchemyStagingCopyJob,
SqlalchemyMergeFollowupJob,
)


Expand Down Expand Up @@ -97,13 +98,10 @@ def _create_replace_followup_jobs(
def _create_merge_followup_jobs(
self, table_chain: Sequence[PreparedTableSchema]
) -> List[FollowupJobRequest]:
# Ensure all tables exist in metadata before generating sql job
for table in table_chain:
self._to_table_object(table)
return [
SqlalchemyStagingCopyJob.from_table_chain(
table_chain, self.sql_client, {"replace": False}
)
]
return [SqlalchemyMergeFollowupJob.from_table_chain(table_chain, self.sql_client)]

def create_load_job(
self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False
Expand Down
4 changes: 2 additions & 2 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,13 @@ def destinations_configs(
destination_configs += [
DestinationTestConfiguration(
destination_type="sqlalchemy",
supports_merge=False,
supports_merge=True,
supports_dbt=False,
destination_name="sqlalchemy_mysql",
),
DestinationTestConfiguration(
destination_type="sqlalchemy",
supports_merge=False,
supports_merge=True,
supports_dbt=False,
destination_name="sqlalchemy_sqlite",
),
Expand Down

0 comments on commit 8d2f997

Please sign in to comment.