Skip to content

Commit

Permalink
add support for delta tables
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Oct 1, 2024
1 parent 0ec1656 commit 8ef63b7
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 41 deletions.
79 changes: 42 additions & 37 deletions dlt/destinations/impl/filesystem/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import os

import dlt

import duckdb

import sqlglot
Expand Down Expand Up @@ -172,60 +174,63 @@ def create_views_for_tables(self, tables: Dict[str, str]) -> None:
if table_name not in self.fs_client.schema.tables:
# unknown tables will not be created
continue
self._existing_views.append(view_name)

# discover file type
schema_table = cast(PreparedTableSchema, self.fs_client.schema.tables[table_name])
self._existing_views.append(view_name)
folder = self.fs_client.get_table_dir(table_name)
files = self.fs_client.list_table_files(table_name)
first_file_type = os.path.splitext(files[0])[1][1:]

# build files string
supports_wildcard_notation = self.fs_client.config.protocol != "abfss"
protocol = (
"" if self.fs_client.is_local_filesystem else f"{self.fs_client.config.protocol}://"
)
resolved_folder = f"{protocol}{folder}"
resolved_files_string = f"'{resolved_folder}/**/*.{first_file_type}'"
if not supports_wildcard_notation:
resolved_files_string = ",".join(map(lambda f: f"'{protocol}{f}'", files))

# discover tables files
file_type = os.path.splitext(files[0])[1][1:]
columns_string = ""
if file_type == "jsonl":
read_command = "read_json"
# for json we need to provide types
type_mapper = self.capabilities.get_type_mapper()
schema_table = cast(PreparedTableSchema, self.fs_client.schema.tables[table_name])
columns = map(
# build columns definition
type_mapper = self.capabilities.get_type_mapper()
columns = ",".join(
map(
lambda c: (
f'{self.escape_column_name(c["name"])}:'
f' "{type_mapper.to_destination_type(c, schema_table)}"'
),
self.fs_client.schema.tables[table_name]["columns"].values(),
)
columns_string = ",columns = {" + ",".join(columns) + "}"
)

elif file_type == "parquet":
read_command = "read_parquet"
# discover wether compression is enabled
compression = (
""
if dlt.config.get("data_writer.disable_compression")
else ", compression = 'gzip'"
)

# create from statement
from_statement = ""
if schema_table.get("table_format") == "delta":
from_statement = f"delta_scan('{resolved_folder}')"
elif first_file_type == "parquet":
from_statement = f"read_parquet([{resolved_files_string}])"
elif first_file_type == "jsonl":
from_statement = (
f"read_json([{resolved_files_string}], columns = {{{columns}}}) {compression}"
)
else:
raise NotImplementedError(
f"Unknown filetype {file_type} for table {table_name}. Currently only jsonl and"
" parquet files are supported."
f"Unknown filetype {first_file_type} for table {table_name}. Currently only"
" jsonl and parquet files as well as delta tables are supported."
)

# build files string
protocol = (
"" if self.fs_client.is_local_filesystem else f"{self.fs_client.config.protocol}://"
)
supports_wildcard_notation = self.fs_client.config.protocol != "abfss"
files_string = f"'{protocol}{folder}/**/*.{file_type}'"
if not supports_wildcard_notation:
files_string = ",".join(map(lambda f: f"'{protocol}{f}'", files))

# create table
view_name = self.make_qualified_table_name(view_name)
create_table_sql_base = (
f"CREATE VIEW {view_name} AS SELECT * FROM"
f" {read_command}([{files_string}] {columns_string})"
)
create_table_sql_gzipped = (
f"CREATE VIEW {view_name} AS SELECT * FROM"
f" {read_command}([{files_string}] {columns_string} , compression = 'gzip')"
)
try:
self._conn.execute(create_table_sql_base)
except (duckdb.InvalidInputException, duckdb.IOException):
# try to load non gzipped files
self._conn.execute(create_table_sql_gzipped)
create_table_sql_base = f"CREATE VIEW {view_name} AS SELECT * FROM {from_statement}"
self._conn.execute(create_table_sql_base)

@contextmanager
@raise_database_error
Expand Down
55 changes: 51 additions & 4 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@
DestinationTestConfiguration,
GCS_BUCKET,
SFTP_BUCKET,
FILE_BUCKET,
MEMORY_BUCKET,
)
from dlt.destinations import filesystem


def _run_dataset_checks(
pipeline: Pipeline, destination_config: DestinationTestConfiguration
pipeline: Pipeline,
destination_config: DestinationTestConfiguration,
table_format: str = None,
alternate_access_pipeline: Pipeline = None,
) -> None:
destination_type = pipeline.destination_client().config.destination_type

Expand All @@ -47,12 +52,13 @@ def _run_dataset_checks(
@dlt.source()
def source():
@dlt.resource(
table_format=table_format,
columns={
"id": {"data_type": "bigint"},
# we add a decimal with precision to see wether the hints are preserved
"decimal": {"data_type": "decimal", "precision": 10, "scale": 3},
"other_decimal": {"data_type": "decimal", "precision": 12, "scale": 3},
}
},
)
def items():
yield from [
Expand All @@ -66,10 +72,11 @@ def items():
]

@dlt.resource(
table_format=table_format,
columns={
"id": {"data_type": "bigint"},
"double_id": {"data_type": "bigint"},
}
},
)
def double_items():
yield from [
Expand All @@ -86,6 +93,9 @@ def double_items():
s = source()
pipeline.run(s, loader_file_format=destination_config.file_format)

if alternate_access_pipeline:
pipeline = alternate_access_pipeline

# access via key
table_relationship = pipeline._dataset()["items"]

Expand Down Expand Up @@ -314,7 +324,44 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura
"read_pipeline", dataset_name="read_test", dev_mode=True, destination=gcp_bucket
)
_run_dataset_checks(pipeline, destination_config)
assert pipeline.destination_client().config.credentials


@pytest.mark.essential
@pytest.mark.parametrize(
"destination_config",
destinations_configs(
table_format_filesystem_configs=True,
with_table_format="delta",
bucket_exclude=[SFTP_BUCKET, MEMORY_BUCKET],
),
ids=lambda x: x.name,
)
def test_delta_tables(destination_config: DestinationTestConfiguration) -> None:
os.environ["DATA_WRITER__FILE_MAX_ITEMS"] = "700"

pipeline = destination_config.setup_pipeline(
"read_pipeline",
dataset_name="read_test",
)

# in case of gcs we use the s3 compat layer for reading
# for writing we still need to use the gc authentication, as delta_rs seems to use
# methods on the s3 interface that are not implemented by gcs
access_pipeline = pipeline
if destination_config.bucket_url == GCS_BUCKET:
gcp_bucket = filesystem(
GCS_BUCKET.replace("gs://", "s3://"), destination_name="filesystem_s3_gcs_comp"
)
access_pipeline = destination_config.setup_pipeline(
"read_pipeline", dataset_name="read_test", destination=gcp_bucket
)

_run_dataset_checks(
pipeline,
destination_config,
table_format="delta",
alternate_access_pipeline=access_pipeline,
)


@pytest.mark.essential
Expand Down

0 comments on commit 8ef63b7

Please sign in to comment.