diff --git a/dlt/cli/_dlt.py b/dlt/cli/_dlt.py index af4f2f66e9..7c6526c0a2 100644 --- a/dlt/cli/_dlt.py +++ b/dlt/cli/_dlt.py @@ -164,7 +164,7 @@ def schema_command_wrapper(file_path: str, format_: str, remove_defaults: bool) schema_str = json.dumps(s.to_dict(remove_defaults=remove_defaults), pretty=True) else: schema_str = s.to_pretty_yaml(remove_defaults=remove_defaults) - print(schema_str) + fmt.echo(schema_str) return 0 diff --git a/dlt/cli/pipeline_command.py b/dlt/cli/pipeline_command.py index d66d884ff2..6aa479a398 100644 --- a/dlt/cli/pipeline_command.py +++ b/dlt/cli/pipeline_command.py @@ -8,7 +8,12 @@ from dlt.common.destination.reference import TDestinationReferenceArg from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout -from dlt.common.schema.utils import group_tables_by_resource, remove_defaults +from dlt.common.schema.utils import ( + group_tables_by_resource, + has_table_seen_data, + is_complete_column, + remove_defaults, +) from dlt.common.storages import FileStorage, PackageStorage from dlt.pipeline.helpers import DropCommand from dlt.pipeline.exceptions import CannotRestorePipelineException @@ -180,6 +185,35 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.bold(str(res_state_slots)), ) ) + if verbosity > 0: + for table in tables: + incomplete_columns = len( + [ + col + for col in table["columns"].values() + if not is_complete_column(col) + ] + ) + fmt.echo( + "\t%s table %s column(s) %s %s" + % ( + fmt.bold(table["name"]), + fmt.bold(str(len(table["columns"]))), + ( + fmt.style("received data", fg="green") + if has_table_seen_data(table) + else fmt.style("not yet received data", fg="yellow") + ), + ( + fmt.style( + f"{incomplete_columns} incomplete column(s)", + fg="yellow", + ) + if incomplete_columns > 0 + else "" + ), + ) + ) fmt.echo() fmt.echo("Working dir content:") _display_pending_packages() @@ -272,7 +306,7 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.echo(package_info.asstr(verbosity)) if len(package_info.schema_update) > 0: if verbosity == 0: - print("Add -v option to see schema update. Note that it could be large.") + fmt.echo("Add -v option to see schema update. Note that it could be large.") else: tables = remove_defaults({"tables": package_info.schema_update}) # type: ignore fmt.echo(fmt.bold("Schema update:")) @@ -316,7 +350,7 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: fmt.echo( "About to drop the following data in dataset %s in destination %s:" % ( - fmt.bold(drop.info["dataset_name"]), + fmt.bold(p.dataset_name), fmt.bold(p.destination.destination_name), ) ) @@ -329,6 +363,10 @@ def _display_pending_packages() -> Tuple[Sequence[str], Sequence[str]]: ) ) fmt.echo("%s: %s" % (fmt.style("Table(s) to drop", fg="green"), drop.info["tables"])) + fmt.echo( + "%s: %s" + % (fmt.style("\twith data in destination", fg="green"), drop.info["tables_with_data"]) + ) fmt.echo( "%s: %s" % ( diff --git a/dlt/common/configuration/paths.py b/dlt/common/configuration/paths.py index 89494ba6bd..9d0b47f8b6 100644 --- a/dlt/common/configuration/paths.py +++ b/dlt/common/configuration/paths.py @@ -1,16 +1,16 @@ import os import tempfile -# dlt settings folder -DOT_DLT = ".dlt" +from dlt.common import known_env + -# dlt data dir is by default not set, see get_dlt_data_dir for details -DLT_DATA_DIR: str = None +# dlt settings folder +DOT_DLT = os.environ.get(known_env.DLT_CONFIG_FOLDER, ".dlt") def get_dlt_project_dir() -> str: """The dlt project dir is the current working directory but may be overridden by DLT_PROJECT_DIR env variable.""" - return os.environ.get("DLT_PROJECT_DIR", ".") + return os.environ.get(known_env.DLT_PROJECT_DIR, ".") def get_dlt_settings_dir() -> str: @@ -27,14 +27,14 @@ def make_dlt_settings_path(path: str) -> str: def get_dlt_data_dir() -> str: - """Gets default directory where pipelines' data will be stored - 1. in user home directory: ~/.dlt/ - 2. if current user is root: in /var/dlt/ - 3. if current user does not have a home directory: in /tmp/dlt/ - 4. if DLT_DATA_DIR is set in env then it is used + """Gets default directory where pipelines' data (working directories) will be stored + 1. if DLT_DATA_DIR is set in env then it is used + 2. in user home directory: ~/.dlt/ + 3. if current user is root: in /var/dlt/ + 4. if current user does not have a home directory: in /tmp/dlt/ """ - if "DLT_DATA_DIR" in os.environ: - return os.environ["DLT_DATA_DIR"] + if known_env.DLT_DATA_DIR in os.environ: + return os.environ[known_env.DLT_DATA_DIR] # geteuid not available on Windows if hasattr(os, "geteuid") and os.geteuid() == 0: diff --git a/dlt/common/json/__init__.py b/dlt/common/json/__init__.py index cf68e5d3d4..00d8dcc430 100644 --- a/dlt/common/json/__init__.py +++ b/dlt/common/json/__init__.py @@ -12,6 +12,7 @@ except ImportError: PydanticBaseModel = None # type: ignore[misc] +from dlt.common import known_env from dlt.common.pendulum import pendulum from dlt.common.arithmetics import Decimal from dlt.common.wei import Wei @@ -80,7 +81,7 @@ def custom_encode(obj: Any) -> str: # use PUA range to encode additional types -PUA_START = int(os.environ.get("DLT_JSON_TYPED_PUA_START", "0xf026"), 16) +PUA_START = int(os.environ.get(known_env.DLT_JSON_TYPED_PUA_START, "0xf026"), 16) _DECIMAL = chr(PUA_START) _DATETIME = chr(PUA_START + 1) @@ -191,7 +192,7 @@ def may_have_pua(line: bytes) -> bool: # pick the right impl json: SupportsJson = None -if os.environ.get("DLT_USE_JSON") == "simplejson": +if os.environ.get(known_env.DLT_USE_JSON) == "simplejson": from dlt.common.json import _simplejson as _json_d json = _json_d # type: ignore[assignment] diff --git a/dlt/common/known_env.py b/dlt/common/known_env.py new file mode 100644 index 0000000000..7ac36d252d --- /dev/null +++ b/dlt/common/known_env.py @@ -0,0 +1,25 @@ +"""Defines env variables that `dlt` uses independently of its configuration system""" + +DLT_PROJECT_DIR = "DLT_PROJECT_DIR" +"""The dlt project dir is the current working directory, '.' (current working dir) by default""" + +DLT_DATA_DIR = "DLT_DATA_DIR" +"""Gets default directory where pipelines' data (working directories) will be stored""" + +DLT_CONFIG_FOLDER = "DLT_CONFIG_FOLDER" +"""A folder (path relative to DLT_PROJECT_DIR) where config and secrets are stored""" + +DLT_DEFAULT_NAMING_NAMESPACE = "DLT_DEFAULT_NAMING_NAMESPACE" +"""Python namespace default where naming modules reside, defaults to dlt.common.normalizers.naming""" + +DLT_DEFAULT_NAMING_MODULE = "DLT_DEFAULT_NAMING_MODULE" +"""A module name with the default naming convention, defaults to snake_case""" + +DLT_DLT_ID_LENGTH_BYTES = "DLT_DLT_ID_LENGTH_BYTES" +"""The length of the _dlt_id identifier, before base64 encoding""" + +DLT_USE_JSON = "DLT_USE_JSON" +"""Type of json parser to use, defaults to orjson, may be simplejson""" + +DLT_JSON_TYPED_PUA_START = "DLT_JSON_TYPED_PUA_START" +"""Start of the unicode block within the PUA used to encode types in typed json""" diff --git a/dlt/common/normalizers/utils.py b/dlt/common/normalizers/utils.py index beacf03e4e..d852cfb7d9 100644 --- a/dlt/common/normalizers/utils.py +++ b/dlt/common/normalizers/utils.py @@ -1,9 +1,11 @@ +import os from importlib import import_module from types import ModuleType from typing import Any, Dict, Optional, Type, Tuple, cast, List import dlt from dlt.common import logger +from dlt.common import known_env from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import known_sections from dlt.common.destination import DestinationCapabilitiesContext @@ -24,9 +26,11 @@ from dlt.common.typing import is_subclass from dlt.common.utils import get_full_class_name, uniq_id_base64, many_uniq_ids_base64 -DEFAULT_NAMING_NAMESPACE = "dlt.common.normalizers.naming" -DLT_ID_LENGTH_BYTES = 10 -DEFAULT_NAMING_MODULE = "snake_case" +DEFAULT_NAMING_NAMESPACE = os.environ.get( + known_env.DLT_DEFAULT_NAMING_NAMESPACE, "dlt.common.normalizers.naming" +) +DEFAULT_NAMING_MODULE = os.environ.get(known_env.DLT_DEFAULT_NAMING_MODULE, "snake_case") +DLT_ID_LENGTH_BYTES = int(os.environ.get(known_env.DLT_DLT_ID_LENGTH_BYTES, 10)) def _section_for_schema(kwargs: Dict[str, Any]) -> Tuple[str, ...]: diff --git a/dlt/common/runners/stdout.py b/dlt/common/runners/stdout.py index 6a92838342..bb5251764c 100644 --- a/dlt/common/runners/stdout.py +++ b/dlt/common/runners/stdout.py @@ -21,11 +21,11 @@ def exec_to_stdout(f: AnyFun) -> Iterator[Any]: rv = f() yield rv except Exception as ex: - print(encode_obj(ex), file=sys.stderr, flush=True) + print(encode_obj(ex), file=sys.stderr, flush=True) # noqa raise finally: if rv is not None: - print(encode_obj(rv), flush=True) + print(encode_obj(rv), flush=True) # noqa def iter_std( @@ -126,6 +126,6 @@ def iter_stdout_with_result( if isinstance(exception, Exception): raise exception from cpe else: - print(cpe.stderr, file=sys.stderr) + sys.stderr.write(cpe.stderr) # otherwise reraise cpe raise diff --git a/dlt/common/runtime/collector.py b/dlt/common/runtime/collector.py index e00bca576e..95117b70cc 100644 --- a/dlt/common/runtime/collector.py +++ b/dlt/common/runtime/collector.py @@ -230,7 +230,7 @@ def _log(self, log_level: int, log_message: str) -> None: if isinstance(self.logger, (logging.Logger, logging.LoggerAdapter)): self.logger.log(log_level, log_message) else: - print(log_message, file=self.logger or sys.stdout) + print(log_message, file=self.logger or sys.stdout) # noqa def _start(self, step: str) -> None: self.counters = defaultdict(int) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index 52f8545587..9ef638e289 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -442,7 +442,7 @@ def drop_tables( """Drops tables from the schema and returns the dropped tables""" result = [] for table_name in table_names: - table = self.tables.get(table_name) + table = self.get_table(table_name) if table and (not seen_data_only or utils.has_table_seen_data(table)): result.append(self._schema_tables.pop(table_name)) return result @@ -555,9 +555,16 @@ def data_tables( ) ] - def data_table_names(self) -> List[str]: + def data_table_names( + self, seen_data_only: bool = False, include_incomplete: bool = False + ) -> List[str]: """Returns list of table table names. Excludes dlt table names.""" - return [t["name"] for t in self.data_tables()] + return [ + t["name"] + for t in self.data_tables( + seen_data_only=seen_data_only, include_incomplete=include_incomplete + ) + ] def dlt_tables(self) -> List[TTableSchema]: """Gets dlt tables""" @@ -728,6 +735,14 @@ def update_normalizers(self) -> None: self._configure_normalizers(explicit_normalizers(schema_name=self._schema_name)) self._compile_settings() + def will_update_normalizers(self) -> bool: + """Checks if schema has any pending normalizer updates due to configuration or destination capabilities""" + # import desired modules + _, to_naming, _ = import_normalizers( + explicit_normalizers(schema_name=self._schema_name), self._normalizers_config + ) + return type(to_naming) is not type(self.naming) # noqa + def set_schema_contract(self, settings: TSchemaContract) -> None: if not settings: self._settings.pop("schema_contract", None) @@ -967,42 +982,91 @@ def _verify_update_normalizers( from_naming: NamingConvention, ) -> TSchemaTables: """Verifies if normalizers can be updated before schema is changed""" - # print(f"{self.name}: {type(to_naming)} {type(naming_module)}") - if from_naming and type(from_naming) is not type(to_naming): + allow_ident_change = normalizers_config.get( + "allow_identifier_change_on_table_with_data", False + ) + + def _verify_identifiers(table: TTableSchema, norm_table: TTableSchema) -> None: + if not allow_ident_change: + # make sure no identifier got changed in table + if norm_table["name"] != table["name"]: + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + f"Attempt to rename table name to {norm_table['name']}.", + ) + # if len(norm_table["columns"]) != len(table["columns"]): + # print(norm_table["columns"]) + # raise TableIdentifiersFrozen( + # self.name, + # table["name"], + # to_naming, + # from_naming, + # "Number of columns changed after normalization. Some columns must have" + # " merged.", + # ) + col_diff = set(norm_table["columns"].keys()).symmetric_difference( + table["columns"].keys() + ) + if len(col_diff) > 0: + raise TableIdentifiersFrozen( + self.name, + table["name"], + to_naming, + from_naming, + f"Some columns got renamed to {col_diff}.", + ) + + naming_changed = from_naming and type(from_naming) is not type(to_naming) + if naming_changed: schema_tables = {} - for table in self._schema_tables.values(): + # check dlt tables + schema_seen_data = any( + utils.has_table_seen_data(t) for t in self._schema_tables.values() + ) + # modify dlt tables using original naming + orig_dlt_tables = [ + (self.version_table_name, utils.version_table()), + (self.loads_table_name, utils.loads_table()), + (self.state_table_name, utils.pipeline_state_table(add_dlt_id=True)), + ] + for existing_table_name, original_table in orig_dlt_tables: + table = self._schema_tables.get(existing_table_name) + # state table is optional + if table: + table = copy(table) + # keep all attributes of the schema table, copy only what we need to normalize + table["columns"] = original_table["columns"] + norm_table = utils.normalize_table_identifiers(table, to_naming) + table_seen_data = utils.has_table_seen_data(norm_table) + if schema_seen_data: + _verify_identifiers(table, norm_table) + schema_tables[norm_table["name"]] = norm_table + + schema_seen_data = False + for table in self.data_tables(include_incomplete=True): + # TODO: when lineage is fully implemented we should use source identifiers + # not `table` which was already normalized norm_table = utils.normalize_table_identifiers(table, to_naming) - if utils.has_table_seen_data(norm_table) and not normalizers_config.get( - "allow_identifier_change_on_table_with_data", False - ): - # make sure no identifier got changed in table - if norm_table["name"] != table["name"]: - raise TableIdentifiersFrozen( - self.name, - table["name"], - to_naming, - from_naming, - f"Attempt to rename table name to {norm_table['name']}.", - ) - if len(norm_table["columns"]) != len(table["columns"]): - raise TableIdentifiersFrozen( - self.name, - table["name"], - to_naming, - from_naming, - "Number of columns changed after normalization. Some columns must have" - " merged.", - ) - col_diff = set(norm_table["columns"].keys()).difference(table["columns"].keys()) - if len(col_diff) > 0: - raise TableIdentifiersFrozen( - self.name, - table["name"], - to_naming, - from_naming, - f"Some columns got renamed to {col_diff}.", - ) + table_seen_data = utils.has_table_seen_data(norm_table) + if table_seen_data: + _verify_identifiers(table, norm_table) schema_tables[norm_table["name"]] = norm_table + schema_seen_data |= table_seen_data + if schema_seen_data and not allow_ident_change: + # if any of the tables has seen data, fail naming convention change + # NOTE: this will be dropped with full identifier lineage. currently we cannot detect + # strict schemas being changed to lax + raise TableIdentifiersFrozen( + self.name, + "-", + to_naming, + from_naming, + "Schema contains tables that received data. As a precaution changing naming" + " conventions is disallowed until full identifier lineage is implemented.", + ) # re-index the table names return schema_tables else: diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index f5765be351..cd0cc5aa63 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -802,23 +802,26 @@ def loads_table() -> TTableSchema: return table -def pipeline_state_table() -> TTableSchema: +def pipeline_state_table(add_dlt_id: bool = False) -> TTableSchema: # NOTE: always add new columns at the end of the table so we have identical layout # after an update of existing tables (always at the end) # set to nullable so we can migrate existing tables # WARNING: do not reorder the columns + columns: List[TColumnSchema] = [ + {"name": "version", "data_type": "bigint", "nullable": False}, + {"name": "engine_version", "data_type": "bigint", "nullable": False}, + {"name": "pipeline_name", "data_type": "text", "nullable": False}, + {"name": "state", "data_type": "text", "nullable": False}, + {"name": "created_at", "data_type": "timestamp", "nullable": False}, + {"name": "version_hash", "data_type": "text", "nullable": True}, + {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, + ] + if add_dlt_id: + columns.append({"name": "_dlt_id", "data_type": "text", "nullable": False, "unique": True}) table = new_table( PIPELINE_STATE_TABLE_NAME, write_disposition="append", - columns=[ - {"name": "version", "data_type": "bigint", "nullable": False}, - {"name": "engine_version", "data_type": "bigint", "nullable": False}, - {"name": "pipeline_name", "data_type": "text", "nullable": False}, - {"name": "state", "data_type": "text", "nullable": False}, - {"name": "created_at", "data_type": "timestamp", "nullable": False}, - {"name": "version_hash", "data_type": "text", "nullable": True}, - {"name": "_dlt_load_id", "data_type": "text", "nullable": False}, - ], + columns=columns, # always use caps preferred file format for processing file_format="preferred", ) diff --git a/dlt/common/storages/fsspec_filesystem.py b/dlt/common/storages/fsspec_filesystem.py index f419baed03..be9ae2bbb1 100644 --- a/dlt/common/storages/fsspec_filesystem.py +++ b/dlt/common/storages/fsspec_filesystem.py @@ -319,7 +319,7 @@ def glob_files( rel_path = pathlib.Path(file).relative_to(root_dir).as_posix() file_url = FilesystemConfiguration.make_file_uri(file) else: - rel_path = posixpath.relpath(file, root_dir) + rel_path = posixpath.relpath(file.lstrip("/"), root_dir) file_url = bucket_url_parsed._replace( path=posixpath.join(bucket_url_parsed.path, rel_path) ).geturl() diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index 9e3185221d..4d84094427 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -72,7 +72,14 @@ class TPipelineStateDoc(TypedDict, total=False): _dlt_load_id: NotRequired[str] -class TLoadPackageState(TVersionedState, total=False): +class TLoadPackageDropTablesState(TypedDict): + dropped_tables: NotRequired[List[TTableSchema]] + """List of tables that are to be dropped from the schema and destination (i.e. when `refresh` mode is used)""" + truncated_tables: NotRequired[List[TTableSchema]] + """List of tables that are to be truncated in the destination (i.e. when `refresh='drop_data'` mode is used)""" + + +class TLoadPackageState(TVersionedState, TLoadPackageDropTablesState, total=False): created_at: DateTime """Timestamp when the load package was created""" pipeline_state: NotRequired[TPipelineStateDoc] @@ -82,11 +89,6 @@ class TLoadPackageState(TVersionedState, total=False): destination_state: NotRequired[Dict[str, Any]] """private space for destinations to store state relevant only to the load package""" - dropped_tables: NotRequired[List[TTableSchema]] - """List of tables that are to be dropped from the schema and destination (i.e. when `refresh` mode is used)""" - truncated_tables: NotRequired[List[TTableSchema]] - """List of tables that are to be truncated in the destination (i.e. when `refresh='drop_data'` mode is used)""" - class TLoadPackage(TypedDict, total=False): load_id: str diff --git a/dlt/common/utils.py b/dlt/common/utils.py index 8e89556c39..7109daf497 100644 --- a/dlt/common/utils.py +++ b/dlt/common/utils.py @@ -137,7 +137,7 @@ def flatten_list_of_str_or_dicts(seq: Sequence[Union[StrAny, str]]) -> DictStrAn else: key = str(e) if key in o: - raise KeyError(f"Cannot flatten with duplicate key {k}") + raise KeyError(f"Cannot flatten with duplicate key {key}") o[key] = None return o diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 8d0ffb1d0c..2b76ca782e 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -17,6 +17,7 @@ import re from contextlib import contextmanager +from fsspec import AbstractFileSystem from pendulum.datetime import DateTime, Date from datetime import datetime # noqa: I251 @@ -33,7 +34,8 @@ from dlt.common import logger from dlt.common.exceptions import TerminalValueError -from dlt.common.utils import without_none +from dlt.common.storages.fsspec_filesystem import fsspec_from_config +from dlt.common.utils import uniq_id, without_none from dlt.common.schema import TColumnSchema, Schema, TTableSchema from dlt.common.schema.typing import ( TTableSchema, @@ -425,8 +427,12 @@ def _get_table_update_sql( is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" columns = ", ".join([self._get_column_def_sql(c, table_format) for c in new_columns]) + # create unique tag for iceberg table so it is never recreated in the same folder + # athena requires some kind of special cleaning (or that is a bug) so we cannot refresh + # iceberg tables without it + location_tag = uniq_id(6) if is_iceberg else "" # this will fail if the table prefix is not properly defined - table_prefix = self.table_prefix_layout.format(table_name=table_name) + table_prefix = self.table_prefix_layout.format(table_name=table_name + location_tag) location = f"{bucket}/{dataset}/{table_prefix}" # use qualified table names diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index c3a1be4174..d0052c22f0 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -20,7 +20,6 @@ SupportsStagingDestination, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.exceptions import UnknownTableException from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat from dlt.common.schema.utils import get_inherited_table_hint from dlt.common.schema.utils import table_schema_has_type @@ -29,9 +28,10 @@ from dlt.destinations.job_impl import DestinationJsonlLoadJob, DestinationParquetLoadJob from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.exceptions import ( + DatabaseTransientException, + DatabaseUndefinedRelation, DestinationSchemaWillNotUpdate, DestinationTerminalException, - DestinationTransientException, LoadJobNotExistsException, LoadJobTerminalException, ) @@ -226,7 +226,7 @@ def restore_file_load(self, file_path: str) -> LoadJob: file_path, f"The server reason was: {reason}" ) from gace else: - raise DestinationTransientException(gace) from gace + raise DatabaseTransientException(gace) from gace return job def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob: @@ -271,7 +271,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> reason = BigQuerySqlClient._get_reason_from_errors(gace) if reason == "notFound": # google.api_core.exceptions.NotFound: 404 – table not found - raise UnknownTableException(self.schema.name, table["name"]) from gace + raise DatabaseUndefinedRelation(gace) from gace elif ( reason == "duplicate" ): # google.api_core.exceptions.Conflict: 409 PUT – already exists @@ -282,7 +282,7 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> file_path, f"The server reason was: {reason}" ) from gace else: - raise DestinationTransientException(gace) from gace + raise DatabaseTransientException(gace) from gace return job diff --git a/dlt/destinations/impl/bigquery/sql_client.py b/dlt/destinations/impl/bigquery/sql_client.py index 45e9379af5..e6aee1fc43 100644 --- a/dlt/destinations/impl/bigquery/sql_client.py +++ b/dlt/destinations/impl/bigquery/sql_client.py @@ -193,20 +193,21 @@ def create_dataset(self) -> None: dataset = bigquery.Dataset(self.fully_qualified_dataset_name(escape=False)) dataset.location = self.location dataset.is_case_insensitive = not self.capabilities.has_case_sensitive_identifiers - self._client.create_dataset( - dataset, - retry=self._default_retry, - timeout=self.http_timeout, - ) - - def drop_dataset(self) -> None: - self._client.delete_dataset( - self.fully_qualified_dataset_name(escape=False), - not_found_ok=True, - delete_contents=True, - retry=self._default_retry, - timeout=self.http_timeout, - ) + try: + self._client.create_dataset( + dataset, + retry=self._default_retry, + timeout=self.http_timeout, + ) + except api_core_exceptions.GoogleAPICallError as gace: + reason = BigQuerySqlClient._get_reason_from_errors(gace) + if reason == "notFound": + # google.api_core.exceptions.NotFound: 404 – table not found + raise DatabaseUndefinedRelation(gace) from gace + elif reason in BQ_TERMINAL_REASONS: + raise DatabaseTerminalException(gace) from gace + else: + raise DatabaseTransientException(gace) from gace def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any diff --git a/dlt/destinations/impl/clickhouse/sql_client.py b/dlt/destinations/impl/clickhouse/sql_client.py index ee013ea123..8544643017 100644 --- a/dlt/destinations/impl/clickhouse/sql_client.py +++ b/dlt/destinations/impl/clickhouse/sql_client.py @@ -107,16 +107,34 @@ def create_dataset(self) -> None: ) def drop_dataset(self) -> None: + # always try to drop sentinel table + sentinel_table_name = self.make_qualified_table_name( + self.credentials.dataset_sentinel_table_name + ) + # drop a sentinel table + self.execute_sql(f"DROP TABLE {sentinel_table_name} SYNC") + # Since ClickHouse doesn't have schemas, we need to drop all tables in our virtual schema, # or collection of tables, that has the `dataset_name` as a prefix. - to_drop_results = self._list_tables() + to_drop_results = [ + f"{self.catalog_name()}.{self.capabilities.escape_identifier(table)}" + for table in self._list_tables() + ] for table in to_drop_results: # The "DROP TABLE" clause is discarded if we allow clickhouse_driver to handle parameter substitution. # This is because the driver incorrectly substitutes the entire query string, causing the "DROP TABLE" keyword to be omitted. # To resolve this, we are forced to provide the full query string here. - self.execute_sql( - f"""DROP TABLE {self.catalog_name()}.{self.capabilities.escape_identifier(table)} SYNC""" - ) + self.execute_sql(f"DROP TABLE {table} SYNC") + + def drop_tables(self, *tables: str) -> None: + """Drops a set of tables if they exist""" + if not tables: + return + statements = [ + f"DROP TABLE IF EXISTS {self.make_qualified_table_name(table)} SYNC;" + for table in tables + ] + self.execute_many(statements) def _list_tables(self) -> List[str]: catalog_name, table_name = self.make_qualified_table_name_path("%", escape=False) diff --git a/dlt/destinations/impl/databricks/sql_client.py b/dlt/destinations/impl/databricks/sql_client.py index da91402803..2af27020ee 100644 --- a/dlt/destinations/impl/databricks/sql_client.py +++ b/dlt/destinations/impl/databricks/sql_client.py @@ -87,9 +87,6 @@ def rollback_transaction(self) -> None: def native_connection(self) -> "DatabricksSqlConnection": return self._conn - def drop_dataset(self) -> None: - self.execute_sql("DROP SCHEMA IF EXISTS %s CASCADE;" % self.fully_qualified_dataset_name()) - def drop_tables(self, *tables: str) -> None: # Tables are drop with `IF EXISTS`, but databricks raises when the schema doesn't exist. # Multi statement exec is safe and the error can be ignored since all tables are in the same schema. diff --git a/dlt/destinations/impl/destination/factory.py b/dlt/destinations/impl/destination/factory.py index 69bb0daa13..e307b651fb 100644 --- a/dlt/destinations/impl/destination/factory.py +++ b/dlt/destinations/impl/destination/factory.py @@ -78,8 +78,10 @@ def __init__( if callable(destination_callable): pass elif destination_callable: + if "." not in destination_callable: + raise ValueError("str destination reference must be of format 'module.function'") + module_path, attr_name = destination_callable.rsplit(".", 1) try: - module_path, attr_name = destination_callable.rsplit(".", 1) dest_module = import_module(module_path) except ModuleNotFoundError as e: raise ConfigurationValueError( diff --git a/dlt/destinations/impl/dremio/sql_client.py b/dlt/destinations/impl/dremio/sql_client.py index fac65e7fd0..929aa2a0d8 100644 --- a/dlt/destinations/impl/dremio/sql_client.py +++ b/dlt/destinations/impl/dremio/sql_client.py @@ -32,6 +32,7 @@ def df(self, chunk_size: int = None, **kwargs: Any) -> Optional[DataFrame]: class DremioSqlClient(SqlClientBase[pydremio.DremioConnection]): dbapi: ClassVar[DBApi] = pydremio + SENTINEL_TABLE_NAME: ClassVar[str] = "_dlt_sentinel_table" def __init__( self, @@ -134,7 +135,9 @@ def is_dbapi_exception(ex: Exception) -> bool: return isinstance(ex, (pyarrow.lib.ArrowInvalid, pydremio.MalformedQueryError)) def create_dataset(self) -> None: - pass + # We create a sentinel table which defines wether we consider the dataset created + sentinel_table_name = self.make_qualified_table_name(self.SENTINEL_TABLE_NAME) + self.execute_sql(f"CREATE TABLE {sentinel_table_name} (_dlt_id BIGINT);") def _get_table_names(self) -> List[str]: query = """ @@ -147,6 +150,11 @@ def _get_table_names(self) -> List[str]: return [table[0] for table in tables] def drop_dataset(self) -> None: + # drop sentinel table + sentinel_table_name = self.make_qualified_table_name(self.SENTINEL_TABLE_NAME) + # must exist or we get undefined relation exception + self.execute_sql(f"DROP TABLE {sentinel_table_name}") + table_names = self._get_table_names() for table_name in table_names: full_table_name = self.make_qualified_table_name(table_name) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 00b990d4fa..bf443e061f 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -9,6 +9,7 @@ import dlt from dlt.common import logger, time, json, pendulum +from dlt.common.storages.fsspec_filesystem import glob_files from dlt.common.typing import DictStrAny from dlt.common.schema import Schema, TSchemaTables, TTableSchema from dlt.common.storages import FileStorage, fsspec_from_config @@ -224,7 +225,7 @@ def drop_tables(self, *tables: str, delete_schema: bool = True) -> None: self._delete_file(filename) def truncate_tables(self, table_names: List[str]) -> None: - """Truncate table with given name""" + """Truncate a set of tables with given `table_names`""" table_dirs = set(self.get_table_dirs(table_names)) table_prefixes = [self.get_table_prefix(t) for t in table_names] for table_dir in table_dirs: @@ -302,18 +303,19 @@ def list_table_files(self, table_name: str) -> List[str]: def list_files_with_prefixes(self, table_dir: str, prefixes: List[str]) -> List[str]: """returns all files in a directory that match given prefixes""" result = [] - for current_dir, _dirs, files in self.fs_client.walk(table_dir, detail=False, refresh=True): - for file in files: - # skip INIT files - if file == INIT_FILE_NAME: - continue - filepath = self.pathlib.join( - path_utils.normalize_path_sep(self.pathlib, current_dir), file - ) - for p in prefixes: - if filepath.startswith(p): - result.append(filepath) - break + # we fallback to our own glob implementation that is tested to return consistent results for + # filesystems we support. we were not able to use `find` or `walk` because they were selecting + # files wrongly (on azure walk on path1/path2/ would also select files from path1/path2_v2/ but returning wrong dirs) + for details in glob_files(self.fs_client, self.make_remote_uri(table_dir), "**"): + file = details["file_name"] + filepath = self.pathlib.join(table_dir, details["relative_path"]) + # skip INIT files + if file == INIT_FILE_NAME: + continue + for p in prefixes: + if filepath.startswith(p): + result.append(filepath) + break return result def is_storage_initialized(self) -> bool: diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 128e2c7e7e..79a5de7f77 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -12,6 +12,7 @@ Optional, Dict, Sequence, + TYPE_CHECKING, ) import lancedb # type: ignore @@ -71,6 +72,11 @@ from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.type_mapping import TypeMapper +if TYPE_CHECKING: + NDArray = ndarray[Any, Any] +else: + NDArray = ndarray + TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} @@ -292,9 +298,7 @@ def delete_table(self, table_name: str) -> None: def query_table( self, table_name: str, - query: Union[ - List[Any], ndarray[Any, Any], Array, ChunkedArray, str, Tuple[Any], None - ] = None, + query: Union[List[Any], NDArray, Array, ChunkedArray, str, Tuple[Any], None] = None, ) -> LanceQueryBuilder: """Query a LanceDB table. @@ -408,8 +412,6 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] field: TArrowField for field in arrow_schema: name = self.schema.naming.normalize_identifier(field.name) - print(field.type) - print(field.name) table_schema[name] = { "name": name, **self.type_mapper.from_db_type(field.type), @@ -453,8 +455,6 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None: for table_name in only_tables or self.schema.tables: exists, existing_columns = self.get_storage_table(table_name) new_columns = self.schema.get_new_table_columns(table_name, existing_columns) - print(table_name) - print(new_columns) embedding_fields: List[str] = get_columns_names_with_prop( self.schema.get_table(table_name), VECTORIZE_HINT ) @@ -520,7 +520,6 @@ def update_schema_in_storage(self) -> None: write_disposition = self.schema.get_table(self.schema.version_table_name).get( "write_disposition" ) - print("UPLOAD") upload_batch( records, db_client=self.db_client, diff --git a/dlt/destinations/impl/mssql/sql_client.py b/dlt/destinations/impl/mssql/sql_client.py index a360670e77..988b461fa7 100644 --- a/dlt/destinations/impl/mssql/sql_client.py +++ b/dlt/destinations/impl/mssql/sql_client.py @@ -129,7 +129,7 @@ def _drop_views(self, *tables: str) -> None: self.execute_many(statements) def _drop_schema(self) -> None: - self.execute_sql("DROP SCHEMA IF EXISTS %s;" % self.fully_qualified_dataset_name()) + self.execute_sql("DROP SCHEMA %s;" % self.fully_qualified_dataset_name()) def execute_sql( self, sql: AnyStr, *args: Any, **kwargs: Any diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 2a5671b7e7..b0786e9ed6 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -172,13 +172,13 @@ def __init__( # decide on source format, stage_file_path will either be a local file or a bucket path if file_name.endswith("jsonl"): source_format = "( TYPE = 'JSON', BINARY_FORMAT = 'BASE64' )" - if file_name.endswith("parquet"): + elif file_name.endswith("parquet"): source_format = ( "(TYPE = 'PARQUET', BINARY_AS_TEXT = FALSE, USE_LOGICAL_TYPE = TRUE)" # TODO: USE_VECTORIZED_SCANNER inserts null strings into VARIANT JSON # " USE_VECTORIZED_SCANNER = TRUE)" ) - if file_name.endswith("csv"): + elif file_name.endswith("csv"): # empty strings are NULL, no data is NULL, missing columns (ERROR_ON_COLUMN_COUNT_MISMATCH) are NULL csv_format = config.csv_format or CsvFormatConfiguration() source_format = ( @@ -192,6 +192,8 @@ def __init__( column_match_clause = "" if csv_format.on_error_continue: on_error_clause = "ON_ERROR = CONTINUE" + else: + raise ValueError(file_name) with client.begin_transaction(): # PUT and COPY in one tx if local file, otherwise only copy diff --git a/dlt/destinations/impl/synapse/sql_client.py b/dlt/destinations/impl/synapse/sql_client.py index db1b3e7cf6..cd9a929901 100644 --- a/dlt/destinations/impl/synapse/sql_client.py +++ b/dlt/destinations/impl/synapse/sql_client.py @@ -1,12 +1,6 @@ -from typing import ClassVar from contextlib import suppress -from dlt.common.destination import DestinationCapabilitiesContext - from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient -from dlt.destinations.impl.mssql.configuration import MsSqlCredentials -from dlt.destinations.impl.synapse.configuration import SynapseCredentials - from dlt.destinations.exceptions import DatabaseUndefinedRelation @@ -17,9 +11,6 @@ def drop_tables(self, *tables: str) -> None: # Synapse does not support DROP TABLE IF EXISTS. # Workaround: use DROP TABLE and suppress non-existence errors. statements = [f"DROP TABLE {self.make_qualified_table_name(table)};" for table in tables] - with suppress(DatabaseUndefinedRelation): - self.execute_fragments(statements) - - def _drop_schema(self) -> None: - # Synapse does not support DROP SCHEMA IF EXISTS. - self.execute_sql("DROP SCHEMA %s;" % self.fully_qualified_dataset_name()) + for statement in statements: + with suppress(DatabaseUndefinedRelation): + self.execute_sql(statement) diff --git a/dlt/destinations/sql_client.py b/dlt/destinations/sql_client.py index 7912ac4561..f74f1b9224 100644 --- a/dlt/destinations/sql_client.py +++ b/dlt/destinations/sql_client.py @@ -98,6 +98,7 @@ def truncate_tables(self, *tables: str) -> None: self.execute_many(statements) def drop_tables(self, *tables: str) -> None: + """Drops a set of tables if they exist""" if not tables: return statements = [ diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index ad10ef3ad3..1eccd86aad 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -192,11 +192,7 @@ def decorator( # source name is passed directly or taken from decorated function name effective_name = name or get_callable_name(f) - if not schema: - # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema - schema = _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) - - if name and name != schema.name: + if schema and name and name != schema.name: raise ExplicitSourceNameInvalid(name, schema.name) # wrap source extraction function in configuration with section @@ -224,12 +220,19 @@ def _eval_rv(_rv: Any, schema_copy: Schema) -> TDltSourceImpl: s.root_key = root_key return s + def _make_schema() -> Schema: + if not schema: + # load the schema from file with name_schema.yaml/json from the same directory, the callable resides OR create new default schema + return _maybe_load_schema_for_callable(f, effective_name) or Schema(effective_name) + else: + # clone the schema passed to decorator, update normalizers, remove processing hints + # NOTE: source may be called several times in many different settings + return schema.clone(update_normalizers=True, remove_processing_hints=True) + @wraps(conf_f) def _wrap(*args: Any, **kwargs: Any) -> TDltSourceImpl: """Wrap a regular function, injection context must be a part of the wrap""" - # clone the schema passed to decorator, update normalizers, remove processing hints - # NOTE: source may be called several times in many different settings - schema_copy = schema.clone(update_normalizers=True, remove_processing_hints=True) + schema_copy = _make_schema() with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] @@ -249,9 +252,7 @@ async def _wrap_coro(*args: Any, **kwargs: Any) -> TDltSourceImpl: """In case of co-routine we must wrap the whole injection context in awaitable, there's no easy way to avoid some code duplication """ - # clone the schema passed to decorator, update normalizers, remove processing hints - # NOTE: source may be called several times in many different settings - schema_copy = schema.clone(update_normalizers=True, remove_processing_hints=True) + schema_copy = _make_schema() with Container().injectable_context(SourceSchemaInjectableContext(schema_copy)): # configurations will be accessed in this section in the source proxy = Container()[PipelineContext] diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 5769be1a8d..7a24b7f225 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -31,7 +31,7 @@ from dlt.common.storages.load_package import ( ParsedLoadJobFileName, LoadPackageStateInjectableContext, - TPipelineStateDoc, + TLoadPackageState, commit_load_package_state, ) from dlt.common.utils import get_callable_name, get_full_class_name @@ -45,7 +45,6 @@ from dlt.extract.storage import ExtractStorage from dlt.extract.extractors import ObjectExtractor, ArrowExtractor, Extractor from dlt.extract.utils import get_data_item_format -from dlt.pipeline.drop import drop_resources def data_to_sources( @@ -371,7 +370,7 @@ def extract( source: DltSource, max_parallel_items: int, workers: int, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, ) -> str: # generate load package to be able to commit all the sources together later load_id = self.extract_storage.create_load_package( @@ -394,7 +393,7 @@ def extract( ) ): if load_package_state_update: - load_package.state.update(load_package_state_update) # type: ignore[typeddict-item] + load_package.state.update(load_package_state_update) # reset resource states, the `extracted` list contains all the explicit resources and all their parents for resource in source.resources.extracted.values(): diff --git a/dlt/extract/incremental/__init__.py b/dlt/extract/incremental/__init__.py index bc25c6fee1..11f989e0b2 100644 --- a/dlt/extract/incremental/__init__.py +++ b/dlt/extract/incremental/__init__.py @@ -356,6 +356,7 @@ def _join_external_scheduler(self) -> None: f"Specified Incremental last value type {param_type} is not supported. Please use" f" DateTime, Date, float, int or str to join external schedulers.({ex})" ) + return if param_type is Any: logger.warning( diff --git a/dlt/extract/source.py b/dlt/extract/source.py index 9953b56117..7732c4f056 100644 --- a/dlt/extract/source.py +++ b/dlt/extract/source.py @@ -215,7 +215,7 @@ def from_data(cls, schema: Schema, section: str, data: Any) -> Self: def name(self) -> str: return self._schema.name - # TODO: 4 properties below must go somewhere else ie. into RelationalSchema which is Schema + Relational normalizer. + # TODO: max_table_nesting/root_key below must go somewhere else ie. into RelationalSchema which is Schema + Relational normalizer. @property def max_table_nesting(self) -> int: """A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON.""" @@ -223,25 +223,12 @@ def max_table_nesting(self) -> int: @max_table_nesting.setter def max_table_nesting(self, value: int) -> None: - RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value}) - - @property - def schema_contract(self) -> TSchemaContract: - return self.schema.settings["schema_contract"] - - @schema_contract.setter - def schema_contract(self, settings: TSchemaContract) -> None: - self.schema.set_schema_contract(settings) - - @property - def exhausted(self) -> bool: - """check all selected pipes wether one of them has started. if so, the source is exhausted.""" - for resource in self._resources.extracted.values(): - item = resource._pipe.gen - if inspect.isgenerator(item): - if inspect.getgeneratorstate(item) != "GEN_CREATED": - return True - return False + if value is None: + # this also check the normalizer type + config = RelationalNormalizer.get_normalizer_config(self._schema) + config.pop("max_nesting", None) + else: + RelationalNormalizer.update_normalizer_config(self._schema, {"max_nesting": value}) @property def root_key(self) -> bool: @@ -280,6 +267,24 @@ def root_key(self, value: bool) -> None: propagation_config = config["propagation"] propagation_config["root"].pop(data_normalizer.c_dlt_id) + @property + def schema_contract(self) -> TSchemaContract: + return self.schema.settings.get("schema_contract") + + @schema_contract.setter + def schema_contract(self, settings: TSchemaContract) -> None: + self.schema.set_schema_contract(settings) + + @property + def exhausted(self) -> bool: + """check all selected pipes wether one of them has started. if so, the source is exhausted.""" + for resource in self._resources.extracted.values(): + item = resource._pipe.gen + if inspect.isgenerator(item): + if inspect.getgeneratorstate(item) != "GEN_CREATED": + return True + return False + @property def resources(self) -> DltResourceDict: """A dictionary of all resources present in the source, where the key is a resource name.""" diff --git a/dlt/helpers/airflow_helper.py b/dlt/helpers/airflow_helper.py index 7d7302aab6..8494d3bba3 100644 --- a/dlt/helpers/airflow_helper.py +++ b/dlt/helpers/airflow_helper.py @@ -11,6 +11,7 @@ RetryCallState, ) +from dlt.common.known_env import DLT_DATA_DIR, DLT_PROJECT_DIR from dlt.common.exceptions import MissingDependencyException try: @@ -121,7 +122,7 @@ def __init__( dags_folder = conf.get("core", "dags_folder") # set the dlt project folder to dags - os.environ["DLT_PROJECT_DIR"] = dags_folder + os.environ[DLT_PROJECT_DIR] = dags_folder # check if /data mount is available if use_data_folder and os.path.exists("/home/airflow/gcs/data"): @@ -129,7 +130,7 @@ def __init__( else: # create random path data_dir = os.path.join(local_data_folder or gettempdir(), f"dlt_{uniq_id(8)}") - os.environ["DLT_DATA_DIR"] = data_dir + os.environ[DLT_DATA_DIR] = data_dir # delete existing config providers in container, they will get reloaded on next use if ConfigProvidersContext in Container(): @@ -400,7 +401,7 @@ def add_run( """ # make sure that pipeline was created after dag was initialized - if not pipeline.pipelines_dir.startswith(os.environ["DLT_DATA_DIR"]): + if not pipeline.pipelines_dir.startswith(os.environ[DLT_DATA_DIR]): raise ValueError( "Please create your Pipeline instance after AirflowTasks are created. The dlt" " pipelines directory is not set correctly." diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index c68931d7db..266581c785 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -1,3 +1,4 @@ +import sys import os from subprocess import CalledProcessError import giturlparse @@ -154,12 +155,12 @@ def _run_dbt_command( try: i = iter_stdout_with_result(self.venv, "python", "-c", script) while True: - print(next(i).strip()) + sys.stdout.write(next(i).strip()) except StopIteration as si: # return result from generator return si.value # type: ignore except CalledProcessError as cpe: - print(cpe.stderr) + sys.stderr.write(cpe.stderr) raise def run( diff --git a/dlt/load/load.py b/dlt/load/load.py index 76b4806694..0e78650a84 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -382,8 +382,12 @@ def complete_package(self, load_id: str, schema: Schema, aborted: bool = False) def load_single_package(self, load_id: str, schema: Schema) -> None: new_jobs = self.get_new_jobs_info(load_id) + # get dropped and truncated tables that were added in the extract step if refresh was requested + # NOTE: if naming convention was updated those names correspond to the old naming convention + # and they must be like that in order to drop existing tables dropped_tables = current_load_package()["state"].get("dropped_tables", []) truncated_tables = current_load_package()["state"].get("truncated_tables", []) + # initialize analytical storage ie. create dataset required by passed schema with self.get_destination_client(schema) as job_client: if (expected_update := self.load_storage.begin_schema_update(load_id)) is not None: diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 7db05674fa..67a813f5f2 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -113,12 +113,15 @@ def init_client( ) ) + # get tables to drop + drop_table_names = {table["name"] for table in drop_tables} if drop_tables else set() + applied_update = _init_dataset_and_update_schema( job_client, expected_update, tables_with_jobs | dlt_tables, truncate_table_names, - drop_tables=drop_tables, + drop_tables=drop_table_names, ) # update the staging dataset if client supports this @@ -138,6 +141,7 @@ def init_client( staging_tables | {schema.version_table_name}, # keep only schema version staging_tables, # all eligible tables must be also truncated staging_info=True, + drop_tables=drop_table_names, # try to drop all the same tables on staging ) return applied_update @@ -149,7 +153,7 @@ def _init_dataset_and_update_schema( update_tables: Iterable[str], truncate_tables: Iterable[str] = None, staging_info: bool = False, - drop_tables: Optional[List[TTableSchema]] = None, + drop_tables: Iterable[str] = None, ) -> TSchemaTables: staging_text = "for staging dataset" if staging_info else "" logger.info( @@ -158,12 +162,17 @@ def _init_dataset_and_update_schema( ) job_client.initialize_storage() if drop_tables: - drop_table_names = [table["name"] for table in drop_tables] if hasattr(job_client, "drop_tables"): logger.info( - f"Client for {job_client.config.destination_type} will drop tables {staging_text}" + f"Client for {job_client.config.destination_type} will drop tables" + f" {drop_tables} {staging_text}" + ) + job_client.drop_tables(*drop_tables, delete_schema=True) + else: + logger.warning( + f"Client for {job_client.config.destination_type} does not implement drop table." + f" Following tables {drop_tables} will not be dropped {staging_text}" ) - job_client.drop_tables(*drop_table_names, delete_schema=True) logger.info( f"Client for {job_client.config.destination_type} will update schema to package schema" diff --git a/dlt/normalize/worker.py b/dlt/normalize/worker.py index d5d4a028d9..cd50c56e09 100644 --- a/dlt/normalize/worker.py +++ b/dlt/normalize/worker.py @@ -46,6 +46,7 @@ def group_worker_files(files: Sequence[str], no_groups: int) -> List[Sequence[st remainder_l = len(chunk_files) - no_groups l_idx = 0 while remainder_l > 0: + idx = 0 for idx, file in enumerate(reversed(chunk_files.pop())): chunk_files[-l_idx - idx - remainder_l].append(file) # type: ignore remainder_l -= 1 diff --git a/dlt/pipeline/drop.py b/dlt/pipeline/drop.py index 486bead2f4..cd982cf676 100644 --- a/dlt/pipeline/drop.py +++ b/dlt/pipeline/drop.py @@ -17,6 +17,7 @@ group_tables_by_resource, compile_simple_regexes, compile_simple_regex, + has_table_seen_data, ) from dlt.common import jsonpath from dlt.common.typing import REPattern @@ -24,11 +25,11 @@ class _DropInfo(TypedDict): tables: List[str] + tables_with_data: List[str] resource_states: List[str] resource_names: List[str] state_paths: List[str] schema_name: str - dataset_name: Optional[str] drop_all: bool resource_pattern: Optional[REPattern] warnings: List[str] @@ -39,7 +40,7 @@ class _DropResult: schema: Schema state: TPipelineState info: _DropInfo - dropped_tables: List[TTableSchema] + modified_tables: List[TTableSchema] def _create_modified_state( @@ -85,12 +86,12 @@ def drop_resources( """Generate a new schema and pipeline state with the requested resources removed. Args: - schema: The schema to modify. - state: The pipeline state to modify. + schema: The schema to modify. Note that schema is changed in place. + state: The pipeline state to modify. Note that state is changed in place. resources: Resource name(s) or regex pattern(s) matching resource names to drop. If empty, no resources will be dropped unless `drop_all` is True. state_paths: JSON path(s) relative to the source state to drop. - drop_all: If True, all resources will be dropped (supeseeds `resources`). + drop_all: If True, all resources will be dropped (supersedes `resources`). state_only: If True, only modify the pipeline state, not schema sources: Only wipe state for sources matching the name(s) or regex pattern(s) in this list If not set all source states will be modified according to `state_paths` and `resources` @@ -112,9 +113,6 @@ def drop_resources( state_paths = jsonpath.compile_paths(state_paths) - schema = schema.clone() - state = deepcopy(state) - resources = set(resources) if drop_all: resource_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything @@ -128,28 +126,28 @@ def drop_resources( source_pattern = compile_simple_regex(TSimpleRegex("re:.*")) # Match everything if resource_pattern: - data_tables = { - t["name"]: t for t in schema.data_tables(seen_data_only=True) - } # Don't remove _dlt tables + # (1) Don't remove _dlt tables (2) Drop all selected tables from the schema + # (3) Mark tables that seen data to be dropped in destination + data_tables = {t["name"]: t for t in schema.data_tables(include_incomplete=True)} resource_tables = group_tables_by_resource(data_tables, pattern=resource_pattern) resource_names = list(resource_tables.keys()) - # TODO: If drop_tables - if not state_only: - tables_to_drop = list(chain.from_iterable(resource_tables.values())) - tables_to_drop.reverse() - else: - tables_to_drop = [] + tables_to_drop_from_schema = list(chain.from_iterable(resource_tables.values())) + tables_to_drop_from_schema.reverse() + tables_to_drop_from_schema_names = [t["name"] for t in tables_to_drop_from_schema] + tables_to_drop_from_dest = [t for t in tables_to_drop_from_schema if has_table_seen_data(t)] else: - tables_to_drop = [] + tables_to_drop_from_schema_names = [] + tables_to_drop_from_dest = [] + tables_to_drop_from_schema = [] resource_names = [] info: _DropInfo = dict( - tables=[t["name"] for t in tables_to_drop], + tables=tables_to_drop_from_schema_names if not state_only else [], + tables_with_data=[t["name"] for t in tables_to_drop_from_dest] if not state_only else [], resource_states=[], state_paths=[], - resource_names=resource_names, + resource_names=resource_names if not state_only else [], schema_name=schema.name, - dataset_name=None, drop_all=drop_all, resource_pattern=resource_pattern, warnings=[], @@ -158,7 +156,7 @@ def drop_resources( new_state, info = _create_modified_state( state, resource_pattern, source_pattern, state_paths, info ) - info["resource_names"] = resource_names + # info["resource_names"] = resource_names if not state_only else [] if resource_pattern and not resource_tables: info["warnings"].append( @@ -167,5 +165,7 @@ def drop_resources( f" {list(group_tables_by_resource(data_tables).keys())}" ) - dropped_tables = schema.drop_tables([t["name"] for t in tables_to_drop], seen_data_only=True) - return _DropResult(schema, new_state, info, dropped_tables) + if not state_only: + # drop only the selected tables + schema.drop_tables(tables_to_drop_from_schema_names) + return _DropResult(schema, new_state, info, tables_to_drop_from_dest) diff --git a/dlt/pipeline/helpers.py b/dlt/pipeline/helpers.py index 0defbc14eb..ce81b81433 100644 --- a/dlt/pipeline/helpers.py +++ b/dlt/pipeline/helpers.py @@ -12,8 +12,10 @@ from dlt.common.jsonpath import TAnyJsonPath from dlt.common.exceptions import TerminalException +from dlt.common.schema.schema import Schema from dlt.common.schema.typing import TSimpleRegex from dlt.common.pipeline import pipeline_state as current_pipeline_state, TRefreshMode +from dlt.common.storages.load_package import TLoadPackageDropTablesState from dlt.pipeline.exceptions import ( PipelineNeverRan, PipelineStepFailed, @@ -83,24 +85,24 @@ def __init__( if not pipeline.default_schema_name: raise PipelineNeverRan(pipeline.pipeline_name, pipeline.pipelines_dir) + # clone schema to keep it as original in case we need to restore pipeline schema self.schema = pipeline.schemas[schema_name or pipeline.default_schema_name].clone() drop_result = drop_resources( - # self._drop_schema, self._new_state, self.info = drop_resources( - self.schema, - pipeline.state, + # create clones to have separate schemas and state + self.schema.clone(), + deepcopy(pipeline.state), resources, state_paths, drop_all, state_only, ) - + # get modified schema and state self._new_state = drop_result.state - self.info = drop_result.info self._new_schema = drop_result.schema - self._dropped_tables = drop_result.dropped_tables - self.drop_tables = not state_only and bool(self._dropped_tables) - + self.info = drop_result.info + self._modified_tables = drop_result.modified_tables + self.drop_tables = not state_only and bool(self._modified_tables) self.drop_state = bool(drop_all or resources or state_paths) @property @@ -130,7 +132,9 @@ def __call__(self) -> None: self.pipeline._save_and_extract_state_and_schema( new_state, schema=self._new_schema, - load_package_state_update={"dropped_tables": self._dropped_tables}, + load_package_state_update=( + {"dropped_tables": self._modified_tables} if self.drop_tables else None + ), ) self.pipeline.normalize() @@ -159,30 +163,33 @@ def drop( def refresh_source( pipeline: "Pipeline", source: DltSource, refresh: TRefreshMode -) -> Dict[str, Any]: - """Run the pipeline's refresh mode on the given source, updating the source's schema and state. +) -> TLoadPackageDropTablesState: + """Run the pipeline's refresh mode on the given source, updating the provided `schema` and pipeline state. Returns: The new load package state containing tables that need to be dropped/truncated. """ - if pipeline.first_run: - return {} pipeline_state, _ = current_pipeline_state(pipeline._container) _resources_to_drop = list(source.resources.extracted) if refresh != "drop_sources" else [] + only_truncate = refresh == "drop_data" + drop_result = drop_resources( + # do not cline the schema, change in place source.schema, + # do not clone the state, change in place pipeline_state, resources=_resources_to_drop, drop_all=refresh == "drop_sources", state_paths="*" if refresh == "drop_sources" else [], + state_only=only_truncate, sources=source.name, ) - load_package_state = {} - if drop_result.dropped_tables: - key = "dropped_tables" if refresh != "drop_data" else "truncated_tables" - load_package_state[key] = drop_result.dropped_tables - if refresh != "drop_data": # drop_data is only data wipe, keep original schema - source.schema = drop_result.schema - if "sources" in drop_result.state: - pipeline_state["sources"] = drop_result.state["sources"] + load_package_state: TLoadPackageDropTablesState = {} + if drop_result.modified_tables: + if only_truncate: + load_package_state["truncated_tables"] = drop_result.modified_tables + else: + load_package_state["dropped_tables"] = drop_result.modified_tables + # if any tables should be dropped, we force state to extract + force_state_extract(pipeline_state) return load_package_state diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 2bfee3fd29..ac5d3b90e4 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -1,6 +1,7 @@ import contextlib import os from contextlib import contextmanager +from copy import deepcopy, copy from functools import wraps from typing import ( Any, @@ -157,10 +158,8 @@ def _wrap(self: "Pipeline", *args: Any, **kwargs: Any) -> Any: # backup and restore state should_extract_state = may_extract_state and self.config.restore_from_destination - with self.managed_state(extract_state=should_extract_state) as state: - # add the state to container as a context - with self._container.injectable_context(StateInjectableContext(state=state)): - return f(self, *args, **kwargs) + with self.managed_state(extract_state=should_extract_state): + return f(self, *args, **kwargs) return _wrap # type: ignore @@ -438,12 +437,12 @@ def extract( workers, refresh=refresh or self.refresh, ) - # extract state - if self.config.restore_from_destination: - # this will update state version hash so it will not be extracted again by with_state_sync - self._bump_version_and_extract_state( - self._container[StateInjectableContext].state, True, extract_step - ) + # this will update state version hash so it will not be extracted again by with_state_sync + self._bump_version_and_extract_state( + self._container[StateInjectableContext].state, + self.config.restore_from_destination, + extract_step, + ) # commit load packages with state extract_step.commit_packages() return self._get_step_info(extract_step) @@ -1107,8 +1106,9 @@ def _extract_source( max_parallel_items: int, workers: int, refresh: Optional[TRefreshMode] = None, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, ) -> str: + load_package_state_update = copy(load_package_state_update or {}) # discover the existing pipeline schema try: # all live schemas are initially committed and during the extract will accumulate changes in memory @@ -1116,19 +1116,34 @@ def _extract_source( # this will (1) look for import schema if present # (2) load import schema an overwrite pipeline schema if import schema modified # (3) load pipeline schema if no import schema is present - pipeline_schema = self.schemas[source.schema.name] - pipeline_schema = pipeline_schema.clone() # use clone until extraction complete - # apply all changes in the source schema to pipeline schema - # NOTE: we do not apply contracts to changes done programmatically - pipeline_schema.update_schema(source.schema) - # replace schema in the source - source.schema = pipeline_schema + + # keep schema created by the source so we can apply changes from it later + source_schema = source.schema + # use existing pipeline schema as the source schema, clone until extraction complete + source.schema = self.schemas[source.schema.name].clone() + # refresh the pipeline schema ie. to drop certain tables before any normalizes change + if refresh: + # NOTE: we use original pipeline schema to detect dropped/truncated tables so we can drop + # the original names, before eventual new naming convention is applied + load_package_state_update.update(deepcopy(refresh_source(self, source, refresh))) + if refresh == "drop_sources": + # replace the whole source AFTER we got tables to drop + source.schema = source_schema + # NOTE: we do pass any programmatic changes from source schema to pipeline schema except settings below + # TODO: enable when we have full identifier lineage and we are able to merge table identifiers + if type(source.schema.naming) is not type(source_schema.naming): # noqa + source.schema_contract = source_schema.settings.get("schema_contract") + else: + source.schema.update_schema(source_schema) except FileNotFoundError: - pass + if refresh is not None: + logger.info( + f"Refresh flag {refresh} has no effect on source {source.name} because the" + " source is extracted for a first time" + ) - load_package_state_update = dict(load_package_state_update or {}) - if refresh: - load_package_state_update.update(refresh_source(self, source, refresh)) + # update the normalizers to detect any conflicts early + source.schema.update_normalizers() # extract into pipeline schema load_id = extract.extract( @@ -1335,9 +1350,9 @@ def _set_destinations( def _maybe_destination_capabilities( self, ) -> Iterator[DestinationCapabilitiesContext]: + caps: DestinationCapabilitiesContext = None + injected_caps: ContextManager[DestinationCapabilitiesContext] = None try: - caps: DestinationCapabilitiesContext = None - injected_caps: ContextManager[DestinationCapabilitiesContext] = None if self.destination: destination_caps = self._get_destination_capabilities() stage_caps = self._get_staging_capabilities() @@ -1504,11 +1519,15 @@ def _get_schemas_from_destination( @contextmanager def managed_state(self, *, extract_state: bool = False) -> Iterator[TPipelineState]: - # load or restore state + """Puts pipeline state in managed mode, where yielded state changes will be persisted or fully roll-backed on exception. + + Makes the state to be available via StateInjectableContext + """ state = self._get_state() - # TODO: we should backup schemas here try: - yield state + # add the state to container as a context + with self._container.injectable_context(StateInjectableContext(state=state)): + yield state except Exception: backup_state = self._get_state() # restore original pipeline props @@ -1576,7 +1595,7 @@ def _save_and_extract_state_and_schema( self, state: TPipelineState, schema: Schema, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, ) -> None: """Save given state + schema and extract creating a new load package @@ -1601,7 +1620,7 @@ def _bump_version_and_extract_state( state: TPipelineState, extract_state: bool, extract: Extract = None, - load_package_state_update: Optional[Dict[str, Any]] = None, + load_package_state_update: Optional[TLoadPackageState] = None, schema: Optional[Schema] = None, ) -> None: """Merges existing state into `state` and extracts state using `storage` if extract_state is True. diff --git a/dlt/sources/helpers/requests/retry.py b/dlt/sources/helpers/requests/retry.py index 3f9d7d559e..7d7d6493ec 100644 --- a/dlt/sources/helpers/requests/retry.py +++ b/dlt/sources/helpers/requests/retry.py @@ -119,8 +119,8 @@ def _make_retry( retry_conds = [retry_if_status(status_codes), retry_if_exception_type(tuple(exceptions))] if condition is not None: if callable(condition): - retry_condition = [condition] - retry_conds.extend([retry_if_predicate(c) for c in retry_condition]) + condition = [condition] + retry_conds.extend([retry_if_predicate(c) for c in condition]) wait_cls = wait_exponential_retry_after if respect_retry_after_header else wait_exponential return Retrying( diff --git a/poetry.lock b/poetry.lock index 323b2188d3..a7d754f5a8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2813,6 +2813,21 @@ files = [ [package.dependencies] flake8 = ">=3.8.4" +[[package]] +name = "flake8-print" +version = "5.0.0" +description = "print statement checker plugin for flake8" +optional = false +python-versions = ">=3.7" +files = [ + {file = "flake8-print-5.0.0.tar.gz", hash = "sha256:76915a2a389cc1c0879636c219eb909c38501d3a43cc8dae542081c9ba48bdf9"}, + {file = "flake8_print-5.0.0-py3-none-any.whl", hash = "sha256:84a1a6ea10d7056b804221ac5e62b1cee1aefc897ce16f2e5c42d3046068f5d8"}, +] + +[package.dependencies] +flake8 = ">=3.0" +pycodestyle = "*" + [[package]] name = "flake8-tidy-imports" version = "4.10.0" @@ -9643,4 +9658,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "bb75ee485742aa176ad726fd468832642096145fff0543472b998e04b8b053d0" +content-hash = "1205791c3a090cf55617833ef566f1d55e6fcfa7209079bca92277f217130549" diff --git a/pyproject.toml b/pyproject.toml index 099850b6bf..6f21d17be7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -154,6 +154,7 @@ ruff = "^0.3.2" pyjwt = "^2.8.0" pytest-mock = "^3.14.0" types-regex = "^2024.5.15.20240519" +flake8-print = "^5.0.0" [tool.poetry.group.pipeline] optional = true diff --git a/tests/cli/common/test_cli_invoke.py b/tests/cli/common/test_cli_invoke.py index d367a97261..f856162479 100644 --- a/tests/cli/common/test_cli_invoke.py +++ b/tests/cli/common/test_cli_invoke.py @@ -6,6 +6,7 @@ from unittest.mock import patch import dlt +from dlt.common.known_env import DLT_DATA_DIR from dlt.common.configuration.paths import get_dlt_data_dir from dlt.common.runners.venv import Venv from dlt.common.utils import custom_environ, set_working_dir @@ -62,7 +63,7 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: shutil.copytree("tests/cli/cases/deploy_pipeline", TEST_STORAGE_ROOT, dirs_exist_ok=True) with set_working_dir(TEST_STORAGE_ROOT): - with custom_environ({"COMPETED_PROB": "1.0", "DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({"COMPETED_PROB": "1.0", DLT_DATA_DIR: get_dlt_data_dir()}): venv = Venv.restore_current() venv.run_script("dummy_pipeline.py") # we check output test_pipeline_command else @@ -96,7 +97,7 @@ def test_invoke_pipeline(script_runner: ScriptRunner) -> None: def test_invoke_init_chess_and_template(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): result = script_runner.run(["dlt", "init", "chess", "dummy"]) assert "Verified source chess was added to your project!" in result.stdout assert result.returncode == 0 @@ -116,7 +117,7 @@ def test_invoke_list_verified_sources(script_runner: ScriptRunner) -> None: def test_invoke_deploy_project(script_runner: ScriptRunner) -> None: with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): result = script_runner.run( ["dlt", "deploy", "debug_pipeline.py", "github-action", "--schedule", "@daily"] ) diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index ccc73a30c0..5271c68633 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -10,6 +10,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.inject import with_config from dlt.common.configuration.exceptions import LookupTrace +from dlt.common.known_env import DLT_DATA_DIR, DLT_PROJECT_DIR from dlt.common.configuration.providers.toml import ( SECRETS_TOML, CONFIG_TOML, @@ -257,8 +258,8 @@ def test_toml_global_config() -> None: assert config._add_global_config is False # type: ignore[attr-defined] # set dlt data and settings dir - os.environ["DLT_DATA_DIR"] = "./tests/common/cases/configuration/dlt_home" - os.environ["DLT_PROJECT_DIR"] = "./tests/common/cases/configuration/" + os.environ[DLT_DATA_DIR] = "./tests/common/cases/configuration/dlt_home" + os.environ[DLT_PROJECT_DIR] = "./tests/common/cases/configuration/" # create instance with global toml enabled config = ConfigTomlProvider(add_global_config=True) assert config._add_global_config is True diff --git a/tests/common/schema/test_normalize_identifiers.py b/tests/common/schema/test_normalize_identifiers.py index 60f8c04604..646a693ea6 100644 --- a/tests/common/schema/test_normalize_identifiers.py +++ b/tests/common/schema/test_normalize_identifiers.py @@ -352,7 +352,9 @@ def test_raise_on_change_identifier_table_with_data() -> None: os.environ["SCHEMA__NAMING"] = "tests.common.cases.normalizers.sql_upper" with pytest.raises(TableIdentifiersFrozen) as fr_ex: schema.update_normalizers() - assert fr_ex.value.table_name == "issues" + # _dlt_version is the first table to be normalized, and since there are tables + # that have seen data, we consider _dlt_version also be materialized + assert fr_ex.value.table_name == "_dlt_version" assert isinstance(fr_ex.value.from_naming, snake_case.NamingConvention) assert isinstance(fr_ex.value.to_naming, sql_upper.NamingConvention) # try again, get exception (schema was not partially modified) diff --git a/tests/common/test_json.py b/tests/common/test_json.py index 79037ebf93..b7d25589a7 100644 --- a/tests/common/test_json.py +++ b/tests/common/test_json.py @@ -6,6 +6,7 @@ from dlt.common import json, Decimal, pendulum from dlt.common.arithmetics import numeric_default_context +from dlt.common import known_env from dlt.common.json import ( _DECIMAL, _WEI, @@ -306,7 +307,7 @@ def test_garbage_pua_string(json_impl: SupportsJson) -> None: def test_change_pua_start() -> None: import inspect - os.environ["DLT_JSON_TYPED_PUA_START"] = "0x0FA179" + os.environ[known_env.DLT_JSON_TYPED_PUA_START] = "0x0FA179" from importlib import reload try: @@ -316,7 +317,7 @@ def test_change_pua_start() -> None: assert MOD_PUA_START == int("0x0FA179", 16) finally: # restore old start - os.environ["DLT_JSON_TYPED_PUA_START"] = hex(PUA_START) + os.environ[known_env.DLT_JSON_TYPED_PUA_START] = hex(PUA_START) from importlib import reload reload(inspect.getmodule(SupportsJson)) diff --git a/tests/extract/test_sources.py b/tests/extract/test_sources.py index 8287da69d4..a170c6977d 100644 --- a/tests/extract/test_sources.py +++ b/tests/extract/test_sources.py @@ -39,6 +39,39 @@ def switch_to_fifo(): del os.environ["EXTRACT__NEXT_ITEM_MODE"] +def test_basic_source() -> None: + def basic_gen(): + yield 1 + + schema = Schema("test") + s = DltSource.from_data(schema, "section", basic_gen) + assert s.name == "test" + assert s.section == "section" + assert s.max_table_nesting is None + assert s.root_key is False + assert s.schema_contract is None + assert s.exhausted is False + assert s.schema is schema + assert len(s.resources) == 1 + assert s.resources == s.selected_resources + + # set some props + s.max_table_nesting = 10 + assert s.max_table_nesting == 10 + s.root_key = True + assert s.root_key is True + s.schema_contract = "evolve" + assert s.schema_contract == "evolve" + + s.max_table_nesting = None + s.root_key = False + s.schema_contract = None + + assert s.max_table_nesting is None + assert s.root_key is False + assert s.schema_contract is None + + def test_call_data_resource() -> None: with pytest.raises(TypeError): DltResource.from_data([1], name="t")() diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index a7b1371f9f..3cad7dda2c 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -1,9 +1,10 @@ import os import posixpath -from typing import Union, Dict +from typing import Tuple, Union, Dict from urllib.parse import urlparse +from fsspec import AbstractFileSystem import pytest from tenacity import retry, stop_after_attempt, wait_fixed @@ -21,7 +22,7 @@ FilesystemDestinationClientConfiguration, ) from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders -from tests.common.storages.utils import assert_sample_files +from tests.common.storages.utils import TEST_SAMPLE_FILES, assert_sample_files from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET from tests.utils import autouse_test_storage from .utils import self_signed_cert @@ -98,27 +99,26 @@ def check_file_changed(file_url_: str): @pytest.mark.parametrize("load_content", (True, False)) @pytest.mark.parametrize("glob_filter", ("**", "**/*.csv", "*.txt", "met_csv/A803/*.csv")) -def test_filesystem_dict( - with_gdrive_buckets_env: str, load_content: bool, glob_filter: str -) -> None: +def test_glob_files(with_gdrive_buckets_env: str, load_content: bool, glob_filter: str) -> None: bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] - config = get_config() - # enable caches - config.read_only = True - if config.protocol in ["memory", "file"]: - pytest.skip(f"{config.protocol} not supported in this test") - glob_folder = "standard_source/samples" - # may contain query string - bucket_url_parsed = urlparse(bucket_url) - bucket_url = bucket_url_parsed._replace( - path=posixpath.join(bucket_url_parsed.path, glob_folder) - ).geturl() - filesystem, _ = fsspec_from_config(config) + bucket_url, config, filesystem = glob_test_setup(bucket_url, "standard_source/samples") # use glob to get data all_file_items = list(glob_files(filesystem, bucket_url, glob_filter)) + # assert len(all_file_items) == 0 assert_sample_files(all_file_items, filesystem, config, load_content, glob_filter) +def test_glob_overlapping_path_files(with_gdrive_buckets_env: str) -> None: + bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] + # "standard_source/sample" overlaps with a real existing "standard_source/samples". walk operation on azure + # will return all files from "standard_source/samples" and report the wrong "standard_source/sample" path to the user + # here we test we do not have this problem with out glob + bucket_url, _, filesystem = glob_test_setup(bucket_url, "standard_source/sample") + # use glob to get data + all_file_items = list(glob_files(filesystem, bucket_url)) + assert len(all_file_items) == 0 + + @pytest.mark.skipif("s3" not in ALL_FILESYSTEM_DRIVERS, reason="s3 destination not configured") def test_filesystem_instance_from_s3_endpoint(environment: Dict[str, str]) -> None: """Test that fsspec instance is correctly configured when using endpoint URL. @@ -264,3 +264,26 @@ def test_filesystem_destination_passed_parameters_override_config_values() -> No bound_config = filesystem_destination.configuration(filesystem_config) assert bound_config.current_datetime == config_now assert bound_config.extra_placeholders == config_extra_placeholders + + +def glob_test_setup( + bucket_url: str, glob_folder: str +) -> Tuple[str, FilesystemConfiguration, AbstractFileSystem]: + config = get_config() + # enable caches + config.read_only = True + if config.protocol in ["file"]: + pytest.skip(f"{config.protocol} not supported in this test") + + # may contain query string + bucket_url_parsed = urlparse(bucket_url) + bucket_url = bucket_url_parsed._replace( + path=posixpath.join(bucket_url_parsed.path, glob_folder) + ).geturl() + filesystem, _ = fsspec_from_config(config) + if config.protocol == "memory": + mem_path = os.path.join("m", "standard_source") + if not filesystem.isdir(mem_path): + filesystem.mkdirs(mem_path) + filesystem.upload(TEST_SAMPLE_FILES, mem_path, recursive=True) + return bucket_url, config, filesystem diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index a89153f629..e817a2f6c8 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -374,7 +374,7 @@ def some_data() -> Generator[List[DictStrAny], Any, None]: def test_merge_github_nested() -> None: - pipe = dlt.pipeline(destination="lancedb", dataset_name="github1", full_refresh=True) + pipe = dlt.pipeline(destination="lancedb", dataset_name="github1", dev_mode=True) assert pipe.dataset_name.startswith("github1_202") with open( @@ -422,7 +422,7 @@ def test_merge_github_nested() -> None: def test_empty_dataset_allowed() -> None: # dataset_name is optional so dataset name won't be autogenerated when not explicitly passed. - pipe = dlt.pipeline(destination="lancedb", full_refresh=True) + pipe = dlt.pipeline(destination="lancedb", dev_mode=True) client: LanceDBClient = pipe.destination_client() # type: ignore[assignment] assert pipe.dataset_name is None diff --git a/tests/load/pipeline/test_bigquery.py b/tests/load/pipeline/test_bigquery.py index 0618ff9d3d..f4fdef8665 100644 --- a/tests/load/pipeline/test_bigquery.py +++ b/tests/load/pipeline/test_bigquery.py @@ -15,7 +15,7 @@ ids=lambda x: x.name, ) def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration) -> None: - pipeline = destination_config.setup_pipeline("test_bigquery_numeric_types") + pipeline = destination_config.setup_pipeline("test_bigquery_numeric_types", dev_mode=True) columns = [ {"name": "col_big_numeric", "data_type": "decimal", "precision": 47, "scale": 9}, diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 210ad76b8a..3f0352cab7 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -601,9 +601,11 @@ def _collect_files(p) -> List[str]: found.append(os.path.join(basedir, file).replace(client.dataset_path, "")) return found - def _collect_table_counts(p) -> Dict[str, int]: + def _collect_table_counts(p, *items: str) -> Dict[str, int]: + expected_items = set(items).intersection({"items", "items2", "items3"}) + print(expected_items) return load_table_counts( - p, "items", "items2", "items3", "_dlt_loads", "_dlt_version", "_dlt_pipeline_state" + p, *expected_items, "_dlt_loads", "_dlt_version", "_dlt_pipeline_state" ) # generate 4 loads from 2 pipelines, store load ids @@ -616,7 +618,7 @@ def _collect_table_counts(p) -> Dict[str, int]: # first two loads p1.run([1, 2, 3], table_name="items").loads_ids[0] load_id_2_1 = p2.run([4, 5, 6], table_name="items").loads_ids[0] - assert _collect_table_counts(p1) == { + assert _collect_table_counts(p1, "items") == { "items": 6, "_dlt_loads": 2, "_dlt_pipeline_state": 2, @@ -643,7 +645,7 @@ def some_data(): p2.run([4, 5, 6], table_name="items").loads_ids[0] # no migration here # 4 loads for 2 pipelines, one schema and state change on p2 changes so 3 versions and 3 states - assert _collect_table_counts(p1) == { + assert _collect_table_counts(p1, "items", "items2") == { "items": 9, "items2": 3, "_dlt_loads": 4, diff --git a/tests/load/pipeline/test_refresh_modes.py b/tests/load/pipeline/test_refresh_modes.py index de557ba118..f4bf3b0311 100644 --- a/tests/load/pipeline/test_refresh_modes.py +++ b/tests/load/pipeline/test_refresh_modes.py @@ -2,21 +2,30 @@ import pytest import dlt +from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.common.pipeline import resource_state -from dlt.destinations.sql_client import DBApiCursor -from dlt.pipeline.state_sync import load_pipeline_state_from_destination +from dlt.common.utils import uniq_id from dlt.common.typing import DictStrAny from dlt.common.pipeline import pipeline_state as current_pipeline_state +from dlt.destinations.sql_client import DBApiCursor +from dlt.extract.source import DltSource +from dlt.pipeline.state_sync import load_pipeline_state_from_destination + from tests.utils import clean_test_storage from tests.pipeline.utils import ( + _is_filesystem, assert_load_info, + load_table_counts, load_tables_to_dicts, assert_only_table_columns, table_exists, ) from tests.load.utils import destinations_configs, DestinationTestConfiguration +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + def assert_source_state_is_wiped(state: DictStrAny) -> None: # Keys contains only "resources" or is empty @@ -66,7 +75,7 @@ def some_data_2(): yield {"id": 7} yield {"id": 8} - @dlt.resource + @dlt.resource(primary_key="id", write_disposition="merge") def some_data_3(): if first_run: dlt.state()["source_key_3"] = "source_value_3" @@ -103,7 +112,6 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration): # First run pipeline so destination so tables are created info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) assert_load_info(info) - # Second run of pipeline with only selected resources info = pipeline.run( refresh_source(first_run=False, drop_sources=True).with_resources( @@ -114,8 +122,6 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration): assert set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) == { "some_data_1", "some_data_2", - # Table has never seen data and is not dropped - "some_data_4", } # No "name" column should exist as table was dropped and re-created without it @@ -163,7 +169,7 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration): new_table_names = set( t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True) ) - assert new_table_names == {"some_data_1", "some_data_2", "some_data_4"} + assert new_table_names == {"some_data_1", "some_data_2"} # Run again with all tables to ensure they are re-created # The new schema in this case should match the schema of the first run exactly @@ -430,10 +436,76 @@ def test_refresh_argument_to_extract(destination_config: DestinationTestConfigur tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) # All other data tables removed - assert tables == {"some_data_3", "some_data_4"} + assert tables == {"some_data_3"} # Run again without refresh to confirm refresh option doesn't persist on pipeline pipeline.extract(refresh_source(first_run=False).with_resources("some_data_2")) tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) - assert tables == {"some_data_2", "some_data_3", "some_data_4"} + assert tables == {"some_data_2", "some_data_3"} + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, default_staging_configs=True, all_buckets_filesystem_configs=True + ), + ids=lambda x: x.name, +) +def test_refresh_staging_dataset(destination_config: DestinationTestConfiguration): + data = [ + {"id": 1, "pop": 1}, + {"id": 2, "pop": 3}, + {"id": 2, "pop": 4}, # duplicate + ] + + pipeline = destination_config.setup_pipeline("test_refresh_staging_dataset" + uniq_id()) + + source = DltSource( + dlt.Schema("data_x"), + "data_section", + [ + dlt.resource(data, name="data_1", primary_key="id", write_disposition="merge"), + dlt.resource(data, name="data_2", primary_key="id", write_disposition="append"), + ], + ) + # create two tables so two tables need to be dropped + info = pipeline.run(source) + assert_load_info(info) + + # make data so inserting on mangled tables is not possible + data_i = [ + {"id": "A", "pop": 0.1}, + {"id": "B", "pop": 0.3}, + {"id": "A", "pop": 0.4}, + ] + source_i = DltSource( + dlt.Schema("data_x"), + "data_section", + [ + dlt.resource(data_i, name="data_1", primary_key="id", write_disposition="merge"), + dlt.resource(data_i, name="data_2", primary_key="id", write_disposition="append"), + ], + ) + info = pipeline.run(source_i, refresh="drop_resources") + assert_load_info(info) + + # now replace the whole source and load different tables + source_i = DltSource( + dlt.Schema("data_x"), + "data_section", + [ + dlt.resource(data_i, name="data_1_v2", primary_key="id", write_disposition="merge"), + dlt.resource(data_i, name="data_2_v2", primary_key="id", write_disposition="append"), + ], + ) + info = pipeline.run(source_i, refresh="drop_sources") + assert_load_info(info) + + # tables got dropped + if _is_filesystem(pipeline): + assert load_table_counts(pipeline, "data_1", "data_2") == {} + else: + with pytest.raises(DestinationUndefinedEntity): + load_table_counts(pipeline, "data_1", "data_2") + load_table_counts(pipeline, "data_1_v2", "data_1_v2") diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 35b988d46e..69f6bd4cc4 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -338,10 +338,11 @@ def test_drop_tables(client: SqlJobClientBase) -> None: # Drop tables from the first schema client.schema = schema tables_to_drop = ["event_slot", "event_user"] - for tbl in tables_to_drop: - del schema.tables[tbl] + schema.drop_tables(tables_to_drop) schema._bump_version() - client.drop_tables(*tables_to_drop) + + # add one fake table to make sure one table can be ignored + client.drop_tables(tables_to_drop[0], "not_exists", *tables_to_drop[1:]) client._update_schema_in_storage(schema) # Schema was deleted, load it in again if isinstance(client, WithStagingDataset): with contextlib.suppress(DatabaseUndefinedRelation): diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index fa31f1db65..8d4e146034 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -112,6 +112,28 @@ def test_malformed_query_parameters(client: SqlJobClientBase) -> None: assert client.sql_client.is_dbapi_exception(term_ex.value.dbapi_exception) +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_has_dataset(client: SqlJobClientBase) -> None: + with client.sql_client.with_alternative_dataset_name("not_existing"): + assert not client.sql_client.has_dataset() + client.update_stored_schema() + assert client.sql_client.has_dataset() + + +@pytest.mark.parametrize( + "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name +) +def test_create_drop_dataset(client: SqlJobClientBase) -> None: + # client.sql_client.create_dataset() + with pytest.raises(DatabaseException): + client.sql_client.create_dataset() + client.sql_client.drop_dataset() + with pytest.raises(DatabaseUndefinedRelation): + client.sql_client.drop_dataset() + + @pytest.mark.parametrize( "client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name ) diff --git a/tests/load/utils.py b/tests/load/utils.py index 9ee933a07a..95083b7d31 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -718,7 +718,8 @@ def yield_client_with_storage( ) as client: client.initialize_storage() yield client - client.sql_client.drop_dataset() + if client.is_storage_initialized(): + client.sql_client.drop_dataset() if isinstance(client, WithStagingDataset): with client.with_staging_dataset(): if client.is_storage_initialized(): diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index ba7c0b9db8..979bdd0e37 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -3,10 +3,12 @@ import pytest import tempfile import shutil +from unittest.mock import patch from importlib.metadata import version as pkg_version import dlt from dlt.common import json, pendulum +from dlt.common.known_env import DLT_DATA_DIR from dlt.common.json import custom_pua_decode from dlt.common.runners import Venv from dlt.common.storages.exceptions import StorageMigrationError @@ -24,9 +26,49 @@ from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient -from tests.pipeline.utils import load_table_counts +from tests.pipeline.utils import airtable_emojis, load_table_counts from tests.utils import TEST_STORAGE_ROOT, test_storage + +def test_simulate_default_naming_convention_change() -> None: + # checks that (future) change in the naming convention won't affect existing pipelines + pipeline = dlt.pipeline("simulated_snake_case", destination="duckdb") + assert pipeline.naming.name() == "snake_case" + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + info.raise_on_failed_jobs() + # normalized names + assert pipeline.last_trace.last_normalize_info.row_counts["_schedule"] == 3 + assert "_schedule" in pipeline.default_schema.tables + + # mock the mod + # from dlt.common.normalizers import utils + + with patch("dlt.common.normalizers.utils.DEFAULT_NAMING_MODULE", "duck_case"): + duck_pipeline = dlt.pipeline("simulated_duck_case", destination="duckdb") + assert duck_pipeline.naming.name() == "duck_case" + print(airtable_emojis().schema.naming.name()) + + # run new and old pipelines + info = duck_pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + info.raise_on_failed_jobs() + print(duck_pipeline.last_trace.last_normalize_info.row_counts) + assert duck_pipeline.last_trace.last_normalize_info.row_counts["📆 Schedule"] == 3 + assert "📆 Schedule" in duck_pipeline.default_schema.tables + + # old pipeline should keep its naming convention + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + info.raise_on_failed_jobs() + # normalized names + assert pipeline.last_trace.last_normalize_info.row_counts["_schedule"] == 3 + assert pipeline.naming.name() == "snake_case" + + if sys.version_info >= (3, 12): pytest.skip("Does not run on Python 3.12 and later", allow_module_level=True) @@ -41,7 +83,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -175,7 +217,7 @@ def test_filesystem_pipeline_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # create virtual env with (0.4.9) where filesystem started to store state with Venv.create(tempfile.mkdtemp(), ["dlt==0.4.9"]) as venv: try: @@ -247,7 +289,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} @@ -322,7 +364,7 @@ def test_normalize_package_with_dlt_update(test_storage: FileStorage) -> None: # execute in test storage with set_working_dir(TEST_STORAGE_ROOT): # store dlt data in test storage (like patch_home_dir) - with custom_environ({"DLT_DATA_DIR": get_dlt_data_dir()}): + with custom_environ({DLT_DATA_DIR: get_dlt_data_dir()}): # save database outside of pipeline dir with custom_environ( {"DESTINATION__DUCKDB__CREDENTIALS": "duckdb:///test_github_3.duckdb"} diff --git a/tests/pipeline/test_drop_helpers.py b/tests/pipeline/test_drop_helpers.py new file mode 100644 index 0000000000..9a09d9f866 --- /dev/null +++ b/tests/pipeline/test_drop_helpers.py @@ -0,0 +1,209 @@ +import pytest +from copy import deepcopy + +import dlt +from dlt.common.schema.typing import LOADS_TABLE_NAME, PIPELINE_STATE_TABLE_NAME, VERSION_TABLE_NAME +from dlt.common.versioned_state import decompress_state +from dlt.pipeline.drop import drop_resources +from dlt.pipeline.helpers import DropCommand, refresh_source + +from tests.pipeline.utils import airtable_emojis, assert_load_info + + +@pytest.mark.parametrize("seen_data", [True, False], ids=["seen_data", "no_data"]) +def test_drop_helper_utils(seen_data: bool) -> None: + pipeline = dlt.pipeline("test_drop_helpers_no_table_drop", destination="duckdb") + # extract first which should produce tables that didn't seen data + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + if seen_data: + pipeline.run(source) + else: + pipeline.extract(source) + + # drop nothing + drop_info = drop_resources(pipeline.default_schema.clone(), pipeline.state) + assert drop_info.modified_tables == [] + assert drop_info.info["tables"] == [] + + # drop all resources + drop_info = drop_resources(pipeline.default_schema.clone(), pipeline.state, drop_all=True) + # no tables to drop + tables_to_drop = ( + {"_schedule", "_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else set() + ) + tables_to_drop_schema = ( + tables_to_drop if seen_data else {"_schedule", "_peacock", "_wide_peacock"} + ) + assert {t["name"] for t in drop_info.modified_tables} == tables_to_drop + # no state mods + assert drop_info.state["sources"]["airtable_emojis"] == {"resources": {}} + assert set(drop_info.info["tables"]) == tables_to_drop_schema + assert set(drop_info.info["tables_with_data"]) == tables_to_drop + # all tables got dropped + assert drop_info.schema.data_tables(include_incomplete=True) == [] + # dlt tables still there + assert set(drop_info.schema.dlt_table_names()) == { + VERSION_TABLE_NAME, + LOADS_TABLE_NAME, + PIPELINE_STATE_TABLE_NAME, + } + # same but with refresh + source_clone = source.clone() + source_clone.schema = pipeline.default_schema.clone() + with pipeline.managed_state() as state: + emoji_state = deepcopy(state["sources"]["airtable_emojis"]) + package_state = refresh_source(pipeline, source_clone, refresh="drop_sources") + # managed state modified + assert state["sources"]["airtable_emojis"] == {"resources": {}} + # restore old state for next tests + state["sources"]["airtable_emojis"] = emoji_state + if seen_data: + assert {t["name"] for t in package_state["dropped_tables"]} == tables_to_drop + else: + assert package_state == {} + assert source_clone.schema.data_tables(include_incomplete=True) == [] + + # drop only selected resources + tables_to_drop = {"_schedule"} if seen_data else set() + # seen_data means full run so we generate child tables in that case + left_in_schema = ( + {"_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else {"_peacock", "_wide_peacock"} + ) + drop_info = drop_resources( + pipeline.default_schema.clone(), pipeline.state, resources=["📆 Schedule"] + ) + assert set(t["name"] for t in drop_info.modified_tables) == tables_to_drop + # no changes in state + assert drop_info.state == pipeline.state + assert set(drop_info.info["tables"]) == {"_schedule"} + assert set(drop_info.schema.data_table_names(include_incomplete=True)) == left_in_schema + source_clone = source_clone.with_resources("📆 Schedule") + source_clone.schema = pipeline.default_schema.clone() + with pipeline.managed_state() as state: + package_state = refresh_source(pipeline, source_clone, refresh="drop_resources") + # state not modified + assert state["sources"]["airtable_emojis"] == {"resources": {"🦚Peacock": {"🦚🦚🦚": "🦚"}}} + if seen_data: + assert {t["name"] for t in package_state["dropped_tables"]} == tables_to_drop + else: + assert package_state == {} + assert set(source_clone.schema.data_table_names(include_incomplete=True)) == left_in_schema + + # truncate only + tables_to_truncate = ( + {"_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else set() + ) + all_in_schema = ( + {"_schedule", "_peacock", "_wide_peacock", "_peacock__peacock", "_wide_peacock__peacock"} + if seen_data + else {"_schedule", "_peacock", "_wide_peacock"} + ) + drop_info = drop_resources( + pipeline.default_schema.clone(), + pipeline.state, + resources=["🦚Peacock", "🦚WidePeacock"], + state_only=True, + ) + assert set(t["name"] for t in drop_info.modified_tables) == tables_to_truncate + # state is modified + assert drop_info.state["sources"]["airtable_emojis"] == {"resources": {}} + assert drop_info.info["tables"] == [] + # no tables with data will be dropped + assert drop_info.info["tables_with_data"] == [] + assert set(drop_info.schema.data_table_names(include_incomplete=True)) == all_in_schema + source_clone = source_clone.with_resources("🦚Peacock", "🦚WidePeacock") + source_clone.schema = pipeline.default_schema.clone() + with pipeline.managed_state() as state: + package_state = refresh_source(pipeline, source_clone, refresh="drop_data") + # state modified + assert state["sources"]["airtable_emojis"] == {"resources": {}} + if seen_data: + assert {t["name"] for t in package_state["truncated_tables"]} == tables_to_truncate + else: + assert package_state == {} + assert set(source_clone.schema.data_table_names(include_incomplete=True)) == all_in_schema + + +def test_drop_unknown_resource() -> None: + pipeline = dlt.pipeline("test_drop_unknown_resource", destination="duckdb") + # extract first which should produce tables that didn't seen data + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + info = pipeline.run(source) + assert_load_info(info) + drop = DropCommand(pipeline, resources=["💰Budget"]) + assert drop.is_empty + + source.schema = pipeline.default_schema + package_state = refresh_source( + pipeline, source.with_resources("💰Budget"), refresh="drop_resources" + ) + assert package_state == {} + + info = pipeline.run(source.with_resources("💰Budget"), refresh="drop_resources") + # nothing loaded + assert_load_info(info, 0) + + +def test_modified_state_in_package() -> None: + pipeline = dlt.pipeline("test_modified_state_in_package", destination="duckdb") + # extract first which should produce tables that didn't seen data + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + pipeline.extract(source) + # run again to change peacock state again + info = pipeline.extract(source) + normalize_storage = pipeline._get_normalize_storage() + package_state = normalize_storage.extracted_packages.get_load_package_state(info.loads_ids[0]) + pipeline_state = decompress_state(package_state["pipeline_state"]["state"]) + assert pipeline_state["sources"]["airtable_emojis"] == { + "resources": {"🦚Peacock": {"🦚🦚🦚": "🦚🦚"}} + } + + # remove state + info = pipeline.extract(airtable_emojis().with_resources("🦚Peacock"), refresh="drop_resources") + normalize_storage = pipeline._get_normalize_storage() + package_state = normalize_storage.extracted_packages.get_load_package_state(info.loads_ids[0]) + # nothing to drop + assert "dropped_tables" not in package_state + pipeline_state = decompress_state(package_state["pipeline_state"]["state"]) + # the state was reset to the original + assert pipeline_state["sources"]["airtable_emojis"] == { + "resources": {"🦚Peacock": {"🦚🦚🦚": "🦚"}} + } + + +def test_drop_tables_force_extract_state() -> None: + # if any tables will be dropped, state must be extracted even if it is not changed + pipeline = dlt.pipeline("test_drop_tables_force_extract_state", destination="duckdb") + source = airtable_emojis().with_resources( + "📆 Schedule", "🦚Peacock", "🦚WidePeacock", "💰Budget" + ) + info = pipeline.run(source) + assert_load_info(info) + # dropping schedule should not change the state + info = pipeline.run(airtable_emojis().with_resources("📆 Schedule"), refresh="drop_resources") + assert_load_info(info) + storage = pipeline._get_load_storage() + package_state = storage.get_load_package_state(info.loads_ids[0]) + assert package_state["dropped_tables"][0]["name"] == "_schedule" + assert "pipeline_state" in package_state + + # here we drop and set state to original, so without forcing state extract state would not be present + info = pipeline.run(airtable_emojis().with_resources("🦚Peacock"), refresh="drop_resources") + assert_load_info(info) + storage = pipeline._get_load_storage() + package_state = storage.get_load_package_state(info.loads_ids[0]) + # child table also dropped + assert len(package_state["dropped_tables"]) == 2 + assert "pipeline_state" in package_state diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 6a6bf4bde1..328119970a 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2271,21 +2271,57 @@ def test_change_naming_convention_name_collision() -> None: os.environ["SOURCES__AIRTABLE_EMOJIS__SCHEMA__NAMING"] = "sql_ci_v1" with pytest.raises(PipelineStepFailed) as pip_ex: pipeline.run(airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock")) + # see conflicts early + assert pip_ex.value.step == "extract" assert isinstance(pip_ex.value.__cause__, TableIdentifiersFrozen) # all good if we drop tables - # info = pipeline.run( - # airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), - # refresh="drop_resources", - # ) - # assert_load_info(info) - # assert load_data_table_counts(pipeline) == { - # "📆 Schedule": 3, - # "🦚Peacock": 1, - # "🦚WidePeacock": 1, - # "🦚Peacock__peacock": 3, - # "🦚WidePeacock__Peacock": 3, - # } + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + refresh="drop_resources", + ) + assert_load_info(info) + # case insensitive normalization + assert load_data_table_counts(pipeline) == { + "_schedule": 3, + "_peacock": 1, + "_widepeacock": 1, + "_peacock__peacock": 3, + "_widepeacock__peacock": 3, + } + + +def test_change_to_more_lax_naming_convention_name_collision() -> None: + # use snake_case which is strict and then change to duck_case which accepts snake_case names without any changes + # still we want to detect collisions + pipeline = dlt.pipeline( + "test_change_to_more_lax_naming_convention_name_collision", destination="duckdb" + ) + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock") + ) + assert_load_info(info) + assert "_peacock" in pipeline.default_schema.tables + + # use duck case to load data into duckdb so casing and emoji are preserved + duck_ = dlt.destinations.duckdb(naming_convention="duck_case") + + # changing destination to one with a separate naming convention raises immediately + with pytest.raises(TableIdentifiersFrozen): + pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + destination=duck_, + ) + + # refresh on the source level will work + info = pipeline.run( + airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), + destination=duck_, + refresh="drop_sources", + ) + assert_load_info(info) + # make sure that emojis got in + assert "🦚Peacock" in pipeline.default_schema.tables def test_change_naming_convention_column_collision() -> None: diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index c10618a7cc..f2e0058891 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -6,6 +6,7 @@ import dlt from dlt.common import json, sleep +from dlt.common.destination.exceptions import DestinationUndefinedEntity from dlt.common.pipeline import LoadInfo from dlt.common.schema.utils import get_table_format from dlt.common.typing import DictStrAny @@ -47,7 +48,9 @@ def budget(): @dlt.resource(name="🦚Peacock", selected=False, primary_key="🔑id") def peacock(): - dlt.current.resource_state()["🦚🦚🦚"] = "🦚" + r_state = dlt.current.resource_state() + r_state.setdefault("🦚🦚🦚", "") + r_state["🦚🦚🦚"] += "🦚" yield [{"peacock": [1, 2, 3], "🔑id": 1}] @dlt.resource(name="🦚WidePeacock", selected=False) diff --git a/tox.ini b/tox.ini index ed6c69c585..059f6a586a 100644 --- a/tox.ini +++ b/tox.ini @@ -7,3 +7,6 @@ banned-modules = datetime = use dlt.common.pendulum open = use dlt.common.open pendulum = use dlt.common.pendulum extend-immutable-calls = dlt.sources.incremental +per-file-ignores = + tests/*: T20 + docs/*: T20 \ No newline at end of file