Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allows to pass default values when writing specs #2018

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dlt/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
from dlt.cli.reference import SupportsCliCommand
from dlt.cli.exceptions import CliCommandException

__all__ = ["SupportsCliCommand", "CliCommandException"]
22 changes: 18 additions & 4 deletions dlt/cli/config_toml_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, NamedTuple, Tuple, Iterable
from typing import Any, NamedTuple, Tuple, Iterable, Mapping
import tomlkit
from tomlkit.items import Table as TOMLTable
from tomlkit.container import Container as TOMLContainer
Expand Down Expand Up @@ -72,7 +72,7 @@ def write_value(
hint = extract_inner_hint(hint)
if is_base_configuration_inner_hint(hint):
inner_table = tomlkit.table(is_super_table=True)
write_spec(inner_table, hint(), overwrite_existing)
write_spec(inner_table, hint(), default_value, overwrite_existing)
if len(inner_table) > 0:
toml_table[name] = inner_table
else:
Expand All @@ -86,17 +86,31 @@ def write_value(
toml_table[name] = default_value


def write_spec(toml_table: TOMLTable, config: BaseConfiguration, overwrite_existing: bool) -> None:
def write_spec(
toml_table: TOMLTable,
config: BaseConfiguration,
initial_value: Mapping[str, Any],
overwrite_existing: bool,
) -> None:
for name, hint in config.get_resolvable_fields().items():
# use initial value
initial_ = initial_value.get(name) if initial_value else None
# use default value stored in config
default_value = getattr(config, name, None)

# check if field is of particular interest and should be included if it has default
is_default_of_interest = name in config.__config_gen_annotations__

# if initial is different from default, it is of interest as well
if initial_ is not None:
is_default_of_interest = is_default_of_interest or (initial_ != default_value)

write_value(
toml_table,
name,
hint,
overwrite_existing,
default_value=default_value,
default_value=initial_ or default_value,
is_default_of_interest=is_default_of_interest,
)

Expand Down
4 changes: 4 additions & 0 deletions dlt/common/configuration/providers/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def __init__(
it will additionally look for `file_name` in `dlt` global dir (home dir by default) and merge the content.
The "settings" (`settings_dir`) values overwrite the "global" values.

If toml file under `settings_dir` is not found it will look into Google Colab userdata object for a value
with name `file_name` and load toml file from it.
If that one is not found, it will try to load Streamlit `secrets.toml` file.

If none of the files exist, an empty provider is created.

Args:
Expand Down
12 changes: 9 additions & 3 deletions dlt/extract/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,9 @@ def __call__(
"""Makes dlt source"""
pass

# TODO: make factory to expose SourceReference with actual spec, name and section
# model after Destination, which also needs to be broken down into reference and factory

def with_args(
self,
*,
Expand All @@ -511,14 +514,17 @@ def with_args(
"""Overrides default decorator arguments that will be used to when DltSource instance and returns modified clone."""


AnySourceFactory = SourceFactory[Any, DltSource]


class SourceReference:
"""Runtime information on the source/resource"""

SOURCES: ClassVar[Dict[str, "SourceReference"]] = {}
"""A registry of all the decorated sources and resources discovered when importing modules"""

SPEC: Type[BaseConfiguration]
f: SourceFactory[Any, DltSource]
f: AnySourceFactory
module: ModuleType
section: str
name: str
Expand All @@ -527,7 +533,7 @@ class SourceReference:
def __init__(
self,
SPEC: Type[BaseConfiguration],
f: SourceFactory[Any, DltSource],
f: AnySourceFactory,
module: ModuleType,
section: str,
name: str,
Expand Down Expand Up @@ -582,7 +588,7 @@ def find(cls, ref: str) -> "SourceReference":
raise KeyError(refs)

@classmethod
def from_reference(cls, ref: str) -> SourceFactory[Any, DltSource]:
def from_reference(cls, ref: str) -> AnySourceFactory:
"""Returns registered source factory or imports source module and returns a function.
Expands shorthand notation into section.name eg. "sql_database" is expanded into "sql_database.sql_database"
"""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dlt"
version = "1.3.1a1"
version = "1.3.1a2"
description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run."
authors = ["dltHub Inc. <services@dlthub.com>"]
maintainers = [ "Marcin Rudolf <marcin@dlthub.com>", "Adrian Brudaru <adrian@dlthub.com>", "Anton Burnashev <anton@dlthub.com>", "David Scharf <david@dlthub.com>" ]
Expand Down
96 changes: 95 additions & 1 deletion tests/cli/test_config_toml_writer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Optional, Final
from typing import ClassVar, List, Optional, Final
import pytest
import tomlkit

from dlt.cli.config_toml_writer import write_value, WritableConfigValue, write_values
from dlt.common.configuration.specs import configspec
from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT

EXAMPLE_COMMENT = "# please set me up!"

Expand Down Expand Up @@ -159,3 +161,95 @@ def test_write_values_without_defaults(example_toml):

assert example_toml["genomic_info"]["gene_data"]["genes"] == {"key": "value"}
assert example_toml["genomic_info"]["gene_data"]["genes"].trivia.comment == EXAMPLE_COMMENT


def test_write_spec_without_defaults(example_toml) -> None:
from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration
from dlt.destinations.impl.filesystem.configuration import (
FilesystemDestinationClientConfiguration,
)

write_value(
example_toml, "snowflake", SnowflakeClientConfiguration, False, is_default_of_interest=True
)
# nothing of interest in "snowflake"
# host, database, username are required and will be included
# "password", "warehouse", "role" are explicitly of interest
assert example_toml.as_string() == """[snowflake.credentials]
database = "database" # please set me up!
password = "password" # please set me up!
username = "username" # please set me up!
host = "host" # please set me up!
warehouse = "warehouse" # please set me up!
role = "role" # please set me up!
"""
example_toml = tomlkit.parse("")
write_value(
example_toml,
"filesystem",
FilesystemDestinationClientConfiguration,
False,
is_default_of_interest=True,
)

# bucket_url is mandatory, same for aws credentials
assert example_toml.as_string() == """[filesystem]
bucket_url = "bucket_url" # please set me up!

[filesystem.credentials]
aws_access_key_id = "aws_access_key_id" # please set me up!
aws_secret_access_key = "aws_secret_access_key" # please set me up!
"""

@configspec
class SnowflakeDatabaseConfiguration(SnowflakeClientConfiguration):
database: str = "dlt_db"

__config_gen_annotations__: ClassVar[List[str]] = ["database"]

example_toml = tomlkit.parse("")
write_value(
example_toml,
"snowflake",
SnowflakeDatabaseConfiguration,
False,
is_default_of_interest=True,
)

# uses default value
assert example_toml["snowflake"]["database"] == "dlt_db"

# use initial values
example_toml = tomlkit.parse("")
write_value(
example_toml,
"filesystem",
FilesystemDestinationClientConfiguration,
False,
is_default_of_interest=True,
default_value={
"bucket_url": "az://test-az-bucket",
"layout": DEFAULT_FILE_LAYOUT,
"credentials": {"region_name": "eu"},
},
)
assert example_toml["filesystem"]["bucket_url"] == "az://test-az-bucket"
# TODO: choose right credentials based on bucket_url
assert example_toml["filesystem"]["credentials"]["aws_access_key_id"] == "aws_access_key_id"
# if initial value is different from the default then it is included
assert example_toml["filesystem"]["credentials"]["region_name"] == "eu"
# this is same as default so not included
assert "layout" not in example_toml["filesystem"]

example_toml = tomlkit.parse("")
write_value(
example_toml,
"snowflake",
SnowflakeDatabaseConfiguration,
False,
is_default_of_interest=True,
default_value={"database": "dlt_db"},
)

# still here because marked specifically as of interest
assert example_toml["snowflake"]["database"] == "dlt_db"
4 changes: 2 additions & 2 deletions tests/helpers/dbt_tests/test_runner_dbt_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,10 @@ def test_infer_venv_deps() -> None:
# provide version ranges
requirements = _create_dbt_deps(["duckdb"], dbt_version=">3")
# special duckdb dependency
assert requirements[:-1] == ["dbt-core>3", "dbt-duckdb", "duckdb==1.1.0"]
assert requirements[:-1] == ["dbt-core>3", "dbt-duckdb", "duckdb==1.1.2"]
# we do not validate version ranges, pip will do it and fail when creating venv
requirements = _create_dbt_deps(["motherduck"], dbt_version="y")
assert requirements[:-1] == ["dbt-corey", "dbt-duckdb", "duckdb==1.1.0"]
assert requirements[:-1] == ["dbt-corey", "dbt-duckdb", "duckdb==1.1.2"]


def test_default_profile_name() -> None:
Expand Down
Loading