Skip to content

Commit

Permalink
Merge pull request dbt-labs#1
Browse files Browse the repository at this point in the history
Add dynamic pointers to Python modules
  • Loading branch information
Bilbottom authored Aug 1, 2023
2 parents 87ea28f + d99337b commit 38c9aff
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 3 deletions.
43 changes: 42 additions & 1 deletion core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
Project as ProjectContract,
SemverString,
)
from dbt.contracts.project import PackageConfig, ProjectPackageMetadata
from dbt.contracts.project import PackageConfig, ModuleConfig, ProjectPackageMetadata
from dbt.dataclass_schema import ValidationError
from .renderer import DbtProjectYamlRenderer, PackageRenderer
from .selectors import (
Expand Down Expand Up @@ -75,6 +75,14 @@
{error}
"""

MALFORMED_MODULE_ERROR = """\
The modules.yml file in this project is malformed. Please double check
the contents of this file and fix any errors before retrying.
Validator Error:
{error}
"""

MISSING_DBT_PROJECT_ERROR = """\
No dbt_project.yml found at expected path {path}
Verify that each entry within packages.yml (and their transitive dependencies) contains a file named dbt_project.yml
Expand Down Expand Up @@ -103,6 +111,16 @@ def package_data_from_root(project_root):
return packages_dict


def module_data_from_root(project_root):
module_filepath = resolve_path_from_base("modules.yml", project_root)

if path_exists(module_filepath):
modules_dict = _load_yaml(module_filepath)
else:
modules_dict = None
return modules_dict


def package_config_from_data(packages_data: Dict[str, Any]):
if not packages_data:
packages_data = {"packages": []}
Expand All @@ -115,6 +133,18 @@ def package_config_from_data(packages_data: Dict[str, Any]):
return packages


def module_config_from_data(modules_data: Dict[str, Any]):
if not modules_data:
modules_data = {"modules": []}

try:
ModuleConfig.validate(modules_data)
modules = ModuleConfig.from_dict(modules_data)
except ValidationError as e:
raise DbtProjectError(MALFORMED_MODULE_ERROR.format(error=str(e.message))) from e
return modules


def _parse_versions(versions: Union[List[str], str]) -> List[VersionSpecifier]:
"""Parse multiple versions as read from disk. The versions value may be any
one of:
Expand Down Expand Up @@ -239,6 +269,7 @@ def _get_required_version(
class RenderComponents:
project_dict: Dict[str, Any] = field(metadata=dict(description="The project dictionary"))
packages_dict: Dict[str, Any] = field(metadata=dict(description="The packages dictionary"))
modules_dict: Dict[str, Any] = field(metadata=dict(description="The modules dictionary"))
selectors_dict: Dict[str, Any] = field(metadata=dict(description="The selectors dictionary"))


Expand Down Expand Up @@ -273,11 +304,13 @@ def get_rendered(

rendered_project = renderer.render_project(self.project_dict, self.project_root)
rendered_packages = renderer.render_packages(self.packages_dict)
rendered_modules = renderer.render_modules(self.modules_dict)
rendered_selectors = renderer.render_selectors(self.selectors_dict)

return RenderComponents(
project_dict=rendered_project,
packages_dict=rendered_packages,
modules_dict=rendered_modules,
selectors_dict=rendered_selectors,
)

Expand Down Expand Up @@ -324,6 +357,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
unrendered = RenderComponents(
project_dict=self.project_dict,
packages_dict=self.packages_dict,
modules_dict=self.modules_dict,
selectors_dict=self.selectors_dict,
)
dbt_version = _get_required_version(
Expand Down Expand Up @@ -425,6 +459,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
query_comment = _query_comment_from_cfg(cfg.query_comment)

packages = package_config_from_data(rendered.packages_dict)
modules = module_config_from_data(rendered.modules_dict)
selectors = selector_config_from_data(rendered.selectors_dict)
manifest_selectors: Dict[str, Any] = {}
if rendered.selectors_dict and rendered.selectors_dict["selectors"]:
Expand Down Expand Up @@ -459,6 +494,7 @@ def create_project(self, rendered: RenderComponents) -> "Project":
snapshots=snapshots,
dbt_version=dbt_version,
packages=packages,
modules=modules,
manifest_selectors=manifest_selectors,
selectors=selectors,
query_comment=query_comment,
Expand All @@ -481,6 +517,7 @@ def from_dicts(
project_root: str,
project_dict: Dict[str, Any],
packages_dict: Dict[str, Any],
modules_dict: Dict[str, Any],
selectors_dict: Dict[str, Any],
*,
verify_version: bool = False,
Expand All @@ -495,6 +532,7 @@ def from_dicts(
project_root=project_root,
project_dict=project_dict,
packages_dict=packages_dict,
modules_dict=modules_dict,
selectors_dict=selectors_dict,
verify_version=verify_version,
)
Expand All @@ -506,12 +544,14 @@ def from_project_root(
project_root = os.path.normpath(project_root)
project_dict = load_raw_project(project_root)
packages_dict = package_data_from_root(project_root)
modules_dict = module_data_from_root(project_root)
selectors_dict = selector_data_from_root(project_root)
return cls.from_dicts(
project_root=project_root,
project_dict=project_dict,
selectors_dict=selectors_dict,
packages_dict=packages_dict,
modules_dict=modules_dict,
verify_version=verify_version,
)

Expand Down Expand Up @@ -566,6 +606,7 @@ class Project:
vars: VarProvider
dbt_version: List[VersionSpecifier]
packages: Dict[str, Any]
modules: Dict[str, Any]
manifest_selectors: Dict[str, Any]
selectors: SelectorConfig
query_comment: QueryComment
Expand Down
4 changes: 4 additions & 0 deletions core/dbt/config/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ def render_packages(self, packages: Dict[str, Any]):
package_renderer = self.get_package_renderer()
return package_renderer.render_data(packages)

def render_modules(self, modules: Dict[str, Any]):
"""Render the given modules dict"""
return self.render_data(modules)

def render_selectors(self, selectors: Dict[str, Any]):
return self.render_data(selectors)

Expand Down
1 change: 1 addition & 0 deletions core/dbt/config/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def from_parts(
snapshots=project.snapshots,
dbt_version=project.dbt_version,
packages=project.packages,
modules=project.modules,
manifest_selectors=project.manifest_selectors,
selectors=project.selectors,
query_comment=project.query_comment,
Expand Down
35 changes: 33 additions & 2 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import importlib
import json
import os
import sys
from typing import Any, Dict, NoReturn, Optional, Mapping, Iterable, Set, List

from dbt.flags import get_flags
Expand All @@ -8,7 +10,9 @@
from dbt import utils
from dbt.clients.jinja import get_rendered
from dbt.clients.yaml_helper import yaml, safe_load, SafeLoader, Loader, Dumper # noqa: F401
# from dbt.config.runtime import RuntimeConfig
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.contracts.project import Module
from dbt.contracts.graph.nodes import Resource
from dbt.exceptions import (
SecretEnvVarLocationError,
Expand Down Expand Up @@ -75,7 +79,7 @@ def get_itertools_module_context() -> Dict[str, Any]:
return {name: getattr(itertools, name) for name in context_exports}


def get_context_modules() -> Dict[str, Dict[str, Any]]:
def get_default_context_modules() -> Dict[str, Dict[str, Any]]:
return {
"pytz": get_pytz_module_context(),
"datetime": get_datetime_module_context(),
Expand All @@ -84,6 +88,30 @@ def get_context_modules() -> Dict[str, Dict[str, Any]]:
}


def get_module_context(module: Module) -> Dict[str, Any]:
if module.location is not None:
sys.path.append(module.location)

py_module = importlib.import_module(module.package)

return {name: getattr(py_module, name) for name in module.exports}


def get_context_modules(modules: List[Module]) -> Dict[str, Dict[str, Any]]:
default_modules = get_default_context_modules()
custom_modules = {
module.package: get_module_context(module=module)
for module in modules
}

# Overwrite default modules with custom modules if there are any
# conflicts, with the defaults kept for backwards compatibility
return {
**default_modules,
**custom_modules,
}


class ContextMember:
def __init__(self, value, name=None):
self.name = name
Expand Down Expand Up @@ -619,7 +647,10 @@ def modules(self) -> Dict[str, Any]:
{% set dt_local = modules.pytz.timezone('US/Eastern').localize(dt) %}
{{ dt_local }}
""" # noqa
return get_context_modules()
if type(self).__name__ not in ("SecretContext", "TargetContext"):
return get_context_modules(modules=self.config.modules.modules)

return get_default_context_modules()

@contextproperty
def flags(self) -> Any:
Expand Down
16 changes: 16 additions & 0 deletions core/dbt/contracts/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ def validate(cls, data):
super().validate(data)


@dataclass
class Module(Replaceable, HyphenatedDbtClassMixin):
package: str
exports: List[str] # __all__ is not allowed since not all modules implement this
location: Optional[str] = None # For local modules only (since we need to add them to the path)


@dataclass
class ModuleConfig(dbtClassMixin, Replaceable):
modules: List[Module]

@classmethod
def validate(cls, *args, **kwargs):
pass


@dataclass
class ProjectPackageMetadata:
name: str
Expand Down

0 comments on commit 38c9aff

Please sign in to comment.