From 73ed51512c71af464177f6714c57c1d1b93d88b8 Mon Sep 17 00:00:00 2001 From: Julian Alves <28436330+donotpush@users.noreply.github.com> Date: Fri, 30 Aug 2024 11:40:57 +0200 Subject: [PATCH] Feat/1492 extend timestamp config (#1669) * feat: add timezone flag to configure timestamp data * fix: delete timezone init * test: add duckdb timestamps with timezone * test: fix resource hints for timestamp * test: correct duckdb timestamps * test: timezone tests for parquet files * exp: add notebook with timestamp exploration * test: refactor timestamp tests * test: simplified tests and extended experiments * exp: timestamp exp for duckdb and parquet * fix: add pyarrow reflection for timezone flag * fix lint errors * fix: CI/CD move tests pyarrow module * fix: pyarrow timezone defaults true * refactor: typemapper signatures * fix: duckdb timestamp config * docs: updated duckdb.md timestamps * fix: revert duckdb timestamp defaults * fix: restore duckdb timestamp default * fix: duckdb timestamp mapper * fix: delete notebook * docs: added timestamp and timezone section * refactor: duckdb precision exception message * feat: postgres timestamp timezone config * fix: postgres timestamp precision * fix: postgres timezone false case * feat: add snowflake timezone and precision flag * test: postgres invalid timestamp precision * test: unified timestamp invalid precision * test: unified column flag timezone * chore: add warn log for unsupported timezone or precision flag * docs: timezone and precision flags for timestamps * fix: none case error * docs: add duckdb default precision * fix: typing errors * rebase: formatted files from upstream devel * fix: warning message and reference TODO * test: delete duplicated input_data array * docs: moved timestamp config to data types section * fix: lint and format * fix: lint local errors --- dlt/common/libs/pyarrow.py | 9 +- dlt/common/schema/typing.py | 1 + dlt/destinations/impl/athena/athena.py | 14 +- dlt/destinations/impl/bigquery/bigquery.py | 8 +- .../impl/clickhouse/clickhouse.py | 6 +- .../impl/databricks/databricks.py | 15 +- dlt/destinations/impl/dremio/dremio.py | 12 +- dlt/destinations/impl/duckdb/duck.py | 43 +++-- .../impl/lancedb/lancedb_client.py | 24 +-- dlt/destinations/impl/mssql/mssql.py | 15 +- dlt/destinations/impl/postgres/postgres.py | 42 ++++- dlt/destinations/impl/redshift/redshift.py | 9 +- dlt/destinations/impl/snowflake/snowflake.py | 41 ++++- dlt/destinations/job_client_impl.py | 11 +- dlt/destinations/type_mapping.py | 56 +++++-- .../docs/dlt-ecosystem/destinations/duckdb.md | 39 ++++- .../dlt-ecosystem/destinations/postgres.md | 21 +++ .../dlt-ecosystem/destinations/snowflake.md | 21 +++ tests/load/pipeline/test_pipelines.py | 148 ++++++++++++++++++ tests/pipeline/test_pipeline.py | 17 +- tests/pipeline/test_pipeline_extra.py | 80 ++++++++++ 21 files changed, 536 insertions(+), 96 deletions(-) diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 9d3e97421c..e9dcfaf095 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -54,7 +54,10 @@ def get_py_arrow_datatype( elif column_type == "bool": return pyarrow.bool_() elif column_type == "timestamp": - return get_py_arrow_timestamp(column.get("precision") or caps.timestamp_precision, tz) + # sets timezone to None when timezone hint is false + timezone = tz if column.get("timezone", True) else None + precision = column.get("precision") or caps.timestamp_precision + return get_py_arrow_timestamp(precision, timezone) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) elif column_type == "binary": @@ -139,6 +142,10 @@ def get_column_type_from_py_arrow(dtype: pyarrow.DataType) -> TColumnType: precision = 6 else: precision = 9 + + if dtype.tz is None: + return dict(data_type="timestamp", precision=precision, timezone=False) + return dict(data_type="timestamp", precision=precision) elif pyarrow.types.is_date(dtype): return dict(data_type="date") diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index 284c55caac..a81e9046a9 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -94,6 +94,7 @@ class TColumnType(TypedDict, total=False): data_type: Optional[TDataType] precision: Optional[int] scale: Optional[int] + timezone: Optional[bool] class TColumnSchemaBase(TColumnType, total=False): diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index a5a8ae2562..c4a9bab212 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -104,9 +104,9 @@ class AthenaTypeMapper(TypeMapper): def __init__(self, capabilities: DestinationCapabilitiesContext): super().__init__(capabilities) - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") + table_format = table.get("table_format") if precision is None: return "bigint" if precision <= 8: @@ -403,9 +403,9 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: return ( - f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table_format)}" + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table)}" ) def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str: @@ -429,9 +429,9 @@ def _get_table_update_sql( # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables table = self.prepare_load_table(table_name, self.in_staging_mode) - table_format = table.get("table_format") + is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" - columns = ", ".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + columns = ", ".join([self._get_column_def_sql(c, table) for c in new_columns]) # create unique tag for iceberg table so it is never recreated in the same folder # athena requires some kind of special cleaning (or that is a bug) so we cannot refresh diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 1dd4c727be..9bc555bd0d 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -90,9 +90,9 @@ class BigQueryTypeMapper(TypeMapper): "TIME": "time", } - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> str: + def to_db_decimal_type(self, column: TColumnSchema) -> str: # Use BigQuery's BIGNUMERIC for large precision decimals - precision, scale = self.decimal_precision(precision, scale) + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) if precision > 38 or scale > 9: return "BIGNUMERIC(%i,%i)" % (precision, scale) return "NUMERIC(%i,%i)" % (precision, scale) @@ -417,10 +417,10 @@ def _get_info_schema_columns_query( return query, folded_table_names - def _get_column_def_sql(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, column: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(column["name"]) column_def_sql = ( - f"{name} {self.type_mapper.to_db_type(column, table_format)} {self._gen_not_null(column.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" ) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 282fbaf338..038735a84b 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -293,7 +293,7 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [ClickHouseMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: # Build column definition. # The primary key and sort order definition is defined outside column specification. hints_ = " ".join( @@ -307,9 +307,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non # Alter table statements only accept `Nullable` modifiers. # JSON type isn't nullable in ClickHouse. type_with_nullability_modifier = ( - f"Nullable({self.type_mapper.to_db_type(c)})" + f"Nullable({self.type_mapper.to_db_type(c,table)})" if c.get("nullable", True) - else self.type_mapper.to_db_type(c) + else self.type_mapper.to_db_type(c, table) ) return ( diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 614e6e97c5..0c19984b4c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -68,9 +68,8 @@ class DatabricksTypeMapper(TypeMapper): "wei": "DECIMAL(%i,%i)", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "BIGINT" if precision <= 8: @@ -323,10 +322,12 @@ def _create_merge_followup_jobs( return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause - return ["ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c) for c in new_columns)] + return [ + "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) + ] def _get_table_update_sql( self, @@ -351,10 +352,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _get_storage_table_query_columns(self) -> List[str]: diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 149d106dcd..91dc64f113 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -195,10 +195,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _create_merge_followup_jobs( @@ -207,9 +207,13 @@ def _create_merge_followup_jobs( return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: - return ["ADD COLUMNS (" + ", ".join(self._get_column_def_sql(c) for c in new_columns) + ")"] + return [ + "ADD COLUMNS (" + + ", ".join(self._get_column_def_sql(c, table) for c in new_columns) + + ")" + ] def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 3d5905ff40..d5065f5bdd 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -62,9 +62,8 @@ class DuckDbTypeMapper(TypeMapper): "TIMESTAMP_NS": "timestamp", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "BIGINT" # Precision is number of bits @@ -83,19 +82,39 @@ def to_db_integer_type( ) def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone and precision is not None: + raise TerminalValueError( + f"DuckDB does not support both timezone and precision for column '{column_name}' in" + f" table '{table_name}'. To resolve this issue, either set timezone to False or" + " None, or use the default precision." + ) + + if timezone: + return "TIMESTAMP WITH TIME ZONE" + elif timezone is not None: # condition for when timezone is False given that none is falsy + return "TIMESTAMP" + if precision is None or precision == 6: - return super().to_db_datetime_type(precision, table_format) - if precision == 0: + return None + elif precision == 0: return "TIMESTAMP_S" - if precision == 3: + elif precision == 3: return "TIMESTAMP_MS" - if precision == 9: + elif precision == 9: return "TIMESTAMP_NS" + raise TerminalValueError( - f"timestamp with {precision} decimals after seconds cannot be mapped into duckdb" - " TIMESTAMP type" + f"DuckDB does not support precision '{precision}' for '{column_name}' in table" + f" '{table_name}'" ) def from_db_type( @@ -162,7 +181,7 @@ def create_load_job( job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -170,7 +189,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _from_db_type( diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 78a37952b9..02240b8f93 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -41,7 +41,7 @@ LoadJob, ) from dlt.common.pendulum import timedelta -from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema from dlt.common.schema.typing import ( TColumnType, TTableFormat, @@ -105,21 +105,27 @@ class LanceDBTypeMapper(TypeMapper): pa.date32(): "date", } - def to_db_decimal_type( - self, precision: Optional[int], scale: Optional[int] - ) -> pa.Decimal128Type: - precision, scale = self.decimal_precision(precision, scale) + def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) return pa.decimal128(precision, scale) def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> pa.TimestampType: + column_name = column.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + if timezone is not None or precision is not None: + logger.warning( + "LanceDB does not currently support column flags for timezone or precision." + f" These flags were used in column '{column_name}'." + ) unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.timestamp(unit, "UTC") - def to_db_time_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> pa.Time64Type: + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type: unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] return pa.time64(unit) diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index 750dc93a10..a7e796b2d8 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -59,9 +59,8 @@ class MsSqlTypeMapper(TypeMapper): "int": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" if precision <= 8: @@ -166,20 +165,18 @@ def _create_merge_followup_jobs( return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause - return [ - "ADD \n" + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) - ] + return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns db_type = "nvarchar(%i)" % (c.get("precision") or 900) else: - db_type = self.type_mapper.to_db_type(c) + db_type = self.type_mapper.to_db_type(c, table) hints_str = " ".join( self.active_hints.get(h, "") diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index a832bfe07f..5777e46c90 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -66,9 +66,8 @@ class PostgresTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" # Precision is number of bits @@ -82,6 +81,39 @@ def to_db_integer_type( f"bigint with {precision} bits precision cannot be mapped into postgres integer type" ) + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "timestamp" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 6: + timestamp += f" ({precision})" + else: + raise TerminalValueError( + f"Postgres does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + # append timezone part + if timezone is None or timezone: # timezone True and None + timestamp += " with time zone" + else: # timezone is explicitly False + timestamp += " without time zone" + + return timestamp + def from_db_type( self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None ) -> TColumnType: @@ -233,7 +265,7 @@ def create_load_job( job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -241,7 +273,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _create_replace_followup_jobs( diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 0e201dc4e0..9bba60af07 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -82,9 +82,8 @@ class RedshiftTypeMapper(TypeMapper): "integer": "bigint", } - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + precision = column.get("precision") if precision is None: return "bigint" if precision <= 16: @@ -243,7 +242,7 @@ def _create_merge_followup_jobs( ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: hints_str = " ".join( HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() @@ -251,7 +250,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def create_load_job( diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 6688b5bc17..247b3233d0 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -18,7 +18,7 @@ from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat - +from dlt.common.exceptions import TerminalValueError from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import TLoaderFileFormat @@ -77,6 +77,36 @@ def from_db_type( return dict(data_type="decimal", precision=precision, scale=scale) return super().from_db_type(db_type, precision, scale) + def to_db_datetime_type( + self, + column: TColumnSchema, + table: TTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "TIMESTAMP_TZ" + + if timezone is not None and not timezone: # explicitaly handles timezone False + timestamp = "TIMESTAMP_NTZ" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 9: + timestamp += f"({precision})" + else: + raise TerminalValueError( + f"Snowflake does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + return timestamp + class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( @@ -289,12 +319,11 @@ def create_load_job( return job def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: # Override because snowflake requires multiple columns in a single ADD COLUMN clause return [ - "ADD COLUMN\n" - + ",\n".join(self._get_column_def_sql(c, table_format) for c in new_columns) + "ADD COLUMN\n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns) ] def _get_table_update_sql( @@ -320,10 +349,10 @@ def _from_db_type( ) -> TColumnType: return self.type_mapper.from_db_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 1d6403a2c8..3026baf753 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -517,10 +517,10 @@ def _build_schema_update_sql( return sql_updates, schema_update def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None + self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None ) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" - return [f"ADD COLUMN {self._get_column_def_sql(c, table_format)}" for c in new_columns] + return [f"ADD COLUMN {self._get_column_def_sql(c, table)}" for c in new_columns] def _make_create_table(self, qualified_name: str, table: TTableSchema) -> str: not_exists_clause = " " @@ -537,17 +537,16 @@ def _get_table_update_sql( # build sql qualified_name = self.sql_client.make_qualified_table_name(table_name) table = self.prepare_load_table(table_name) - table_format = table.get("table_format") sql_result: List[str] = [] if not generate_alter: # build CREATE sql = self._make_create_table(qualified_name, table) + " (\n" - sql += ",\n".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + sql += ",\n".join([self._get_column_def_sql(c, table) for c in new_columns]) sql += ")" sql_result.append(sql) else: sql_base = f"ALTER TABLE {qualified_name}\n" - add_column_statements = self._make_add_column_sql(new_columns, table_format) + add_column_statements = self._make_add_column_sql(new_columns, table) if self.capabilities.alter_add_multi_column: column_sql = ",\n" sql_result.append(sql_base + column_sql.join(add_column_statements)) @@ -582,7 +581,7 @@ def _get_table_update_sql( return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: pass @staticmethod diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index dcd938b33c..5ac43e4f1f 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,6 +1,13 @@ from typing import Tuple, ClassVar, Dict, Optional -from dlt.common.schema.typing import TColumnSchema, TDataType, TColumnType, TTableFormat +from dlt.common import logger +from dlt.common.schema.typing import ( + TColumnSchema, + TDataType, + TColumnType, + TTableFormat, + TTableSchema, +) from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.utils import without_none @@ -20,39 +27,54 @@ class TypeMapper: def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: self.capabilities = capabilities - def to_db_integer_type( - self, precision: Optional[int], table_format: TTableFormat = None - ) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: # Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.) return self.sct_to_unbound_dbt["bigint"] def to_db_datetime_type( - self, precision: Optional[int], table_format: TTableFormat = None + self, + column: TColumnSchema, + table: TTableSchema = None, ) -> str: # Override in subclass if db supports other timestamp types (e.g. with different time resolutions) + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is not None or precision is not None: + message = ( + "Column flags for timezone or precision are not yet supported in this" + " destination. One or both of these flags were used in column" + f" '{column.get('name')}'." + ) + # TODO: refactor lancedb and wevavite to make table object required + if table: + message += f" in table '{table.get('name')}'." + + logger.warning(message) + return None - def to_db_time_type(self, precision: Optional[int], table_format: TTableFormat = None) -> str: + def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: # Override in subclass if db supports other time types (e.g. with different time resolutions) return None - def to_db_decimal_type(self, precision: Optional[int], scale: Optional[int]) -> str: - precision_tup = self.decimal_precision(precision, scale) + def to_db_decimal_type(self, column: TColumnSchema) -> str: + precision_tup = self.decimal_precision(column.get("precision"), column.get("scale")) if not precision_tup or "decimal" not in self.sct_to_dbt: return self.sct_to_unbound_dbt["decimal"] return self.sct_to_dbt["decimal"] % (precision_tup[0], precision_tup[1]) - def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) -> str: - precision, scale = column.get("precision"), column.get("scale") + # TODO: refactor lancedb and wevavite to make table object required + def to_db_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: sc_t = column["data_type"] if sc_t == "bigint": - db_t = self.to_db_integer_type(precision, table_format) + db_t = self.to_db_integer_type(column, table) elif sc_t == "timestamp": - db_t = self.to_db_datetime_type(precision, table_format) + db_t = self.to_db_datetime_type(column, table) elif sc_t == "time": - db_t = self.to_db_time_type(precision, table_format) + db_t = self.to_db_time_type(column, table) elif sc_t == "decimal": - db_t = self.to_db_decimal_type(precision, scale) + db_t = self.to_db_decimal_type(column) else: db_t = None if db_t: @@ -61,14 +83,16 @@ def to_db_type(self, column: TColumnSchema, table_format: TTableFormat = None) - bounded_template = self.sct_to_dbt.get(sc_t) if not bounded_template: return self.sct_to_unbound_dbt[sc_t] - precision_tuple = self.precision_tuple_or_default(sc_t, precision, scale) + precision_tuple = self.precision_tuple_or_default(sc_t, column) if not precision_tuple: return self.sct_to_unbound_dbt[sc_t] return self.sct_to_dbt[sc_t] % precision_tuple def precision_tuple_or_default( - self, data_type: TDataType, precision: Optional[int], scale: Optional[int] + self, data_type: TDataType, column: TColumnSchema ) -> Optional[Tuple[int, ...]]: + precision = column.get("precision") + scale = column.get("scale") if data_type in ("timestamp", "time"): if precision is None: return None # Use default which is usually the max diff --git a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md index 19cef92f9d..4b8ecec4ca 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/duckdb.md +++ b/docs/website/docs/dlt-ecosystem/destinations/duckdb.md @@ -35,6 +35,42 @@ All write dispositions are supported. ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). If you are okay with installing `pyarrow`, we suggest switching to `parquet` as the file format. Loading is faster (and also multithreaded). +### Data types +`duckdb` supports various [timestamp types](https://duckdb.org/docs/sql/data_types/timestamp.html). These can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: supported precision values are 0, 3, 6, and 9 for fractional seconds. Note that `timezone` and `precision` cannot be used together; attempting to combine them will result in an error. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP WITH TIME ZONE` (`TIMESTAMPTZ`). + +#### Example precision: TIMESTAMP_MS + +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="duckdb") +pipeline.run(events()) +``` + +#### Example timezone: TIMESTAMP + +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + +pipeline = dlt.pipeline(destination="duckdb") +pipeline.run(events()) +``` + ### Names normalization `dlt` uses the standard **snake_case** naming convention to keep identical table and column identifiers across all destinations. If you want to use the **duckdb** wide range of characters (i.e., emojis) for table and column names, you can switch to the **duck_case** naming convention, which accepts almost any string as an identifier: * `\n` `\r` and `"` are translated to `_` @@ -77,7 +113,8 @@ to disable tz adjustments. ::: ## Supported column hints -`duckdb` may create unique indexes for all columns with `unique` hints, but this behavior **is disabled by default** because it slows the loading down significantly. + +`duckdb` can create unique indexes for columns with `unique` hints. However, **this feature is disabled by default** as it can significantly slow down data loading. ## Destination Configuration diff --git a/docs/website/docs/dlt-ecosystem/destinations/postgres.md b/docs/website/docs/dlt-ecosystem/destinations/postgres.md index 1281298312..e506eb79fe 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/postgres.md +++ b/docs/website/docs/dlt-ecosystem/destinations/postgres.md @@ -82,6 +82,27 @@ If you set the [`replace` strategy](../../general-usage/full-loading.md) to `sta ## Data loading `dlt` will load data using large INSERT VALUES statements by default. Loading is multithreaded (20 threads by default). +### Data types +`postgres` supports various timestamp types, which can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: allows you to specify the number of decimal places for fractional seconds, ranging from 0 to 6. It can be used in combination with the `timezone` flag. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP WITHOUT TIME ZONE`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP WITH TIME ZONE`. + +#### Example precision and timezone: TIMESTAMP (3) WITHOUT TIME ZONE +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3, "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="postgres") +pipeline.run(events()) +``` + ### Fast loading with arrow tables and csv You can use [arrow tables](../verified-sources/arrow-pandas.md) and [csv](../file-formats/csv.md) to quickly load tabular data. Pick the `csv` loader file format like below diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 57e6db311d..f4d5a53d36 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -143,6 +143,27 @@ The data is loaded using an internal Snowflake stage. We use the `PUT` command a keep_staged_files = false ``` +### Data types +`snowflake` supports various timestamp types, which can be configured using the column flags `timezone` and `precision` in the `dlt.resource` decorator or the `pipeline.run` method. + +- **Precision**: allows you to specify the number of decimal places for fractional seconds, ranging from 0 to 9. It can be used in combination with the `timezone` flag. +- **Timezone**: + - Setting `timezone=False` maps to `TIMESTAMP_NTZ`. + - Setting `timezone=True` (or omitting the flag, which defaults to `True`) maps to `TIMESTAMP_TZ`. + +#### Example precision and timezone: TIMESTAMP_NTZ(3) +```py +@dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": 3, "timezone": False}}, + primary_key="event_id", +) +def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123"}] + +pipeline = dlt.pipeline(destination="snowflake") +pipeline.run(events()) +``` + ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default * [parquet](../file-formats/parquet.md) is supported diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 81c9292570..2792cec085 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -17,6 +17,7 @@ from dlt.common.schema.utils import new_table from dlt.common.typing import TDataItem from dlt.common.utils import uniq_id +from dlt.common.exceptions import TerminalValueError from dlt.destinations.exceptions import DatabaseUndefinedRelation from dlt.destinations import filesystem, redshift @@ -1146,3 +1147,150 @@ def _data(): dataset_name=dataset_name, ) return p, _data + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "postgres", "snowflake"]), + ids=lambda x: x.name, +) +def test_dest_column_invalid_timestamp_precision( + destination_config: DestinationTestConfiguration, +) -> None: + invalid_precision = 10 + + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "precision": invalid_precision}}, + primary_key="event_id", + ) + def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + + pipeline = destination_config.setup_pipeline(uniq_id()) + + with pytest.raises((TerminalValueError, PipelineStepFailed)): + pipeline.run(events()) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["duckdb", "snowflake", "postgres"]), + ids=lambda x: x.name, +) +def test_dest_column_hint_timezone(destination_config: DestinationTestConfiguration) -> None: + destination = destination_config.destination + + input_data = [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + output_values = [ + "2024-07-30T10:00:00.123000", + "2024-07-30T08:00:00.123456", + "2024-07-30T10:00:00.123456", + ] + + output_map = { + "postgres": { + "tables": { + "events_timezone_off": { + "timestamp_type": "timestamp without time zone", + "timestamp_values": output_values, + }, + "events_timezone_on": { + "timestamp_type": "timestamp with time zone", + "timestamp_values": output_values, + }, + "events_timezone_unset": { + "timestamp_type": "timestamp with time zone", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='experiments'" + " AND table_name = '%s' AND column_name = 'event_tstamp'" + ), + }, + "snowflake": { + "tables": { + "EVENTS_TIMEZONE_OFF": { + "timestamp_type": "TIMESTAMP_NTZ", + "timestamp_values": output_values, + }, + "EVENTS_TIMEZONE_ON": { + "timestamp_type": "TIMESTAMP_TZ", + "timestamp_values": output_values, + }, + "EVENTS_TIMEZONE_UNSET": { + "timestamp_type": "TIMESTAMP_TZ", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='EXPERIMENTS'" + " AND table_name = '%s' AND column_name = 'EVENT_TSTAMP'" + ), + }, + "duckdb": { + "tables": { + "events_timezone_off": { + "timestamp_type": "TIMESTAMP", + "timestamp_values": output_values, + }, + "events_timezone_on": { + "timestamp_type": "TIMESTAMP WITH TIME ZONE", + "timestamp_values": output_values, + }, + "events_timezone_unset": { + "timestamp_type": "TIMESTAMP WITH TIME ZONE", + "timestamp_values": output_values, + }, + }, + "query_data_type": ( + "SELECT data_type FROM information_schema.columns WHERE table_schema ='experiments'" + " AND table_name = '%s' AND column_name = 'event_tstamp'" + ), + }, + } + + # table: events_timezone_off + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", + ) + def events_timezone_off(): + yield input_data + + # table: events_timezone_on + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True}}, + primary_key="event_id", + ) + def events_timezone_on(): + yield input_data + + # table: events_timezone_unset + @dlt.resource( + primary_key="event_id", + ) + def events_timezone_unset(): + yield input_data + + pipeline = destination_config.setup_pipeline( + f"{destination}_" + uniq_id(), dataset_name="experiments" + ) + + pipeline.run([events_timezone_off(), events_timezone_on(), events_timezone_unset()]) + + with pipeline.sql_client() as client: + for t in output_map[destination]["tables"].keys(): # type: ignore + # check data type + column_info = client.execute_sql(output_map[destination]["query_data_type"] % t) + assert column_info[0][0] == output_map[destination]["tables"][t]["timestamp_type"] # type: ignore + # check timestamp data + rows = client.execute_sql(f"SELECT event_tstamp FROM {t} ORDER BY event_id") + + values = [r[0].strftime("%Y-%m-%dT%H:%M:%S.%f") for r in rows] + assert values == output_map[destination]["tables"][t]["timestamp_values"] # type: ignore diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 027a2b4e72..918f9beab9 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -29,7 +29,7 @@ DestinationTerminalException, UnknownDestinationModule, ) -from dlt.common.exceptions import PipelineStateNotAvailable +from dlt.common.exceptions import PipelineStateNotAvailable, TerminalValueError from dlt.common.pipeline import LoadInfo, PipelineContext from dlt.common.runtime.collector import LogCollector from dlt.common.schema.exceptions import TableIdentifiersFrozen @@ -2729,3 +2729,18 @@ def assert_imported_file( extract_info.metrics[extract_info.loads_ids[0]][0]["table_metrics"][table_name].items_count == expected_rows ) + + +def test_duckdb_column_invalid_timestamp() -> None: + # DuckDB does not have timestamps with timezone and precision + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True, "precision": 3}}, + primary_key="event_id", + ) + def events(): + yield [{"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}] + + pipeline = dlt.pipeline(destination="duckdb") + + with pytest.raises((TerminalValueError, PipelineStepFailed)): + pipeline.run(events()) diff --git a/tests/pipeline/test_pipeline_extra.py b/tests/pipeline/test_pipeline_extra.py index d3e44198b4..c757959bec 100644 --- a/tests/pipeline/test_pipeline_extra.py +++ b/tests/pipeline/test_pipeline_extra.py @@ -22,6 +22,7 @@ class BaseModel: # type: ignore[no-redef] from dlt.common import json, pendulum from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.capabilities import TLoaderFileFormat +from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.common.runtime.collector import ( AliveCollector, EnlightenCollector, @@ -599,3 +600,82 @@ def test_pick_matching_file_format(test_storage: FileStorage) -> None: files = test_storage.list_folder_files("user_data_csv/object") assert len(files) == 1 assert files[0].endswith("csv") + + +def test_filesystem_column_hint_timezone() -> None: + import pyarrow.parquet as pq + import posixpath + + os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "_storage" + + # talbe: events_timezone_off + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": False}}, + primary_key="event_id", + ) + def events_timezone_off(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + # talbe: events_timezone_on + @dlt.resource( + columns={"event_tstamp": {"data_type": "timestamp", "timezone": True}}, + primary_key="event_id", + ) + def events_timezone_on(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + # talbe: events_timezone_unset + @dlt.resource( + primary_key="event_id", + ) + def events_timezone_unset(): + yield [ + {"event_id": 1, "event_tstamp": "2024-07-30T10:00:00.123+00:00"}, + {"event_id": 2, "event_tstamp": "2024-07-30T10:00:00.123456+02:00"}, + {"event_id": 3, "event_tstamp": "2024-07-30T10:00:00.123456"}, + ] + + pipeline = dlt.pipeline(destination="filesystem") + + pipeline.run( + [events_timezone_off(), events_timezone_on(), events_timezone_unset()], + loader_file_format="parquet", + ) + + client: FilesystemClient = pipeline.destination_client() # type: ignore[assignment] + + expected_results = { + "events_timezone_off": None, + "events_timezone_on": "UTC", + "events_timezone_unset": "UTC", + } + + for t in expected_results.keys(): + events_glob = posixpath.join(client.dataset_path, f"{t}/*") + events_files = client.fs_client.glob(events_glob) + + with open(events_files[0], "rb") as f: + table = pq.read_table(f) + + # convert the timestamps to strings + timestamps = [ + ts.as_py().strftime("%Y-%m-%dT%H:%M:%S.%f") for ts in table.column("event_tstamp") + ] + assert timestamps == [ + "2024-07-30T10:00:00.123000", + "2024-07-30T08:00:00.123456", + "2024-07-30T10:00:00.123456", + ] + + # check if the Parquet file contains timezone information + schema = table.schema + field = schema.field("event_tstamp") + assert field.type.tz == expected_results[t]