diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index 7462c469dbd..1909d9627a3 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -14,8 +14,9 @@ from dbt.contracts.connection import ( Connection, Identifier, ConnectionState, AdapterRequiredConfig, LazyHandle ) +from dbt.contracts.graph.manifest import Manifest from dbt.adapters.base.query_headers import ( - QueryStringSetter, MacroQueryStringSetter, + MacroQueryStringSetter, ) from dbt.logger import GLOBAL_LOGGER as logger @@ -39,13 +40,10 @@ def __init__(self, profile: AdapterRequiredConfig): self.profile = profile self.thread_connections: Dict[Hashable, Connection] = {} self.lock: RLock = dbt.flags.MP_CONTEXT.RLock() - self.query_header = QueryStringSetter(self.profile) + self.query_header: Optional[MacroQueryStringSetter] = None - def set_query_header(self, manifest=None) -> None: - if manifest is not None: - self.query_header = MacroQueryStringSetter(self.profile, manifest) - else: - self.query_header = QueryStringSetter(self.profile) + def set_query_header(self, manifest: Manifest) -> None: + self.query_header = MacroQueryStringSetter(self.profile, manifest) @staticmethod def get_thread_identifier() -> Hashable: @@ -285,6 +283,8 @@ def commit_if_has_connection(self) -> None: self.commit() def _add_query_comment(self, sql: str) -> str: + if self.query_header is None: + return sql return self.query_header.add(sql) @abc.abstractmethod diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 297713e2c7b..a97743494dc 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -245,12 +245,14 @@ def connection_named( self, name: str, node: Optional[CompileResultNode] = None ) -> Iterator[None]: try: - self.connections.query_header.set(name, node) + if self.connections.query_header is not None: + self.connections.query_header.set(name, node) self.acquire_connection(name) yield finally: self.release_connection() - self.connections.query_header.reset() + if self.connections.query_header is not None: + self.connections.query_header.reset() @contextmanager def connection_for( @@ -980,12 +982,12 @@ def execute_macro( 'dbt could not find a macro with the name "{}" in {}' .format(macro_name, package_name) ) - # This causes a reference cycle, as dbt.context.runtime.generate() + # This causes a reference cycle, as generate_runtime_macro() # ends up calling get_adapter, so the import has to be here. - import dbt.context.operation - macro_context = dbt.context.operation.generate( - model=macro, - runtime_config=self.config, + from dbt.context.providers import generate_runtime_macro + macro_context = generate_runtime_macro( + macro=macro, + config=self.config, manifest=manifest, package_name=project ) diff --git a/core/dbt/adapters/base/plugin.py b/core/dbt/adapters/base/plugin.py index d9f128ad33d..4085d991248 100644 --- a/core/dbt/adapters/base/plugin.py +++ b/core/dbt/adapters/base/plugin.py @@ -1,6 +1,7 @@ from typing import List, Optional, Type from dbt.adapters.base import BaseAdapter, Credentials +from dbt.exceptions import CompilationException class AdapterPlugin: @@ -23,8 +24,12 @@ def __init__( self.adapter: Type[BaseAdapter] = adapter self.credentials: Type[Credentials] = credentials self.include_path: str = include_path - project = Project.from_project_root(include_path, {}) - self.project_name: str = project.project_name + partial = Project.partial_load(include_path) + if partial.project_name is None: + raise CompilationException( + f'Invalid project at {include_path}: name not set!' + ) + self.project_name: str = partial.project_name self.dependencies: List[str] if dependencies is None: self.dependencies = [] diff --git a/core/dbt/adapters/base/query_headers.py b/core/dbt/adapters/base/query_headers.py index b7a6c8fdd40..f474ade902c 100644 --- a/core/dbt/adapters/base/query_headers.py +++ b/core/dbt/adapters/base/query_headers.py @@ -3,8 +3,7 @@ from dbt.clients.jinja import QueryStringGenerator -# this generates an import cycle, as usual -from dbt.context.base import QueryHeaderContext +from dbt.context.configured import generate_query_header_context from dbt.contracts.connection import AdapterRequiredConfig from dbt.contracts.graph.compiled import CompileResultNode from dbt.contracts.graph.manifest import Manifest @@ -68,9 +67,9 @@ def set(self, comment: Optional[str]): QueryStringFunc = Callable[[str, Optional[NodeWrapper]], str] -class QueryStringSetter: - """The base query string setter. This is only used once.""" - def __init__(self, config: AdapterRequiredConfig): +class MacroQueryStringSetter: + def __init__(self, config: AdapterRequiredConfig, manifest: Manifest): + self.manifest = manifest self.config = config comment_macro = self._get_comment_macro() @@ -88,17 +87,22 @@ def __init__(self, config: AdapterRequiredConfig): self.comment = _QueryComment(None) self.reset() - def _get_context(self): - return QueryHeaderContext(self.config).to_dict() - - def _get_comment_macro(self) -> Optional[str]: + def _get_comment_macro(self): + if ( + self.config.query_comment != NoValue() and + self.config.query_comment + ): + return self.config.query_comment # if the query comment is null/empty string, there is no comment at all - if not self.config.query_comment: + elif not self.config.query_comment: return None else: # else, the default return DEFAULT_QUERY_COMMENT + def _get_context(self) -> Dict[str, Any]: + return generate_query_header_context(self.config, self.manifest) + def add(self, sql: str) -> str: return self.comment.add(sql) @@ -111,21 +115,3 @@ def set(self, name: str, node: Optional[CompileResultNode]): wrapped = NodeWrapper(node) comment_str = self.generator(name, wrapped) self.comment.set(comment_str) - - -class MacroQueryStringSetter(QueryStringSetter): - def __init__(self, config: AdapterRequiredConfig, manifest: Manifest): - self.manifest = manifest - super().__init__(config) - - def _get_comment_macro(self): - if ( - self.config.query_comment != NoValue() and - self.config.query_comment - ): - return self.config.query_comment - else: - return super()._get_comment_macro() - - def _get_context(self) -> Dict[str, Any]: - return QueryHeaderContext(self.config).to_dict(self.manifest.macros) diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index 85c19548d82..e254d3a0780 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -4,7 +4,7 @@ import tempfile from contextlib import contextmanager from typing import ( - List, Union, Set, Optional, Dict, Any, Callable, Iterator, Type + List, Union, Set, Optional, Dict, Any, Callable, Iterator, Type, NoReturn ) import jinja2 @@ -374,7 +374,7 @@ def get_rendered(string, ctx, node=None, return render_template(template, ctx, node) -def undefined_error(msg): +def undefined_error(msg) -> NoReturn: raise jinja2.exceptions.UndefinedError(msg) diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index 34a412337e9..e011fb20c03 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -10,7 +10,7 @@ from dbt.node_types import NodeType from dbt.linker import Linker -import dbt.context.runtime +from dbt.context.providers import generate_runtime_model import dbt.contracts.project import dbt.exceptions import dbt.flags @@ -146,8 +146,9 @@ def compile_node(self, node, manifest, extra_context=None): }) compiled_node = _compiled_type_for(node).from_dict(data) - context = dbt.context.runtime.generate( - compiled_node, self.config, manifest) + context = generate_runtime_model( + compiled_node, self.config, manifest + ) context.update(extra_context) compiled_node.compiled_sql = dbt.clients.jinja.get_rendered( @@ -253,13 +254,11 @@ def compile_node(adapter, config, node, manifest, extra_context, write=True): logger.debug('Writing injected SQL for node "{}"'.format( node.unique_id)) - written_path = dbt.writer.write_node( - node, + node.build_path = node.write_node( config.target_path, 'compiled', - node.injected_sql) - - node.build_path = written_path + node.injected_sql + ) return node diff --git a/core/dbt/config/contexts.py b/core/dbt/config/contexts.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/core/dbt/config/profile.py b/core/dbt/config/profile.py index 7d2446b3f14..adaf4d35192 100644 --- a/core/dbt/config/profile.py +++ b/core/dbt/config/profile.py @@ -6,7 +6,7 @@ from dbt.clients.system import load_file_contents from dbt.clients.yaml_helper import load_yaml_text -from dbt.contracts.connection import Credentials +from dbt.contracts.connection import Credentials, HasCredentials from dbt.contracts.project import ProfileConfig, UserConfig from dbt.exceptions import DbtProfileError from dbt.exceptions import DbtProjectError @@ -14,7 +14,7 @@ from dbt.exceptions import RuntimeException from dbt.exceptions import validator_error_message from dbt.logger import GLOBAL_LOGGER as logger -from dbt.utils import parse_cli_vars, coerce_dict_str +from dbt.utils import coerce_dict_str from .renderer import ConfigRenderer @@ -73,7 +73,7 @@ def read_user_config(directory: str) -> UserConfig: @dataclass -class Profile: +class Profile(HasCredentials): profile_name: str target_name: str config: UserConfig @@ -217,13 +217,11 @@ def render_profile( raw_profile: Dict[str, Any], profile_name: str, target_override: Optional[str], - cli_vars: Dict[str, Any], + renderer: ConfigRenderer, ) -> Tuple[str, Dict[str, Any]]: """This is a containment zone for the hateful way we're rendering profiles. """ - renderer = ConfigRenderer(cli_vars=cli_vars) - # rendering profiles is a bit complex. Two constraints cause trouble: # 1) users should be able to use environment/cli variables to specify # the target in their profile. @@ -255,7 +253,7 @@ def from_raw_profile_info( cls, raw_profile: Dict[str, Any], profile_name: str, - cli_vars: Dict[str, Any], + renderer: ConfigRenderer, user_cfg: Optional[Dict[str, Any]] = None, target_override: Optional[str] = None, threads_override: Optional[int] = None, @@ -267,8 +265,7 @@ def from_raw_profile_info( :param raw_profile: The profile data for a single profile, from disk as yaml and its values rendered with jinja. :param profile_name: The profile name used. - :param cli_vars: The command-line variables passed as arguments, - as a dict. + :param renderer: The config renderer. :param user_cfg: The global config for the user, if it was present. :param target_override: The target to use, if provided on @@ -285,7 +282,7 @@ def from_raw_profile_info( # TODO: should it be, and the values coerced to bool? target_name, profile_data = cls.render_profile( - raw_profile, profile_name, target_override, cli_vars + raw_profile, profile_name, target_override, renderer ) # valid connections never include the number of threads, but it's @@ -311,15 +308,14 @@ def from_raw_profiles( cls, raw_profiles: Dict[str, Any], profile_name: str, - cli_vars: Dict[str, Any], + renderer: ConfigRenderer, target_override: Optional[str] = None, threads_override: Optional[int] = None, ) -> 'Profile': """ :param raw_profiles: The profile data, from disk as yaml. :param profile_name: The profile name to use. - :param cli_vars: The command-line variables passed as arguments, as a - dict. + :param renderer: The config renderer. :param target_override: The target to use, if provided on the command line. :param threads_override: The thread count to use, if provided on the @@ -344,17 +340,18 @@ def from_raw_profiles( return cls.from_raw_profile_info( raw_profile=raw_profile, profile_name=profile_name, - cli_vars=cli_vars, + renderer=renderer, user_cfg=user_cfg, target_override=target_override, threads_override=threads_override, ) @classmethod - def from_args( + def render_from_args( cls, args: Any, - project_profile_name: Optional[str] = None, + renderer: ConfigRenderer, + project_profile_name: Optional[str], ) -> 'Profile': """Given the raw profiles as read from disk and the name of the desired profile if specified, return the profile component of the runtime @@ -370,7 +367,6 @@ def from_args( target could not be found. :returns Profile: The new Profile object. """ - cli_vars = parse_cli_vars(getattr(args, 'vars', '{}')) threads_override = getattr(args, 'threads', None) target_override = getattr(args, 'target', None) raw_profiles = read_profile(args.profiles_dir) @@ -380,7 +376,7 @@ def from_args( return cls.from_raw_profiles( raw_profiles=raw_profiles, profile_name=profile_name, - cli_vars=cli_vars, + renderer=renderer, target_override=target_override, threads_override=threads_override ) diff --git a/core/dbt/config/project.py b/core/dbt/config/project.py index 620aa0773bc..1c9a7f39c15 100644 --- a/core/dbt/config/project.py +++ b/core/dbt/config/project.py @@ -20,7 +20,6 @@ from dbt.version import get_installed_version from dbt.ui import printer from dbt.utils import deep_map -from dbt.utils import parse_cli_vars from dbt.source_config import SourceConfig from dbt.contracts.graph.manifest import ManifestMetadata @@ -166,6 +165,37 @@ def value_or(value: Optional[T], default: T) -> T: return value +def _raw_project_from(project_root: str) -> Dict[str, Any]: + + project_root = os.path.normpath(project_root) + project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml') + + # get the project.yml contents + if not path_exists(project_yaml_filepath): + raise DbtProjectError( + 'no dbt_project.yml found at expected path {}' + .format(project_yaml_filepath) + ) + + project_dict = _load_yaml(project_yaml_filepath) + return project_dict + + +@dataclass +class PartialProject: + profile_name: Optional[str] + project_name: Optional[str] + project_root: str + project_dict: Dict[str, Any] + + def render(self, renderer): + return Project.render_from_dict( + self.project_root, + self.project_dict, + renderer, + ) + + @dataclass class Project: project_name: str @@ -401,42 +431,41 @@ def validate(self): raise DbtProjectError(validator_error_message(e)) from e @classmethod - def from_project_root(cls, project_root, cli_vars): - """Create a project from a root directory. Reads in dbt_project.yml and - packages.yml, if it exists. - - :param project_root str: The path to the project root to load. - :raises DbtProjectError: If the project is missing or invalid, or if - the packages file exists and is invalid. - :returns Project: The project, with defaults populated. - """ - project_root = os.path.normpath(project_root) - project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml') - - # get the project.yml contents - if not path_exists(project_yaml_filepath): - raise DbtProjectError( - 'no dbt_project.yml found at expected path {}' - .format(project_yaml_filepath) - ) - - if isinstance(cli_vars, str): - cli_vars = parse_cli_vars(cli_vars) - renderer = ConfigRenderer(cli_vars) - - project_dict = _load_yaml(project_yaml_filepath) + def render_from_dict( + cls, + project_root: str, + project_dict: Dict[str, Any], + renderer: ConfigRenderer, + ) -> 'Project': rendered_project = renderer.render_project(project_dict) rendered_project['project-root'] = project_root packages_dict = package_data_from_root(project_root) - return cls.from_project_config(rendered_project, packages_dict) + rendered_packages = renderer.render_packages_data(packages_dict) + return cls.from_project_config(rendered_project, rendered_packages) @classmethod - def from_current_directory(cls, cli_vars): - return cls.from_project_root(os.getcwd(), cli_vars) + def partial_load( + cls, project_root: str + ) -> PartialProject: + project_root = os.path.normpath(project_root) + project_dict = _raw_project_from(project_root) + + project_name = project_dict.get('name') + profile_name = project_dict.get('profile') + + return PartialProject( + profile_name=profile_name, + project_name=project_name, + project_root=project_root, + project_dict=project_dict, + ) @classmethod - def from_args(cls, args): - return cls.from_current_directory(getattr(args, 'vars', '{}')) + def from_project_root( + cls, project_root: str, renderer: ConfigRenderer + ) -> 'Project': + partial = cls.partial_load(project_root) + return partial.render(renderer) def hashed_name(self): return hashlib.md5(self.project_name.encode('utf-8')).hexdigest() diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 51bb5a58fdc..3b3881c267d 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -1,5 +1,7 @@ +from typing import Dict, Any + from dbt.clients.jinja import get_rendered -from dbt.context.base import ConfigRenderContext + from dbt.exceptions import DbtProfileError from dbt.exceptions import DbtProjectError from dbt.exceptions import RecursionException @@ -10,8 +12,8 @@ class ConfigRenderer: """A renderer provides configuration rendering for a given set of cli variables and a render type. """ - def __init__(self, cli_vars): - self.context = ConfigRenderContext(cli_vars).to_dict() + def __init__(self, context: Dict[str, Any]): + self.context = context @staticmethod def _is_deferred_render(keypath): @@ -107,3 +109,12 @@ def render_schema_source(self, as_parsed): 'Cycle detected: schema.yml input has a reference to itself', project=as_parsed ) + + def render_packages_data(self, as_parsed): + try: + return deep_map(self.render_value, as_parsed) + except RecursionException: + raise DbtProfileError( + 'Cycle detected: schema.yml input has a reference to itself', + project=as_parsed + ) diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index 3a190792ef5..4117cc7e0db 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -1,21 +1,25 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Dict, Any, Optional +import os +from typing import Dict, Any from .profile import Profile from .project import Project +from .renderer import ConfigRenderer from dbt.utils import parse_cli_vars +from dbt.context.base import generate_base_context +from dbt.context.target import generate_target_context +from dbt.contracts.connection import AdapterRequiredConfig from dbt.contracts.project import Configuration from dbt.exceptions import DbtProjectError from dbt.exceptions import validator_error_message from dbt.adapters.factory import get_relation_class_by_name - from hologram import ValidationError @dataclass -class RuntimeConfig(Project, Profile): +class RuntimeConfig(Project, Profile, AdapterRequiredConfig): args: Any cli_vars: Dict[str, Any] @@ -86,8 +90,11 @@ def new_project(self, project_root: str) -> 'RuntimeConfig': # copy profile profile = Profile(**self.to_profile_info()) profile.validate() + # load the new project and its packages. Don't pass cli variables. - project = Project.from_project_root(project_root, {}) + renderer = ConfigRenderer(generate_target_context(profile, {})) + + project = Project.from_project_root(project_root, renderer) cfg = self.from_parts( project=project, @@ -126,9 +133,7 @@ def validate(self): self.validate_version() @classmethod - def from_args( - cls, args: Any, project_profile_name: Optional[str] = None - ) -> 'RuntimeConfig': + def from_args(cls, args: Any) -> 'RuntimeConfig': """Given arguments, read in dbt_project.yml from the current directory, read in packages.yml if it exists, and use them to find the profile to load. @@ -138,18 +143,21 @@ def from_args( :raises DbtProfileError: If the profile is invalid or missing. :raises ValidationException: If the cli variables are invalid. """ - # project_profile_name is ignored, we just need it to appease mypy - # (Profile.from_args uses it) + # profile_name from the project + partial = Project.partial_load(os.getcwd()) - # build the project and read in packages.yml - project = Project.from_args(args) - - # build the profile - profile = Profile.from_args( - args=args, - project_profile_name=project.profile_name + # build the profile using the base renderer and the one fact we know + cli_vars: Dict[str, Any] = parse_cli_vars(getattr(args, 'vars', '{}')) + renderer = ConfigRenderer(generate_base_context(cli_vars=cli_vars)) + profile = Profile.render_from_args( + args, renderer, partial.profile_name ) + # get a new renderer using our target information and render the + # project + renderer = ConfigRenderer(generate_target_context(profile, cli_vars)) + project = partial.render(renderer) + return cls.from_parts( project=project, profile=profile, diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 1aa66d11b3c..2b2989f3b0d 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -1,19 +1,17 @@ -import itertools import json import os -from typing import Callable, Any, Dict, List, Optional, Mapping - -import dbt.tracking -from dbt.clients.jinja import undefined_error -from dbt.contracts.graph.parsed import ParsedMacro -from dbt.exceptions import MacroReturn, raise_compiler_error -from dbt.include.global_project import PACKAGES -from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME +from typing import ( + Any, Dict, NoReturn, Optional +) + +from dbt import flags +from dbt import tracking +from dbt.clients.jinja import undefined_error, get_rendered +from dbt.exceptions import raise_compiler_error, MacroReturn from dbt.logger import GLOBAL_LOGGER as logger +from dbt.utils import merge from dbt.version import __version__ as dbt_version -from dbt.node_types import NodeType - import yaml # These modules are added to the context. Consider alternative # approaches which will extend well to potentially many modules @@ -21,21 +19,78 @@ import datetime -def env_var(var, default=None): - if var in os.environ: - return os.environ[var] - elif default is not None: - return default - else: - msg = "Env var required but not provided: '{}'".format(var) - undefined_error(msg) +def get_pytz_module_context() -> Dict[str, Any]: + context_exports = pytz.__all__ # type: ignore + + return { + name: getattr(pytz, name) for name in context_exports + } + + +def get_datetime_module_context() -> Dict[str, Any]: + context_exports = [ + 'date', + 'datetime', + 'time', + 'timedelta', + 'tzinfo' + ] + + return { + name: getattr(datetime, name) for name in context_exports + } + + +def get_context_modules() -> Dict[str, Dict[str, Any]]: + return { + 'pytz': get_pytz_module_context(), + 'datetime': get_datetime_module_context(), + } + + +class ContextMember: + def __init__(self, value, name=None): + self.name = name + self.inner = value + + def key(self, default): + if self.name is None: + return default + return self.name + + +def contextmember(value): + if isinstance(value, str): + return lambda v: ContextMember(v, name=value) + return ContextMember(value) + + +def contextproperty(value): + if isinstance(value, str): + return lambda v: ContextMember(property(v), name=value) + return ContextMember(property(value)) -def debug_here(): - import sys - import ipdb # type: ignore - frame = sys._getframe(3) - ipdb.set_trace(frame) +class ContextMeta(type): + def __new__(mcls, name, bases, dct): + context_members = {} + context_attrs = {} + new_dct = {} + + for base in bases: + context_members.update(getattr(base, '_context_members_', {})) + context_attrs.update(getattr(base, '_context_attrs_', {})) + + for key, value in dct.items(): + if isinstance(value, ContextMember): + context_key = value.key(key) + context_members[context_key] = value.inner + context_attrs[context_key] = key + value = value.inner + new_dct[key] = value + new_dct['_context_members_'] = context_members + new_dct['_context_attrs_'] = context_attrs + return type.__new__(mcls, name, bases, new_dct) class Var: @@ -59,7 +114,7 @@ def __init__(self, model, context, overrides): self.model_name = model.name local_vars = model.local_vars() - self.local_vars = dbt.utils.merge(local_vars, overrides) + self.local_vars = merge(local_vars, overrides) def pretty_dict(self, data): return json.dumps(data, sort_keys=True, indent=4) @@ -81,7 +136,7 @@ def get_rendered_var(self, var_name): if not isinstance(raw, str): return raw - return dbt.clients.jinja.get_rendered(raw, self.context) + return get_rendered(raw, self.context) def __call__(self, var_name, default=_VAR_NOTSET): if var_name in self.local_vars: @@ -92,202 +147,131 @@ def __call__(self, var_name, default=_VAR_NOTSET): return self.get_missing_var(var_name) -def get_pytz_module_context() -> Dict[str, Any]: - context_exports = pytz.__all__ # type: ignore - - return { - name: getattr(pytz, name) for name in context_exports - } - - -def get_datetime_module_context() -> Dict[str, Any]: - context_exports = [ - 'date', - 'datetime', - 'time', - 'timedelta', - 'tzinfo' - ] - - return { - name: getattr(datetime, name) for name in context_exports - } - - -def get_context_modules() -> Dict[str, Dict[str, Any]]: - return { - 'pytz': get_pytz_module_context(), - 'datetime': get_datetime_module_context(), - } - - -def _return(value): - raise MacroReturn(value) - - -def fromjson(string, default=None): - try: - return json.loads(string) - except ValueError: - return default - - -def tojson(value, default=None, sort_keys=False): - try: - return json.dumps(value, sort_keys=sort_keys) - except ValueError: - return default - - -def fromyaml(value, default=None): - try: - return yaml.safe_load(value) - except (AttributeError, ValueError, yaml.YAMLError): - return default +class BaseContext(metaclass=ContextMeta): + def __init__(self, cli_vars): + self._ctx = {} + self.cli_vars = cli_vars + def generate_builtins(self): + builtins: Dict[str, Any] = {} + for key, value in self._context_members_.items(): + if hasattr(value, '__get__'): + # handle properties, bound methods, etc + value = value.__get__(self) + builtins[key] = value + return builtins + + def to_dict(self): + self._ctx['context'] = self._ctx + builtins = self.generate_builtins() + self._ctx['builtins'] = builtins + self._ctx.update(builtins) + return self._ctx + + @contextproperty + def dbt_version(self) -> str: + return dbt_version + + @contextproperty + def var(self) -> Var: + return Var(None, self._ctx, self.cli_vars) + + @contextmember + @staticmethod + def env_var(var: str, default: Optional[str] = None) -> str: + if var in os.environ: + return os.environ[var] + elif default is not None: + return default + else: + msg = f"Env var required but not provided: '{var}'" + undefined_error(msg) + + if os.environ.get('DBT_MACRO_DEBUGGING'): + @contextmember + @staticmethod + def debug_here(): + import sys + import ipdb # type: ignore + frame = sys._getframe(3) + ipdb.set_trace(frame) + return '' + + @contextmember('return') + @staticmethod + def _return(value: Any) -> NoReturn: + raise MacroReturn(value) + + @contextmember + @staticmethod + def fromjson(string: str, default: Any = None) -> Any: + try: + return json.loads(string) + except ValueError: + return default -# safe_dump defaults to sort_keys=True, but act like json.dumps (the opposite) -def toyaml(value, default=None, sort_keys=False): - try: - return yaml.safe_dump(data=value, sort_keys=sort_keys) - except (ValueError, yaml.YAMLError): - return default + @contextmember + @staticmethod + def tojson( + value: Any, default: Any = None, sort_keys: bool = False + ) -> Any: + try: + return json.dumps(value, sort_keys=sort_keys) + except ValueError: + return default + @contextmember + @staticmethod + def fromyaml(value: str, default: Any = None) -> Any: + try: + return yaml.safe_load(value) + except (AttributeError, ValueError, yaml.YAMLError): + return default -def log(msg, info=False): - if info: - logger.info(msg) - else: - logger.debug(msg) - return '' + # safe_dump defaults to sort_keys=True, but we act like json.dumps (the + # opposite) + @contextmember + @staticmethod + def toyaml( + value: Any, default: Optional[str] = None, sort_keys: bool = False + ) -> Optional[str]: + try: + return yaml.safe_dump(data=value, sort_keys=sort_keys) + except (ValueError, yaml.YAMLError): + return default + @contextmember + @staticmethod + def log(msg: str, info: bool = False) -> str: + if info: + logger.info(msg) + else: + logger.debug(msg) + return '' -class BaseContext: - def to_dict(self) -> Dict[str, Any]: - run_started_at = None - invocation_id = None + @contextproperty + def run_started_at(self) -> Optional[datetime.datetime]: + if tracking.active_user is not None: + return tracking.active_user.run_started_at + else: + return None - if dbt.tracking.active_user is not None: - run_started_at = dbt.tracking.active_user.run_started_at - invocation_id = dbt.tracking.active_user.invocation_id + @contextproperty + def invocation_id(self) -> Optional[str]: + if tracking.active_user is not None: + return tracking.active_user.invocation_id + else: + return None - context: Dict[str, Any] = { - 'env_var': env_var, - 'modules': get_context_modules(), - 'run_started_at': run_started_at, - 'invocation_id': invocation_id, - 'return': _return, - 'fromjson': fromjson, - 'tojson': tojson, - 'fromyaml': fromyaml, - 'toyaml': toyaml, - 'log': log, - } - if os.environ.get('DBT_MACRO_DEBUGGING'): - context['debug'] = debug_here - return context + @contextproperty + def modules(self) -> Dict[str, Any]: + return get_context_modules() + @contextproperty + def flags(self) -> Any: + return flags -class ConfigRenderContext(BaseContext): - def __init__(self, cli_vars): - self.cli_vars = cli_vars - def make_var(self, context) -> Var: - return Var(None, context, self.cli_vars) - - def to_dict(self) -> Dict[str, Any]: - context = super().to_dict() - context['var'] = self.make_var(context) - return context - - -def _add_macro_map( - context: Dict[str, Any], package_name: str, macro_map: Dict[str, Callable] -): - """Update an existing context in-place, adding the given macro map to the - appropriate package namespace. Adapter packages get inserted into the - global namespace. - """ - key = package_name - if package_name in PACKAGES: - key = GLOBAL_PROJECT_NAME - if key not in context: - value: Dict[str, Callable] = {} - context[key] = value - - context[key].update(macro_map) - - -class HasCredentialsContext(ConfigRenderContext): - def __init__(self, config): - # sometimes we only have a profile object and end up here. In those - # cases, we never want the actual cli vars passed, so we can do this. - cli_vars = getattr(config, 'cli_vars', {}) - super().__init__(cli_vars=cli_vars) - self.config = config - - def get_target(self) -> Dict[str, Any]: - target = dict( - self.config.credentials.connection_info(with_aliases=True) - ) - target.update({ - 'type': self.config.credentials.type, - 'threads': self.config.threads, - 'name': self.config.target_name, - # not specified, but present for compatibility - 'target_name': self.config.target_name, - 'profile_name': self.config.profile_name, - 'config': self.config.config.to_dict(), - }) - return target - - @property - def search_package_name(self): - return self.config.project_name - - def add_macros_from( - self, - context: Dict[str, Any], - macros: Mapping[str, ParsedMacro], - ): - global_macros: List[Dict[str, Callable]] = [] - local_macros: List[Dict[str, Callable]] = [] - - for unique_id, macro in macros.items(): - if macro.resource_type != NodeType.Macro: - continue - package_name = macro.package_name - - macro_map: Dict[str, Callable] = { - macro.name: macro.generator(context) - } - - # adapter packages are part of the global project space - _add_macro_map(context, package_name, macro_map) - - if package_name == self.search_package_name: - # If we're in the current project, allow local override - local_macros.append(macro_map) - elif package_name == self.config.project_name: - # If we're in the root project, allow global override - global_macros.append(macro_map) - elif package_name in PACKAGES: - # If it comes from a dbt package, allow global override - global_macros.append(macro_map) - - # Load global macros before local macros -- local takes precedence - for macro_map in itertools.chain(global_macros, local_macros): - context.update(macro_map) - - -class QueryHeaderContext(HasCredentialsContext): - def __init__(self, config): - super().__init__(config) - - def to_dict(self, macros: Optional[Mapping[str, ParsedMacro]] = None): - context = super().to_dict() - context['target'] = self.get_target() - context['dbt_version'] = dbt_version - if macros is not None: - self.add_macros_from(context, macros) - return context +def generate_base_context(cli_vars: Dict[str, Any]) -> Dict[str, Any]: + ctx = BaseContext(cli_vars) + return ctx.to_dict() diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py deleted file mode 100644 index d2502978e6b..00000000000 --- a/core/dbt/context/common.py +++ /dev/null @@ -1,385 +0,0 @@ -import agate -import os -from typing_extensions import Protocol -from typing import Union, Callable, Any, Dict, TypeVar, Type, Optional - -from dbt.clients import agate_helper -from dbt.contracts.graph.compiled import CompiledSeedNode, CompileResultNode -from dbt.contracts.graph.parsed import ParsedSeedNode, ParsedMacro -import dbt.exceptions -import dbt.flags -import dbt.tracking -import dbt.utils -import dbt.writer -from dbt.adapters.factory import get_adapter -from dbt.node_types import NodeType -from dbt.clients.jinja import get_rendered -from dbt.config import RuntimeConfig -from dbt.context.base import Var, HasCredentialsContext -from dbt.contracts.graph.manifest import Manifest - - -class RelationProxy: - def __init__(self, adapter): - self.quoting_config = adapter.config.quoting - self.relation_type = adapter.Relation - - def __getattr__(self, key): - return getattr(self.relation_type, key) - - def create_from_source(self, *args, **kwargs): - # bypass our create when creating from source so as not to mess up - # the source quoting - return self.relation_type.create_from_source(*args, **kwargs) - - def create(self, *args, **kwargs): - kwargs['quote_policy'] = dbt.utils.merge( - self.quoting_config, - kwargs.pop('quote_policy', {}) - ) - return self.relation_type.create(*args, **kwargs) - - -class BaseDatabaseWrapper: - """ - Wrapper for runtime database interaction. Applies the runtime quote policy - via a relation proxy. - """ - def __init__(self, adapter): - self.adapter = adapter - self.Relation = RelationProxy(adapter) - - def __getattr__(self, name): - raise NotImplementedError('subclasses need to implement this') - - @property - def config(self): - return self.adapter.config - - def type(self): - return self.adapter.type() - - def commit(self): - return self.adapter.commit_if_has_connection() - - -class BaseResolver: - def __init__(self, db_wrapper, model, config, manifest): - self.db_wrapper = db_wrapper - self.model = model - self.config = config - self.manifest = manifest - - @property - def current_project(self): - return self.config.project_name - - @property - def Relation(self): - return self.db_wrapper.Relation - - -class Config(Protocol): - def __init__(self, model, source_config): - ... - - -class Provider(Protocol): - execute: bool - Config: Type[Config] - DatabaseWrapper: Type[BaseDatabaseWrapper] - Var: Type[Var] - ref: Type[BaseResolver] - source: Type[BaseResolver] - - -class ManifestParsedContext(HasCredentialsContext): - """A context available after the manifest has been parsed.""" - def __init__(self, config, manifest): - super().__init__(config) - self.manifest = manifest - - def add_macros(self, context): - self.add_macros_from(context, self.manifest.macros) - - -def _store_result(sql_results): - def call(name, status, agate_table=None): - if agate_table is None: - agate_table = agate_helper.empty_table() - - sql_results[name] = dbt.utils.AttrDict({ - 'status': status, - 'data': agate_helper.as_matrix(agate_table), - 'table': agate_table - }) - return '' - - return call - - -def _load_result(sql_results): - def call(name): - return sql_results.get(name) - - return call - - -T = TypeVar('T') - - -def get_validation() -> dbt.utils.AttrDict: - def validate_any(*args) -> Callable[[T], None]: - def inner(value: T) -> None: - for arg in args: - if isinstance(arg, type) and isinstance(value, arg): - return - elif value == arg: - return - raise dbt.exceptions.ValidationException( - 'Expected value "{}" to be one of {}' - .format(value, ','.join(map(str, args)))) - return inner - - return dbt.utils.AttrDict({ - 'any': validate_any, - }) - - -def add_sql_handlers(context: Dict[str, Any]) -> None: - sql_results: Dict[str, Any] = {} - context['_sql_results'] = sql_results - context['store_result'] = _store_result(sql_results) - context['load_result'] = _load_result(sql_results) - - -def write(node, target_path, subdirectory): - def fn(payload): - node.build_path = dbt.writer.write_node( - node, target_path, subdirectory, payload) - return '' - - return fn - - -def render(context, node): - def fn(string): - return get_rendered(string, context, node) - - return fn - - -def try_or_compiler_error(model): - def impl(message_if_exception, func, *args, **kwargs): - try: - return func(*args, **kwargs) - except Exception: - dbt.exceptions.raise_compiler_error(message_if_exception, model) - return impl - - -# Base context collection, used for parsing configs. - - -def _build_load_agate_table( - model: Union[ParsedSeedNode, CompiledSeedNode] -) -> Callable[[], agate.Table]: - def load_agate_table(): - path = model.seed_file_path - column_types = model.config.column_types - try: - table = agate_helper.from_csv(path, text_columns=column_types) - except ValueError as e: - dbt.exceptions.raise_compiler_error(str(e)) - table.original_abspath = os.path.abspath(path) - return table - return load_agate_table - - -class ProviderContext(ManifestParsedContext): - def __init__(self, model, config, manifest, provider, source_config): - if provider is None: - raise dbt.exceptions.InternalException( - "Invalid provider given to context: {}".format(provider)) - self.model = model - super().__init__(config, manifest) - self.source_config = source_config - self.provider = provider - self.adapter = get_adapter(self.config) - self.db_wrapper = self.provider.DatabaseWrapper(self.adapter) - - @property - def search_package_name(self): - return self.model.package_name - - def add_provider_functions(self, context): - # Generate the builtin functions - builtins = { - 'ref': self.provider.ref( - self.db_wrapper, self.model, self.config, self.manifest), - 'source': self.provider.source( - self.db_wrapper, self.model, self.config, self.manifest), - 'config': self.provider.Config( - self.model, self.source_config), - 'execute': self.provider.execute - } - # Install them at .builtins - context['builtins'] = builtins - # Install each of them directly in case they're not - # clobbered by a macro. - context.update(builtins) - - def add_exceptions(self, context): - context['exceptions'] = dbt.exceptions.wrapped_exports(self.model) - - def add_default_schema_info(self, context): - context['database'] = getattr( - self.model, 'database', self.config.credentials.database - ) - context['schema'] = getattr( - self.model, 'schema', self.config.credentials.schema - ) - - def make_var(self, context) -> Var: - return self.provider.Var( - self.model, context=context, overrides=self.config.cli_vars - ) - - def insert_model_information(self, context: Dict[str, Any]) -> None: - """By default, the model information is not added to the context""" - pass - - def modify_generated_context(self, context: Dict[str, Any]) -> None: - context['validation'] = get_validation() - add_sql_handlers(context) - self.add_macros(context) - - context["write"] = write(self.model, self.config.target_path, 'run') - context["render"] = render(context, self.model) - context['context'] = context - - def to_dict(self): - target = self.get_target() - - context = super().to_dict() - - self.add_provider_functions(context) - self.add_exceptions(context) - self.add_default_schema_info(context) - - context.update({ - "adapter": self.db_wrapper, - "api": { - "Relation": self.db_wrapper.Relation, - "Column": self.adapter.Column, - }, - "column": self.adapter.Column, - 'env': target, - 'target': target, - "flags": dbt.flags, - "load_agate_table": _build_load_agate_table(self.model), - "graph": self.manifest.flat_graph, - "model": self.model.to_dict(), - "post_hooks": None, - "pre_hooks": None, - "sql": None, - "sql_now": self.adapter.date_function(), - "try_or_compiler_error": try_or_compiler_error(self.model) - }) - - self.insert_model_information(context) - - self.modify_generated_context(context) - - return context - - -class ExecuteMacroContext(ProviderContext): - """Internally, macros can be executed like nodes, with some restrictions: - - - they don't have have all values available that nodes do: - - 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing - - 'schema' does not use any 'model' information - - they can't be configured with config() directives - - the search packge is the root project, unless the macro was executed by - fully-qualified name, in which case it's the chosen package. - """ - def __init__( - self, - model: ParsedMacro, - config: RuntimeConfig, - manifest: Manifest, - provider, - search_package_name: Optional[str] - ) -> None: - super().__init__(model, config, manifest, provider, None) - if search_package_name is None: - # if the search package name isn't specified, use the root project - self._search_package_name = config.project_name - else: - self._search_package_name = search_package_name - - @property - def search_package_name(self): - return self._search_package_name - - -class ModelContext(ProviderContext): - def get_this(self): - return self.db_wrapper.Relation.create_from(self.config, self.model) - - def add_hooks(self, context): - context['pre_hooks'] = [ - h.to_dict() for h in self.model.config.pre_hook - ] - context['post_hooks'] = [ - h.to_dict() for h in self.model.config.post_hook - ] - - def insert_model_information(self, context): - # operations (hooks) don't get a 'this' - if self.model.resource_type != NodeType.Operation: - context['this'] = self.get_this() - # overwrite schema/database if we have them, and hooks + sql - # the hooks should come in as dicts, at least for the `run_hooks` macro - # TODO: do we have to preserve this as backwards a compatibility thing? - self.add_default_schema_info(context) - self.add_hooks(context) - context['sql'] = getattr(self.model, 'injected_sql', None) - - -def generate_execute_macro( - model: ParsedMacro, - config: RuntimeConfig, - manifest: Manifest, - provider: Provider, - package_name: Optional[str], -) -> Dict[str, Any]: - """Internally, macros can be executed like nodes, with some restrictions: - - - they don't have have all values available that nodes do: - - 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing - - 'schema' does not use any 'model' information - - they can't be configured with config() directives - """ - ctx = ExecuteMacroContext( - model, config, manifest, provider, package_name - ) - return ctx.to_dict() - - -def generate( - model: CompileResultNode, - config: RuntimeConfig, - manifest: Manifest, - provider: Provider, - source_config=None, -) -> Dict[str, Any]: - """ - Not meant to be called directly. Call with either: - dbt.context.parser.generate - or - dbt.context.runtime.generate - """ - ctx = ModelContext(model, config, manifest, provider, source_config) - return ctx.to_dict() diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py new file mode 100644 index 00000000000..6260a96cd13 --- /dev/null +++ b/core/dbt/context/configured.py @@ -0,0 +1,111 @@ +from typing import Callable, Any, Dict, Iterable, Union + +from dbt.contracts.connection import AdapterRequiredConfig +from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.parsed import ParsedMacro +from dbt.include.global_project import PACKAGES +from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME + +from dbt.context.base import contextproperty +from dbt.context.target import TargetContext + + +class ConfiguredContext(TargetContext): + config: AdapterRequiredConfig + + def __init__( + self, config: AdapterRequiredConfig + ) -> None: + super().__init__(config, config.cli_vars) + + @contextproperty + def project_name(self) -> str: + return self.config.project_name + + +class _MacroNamespace: + def __init__(self, root_package, search_package): + self.root_package = root_package + self.search_package = search_package + self.globals: Dict[str, Callable] = {} + self.locals: Dict[str, Callable] = {} + self.packages: Dict[str, Dict[str, Callable]] = {} + + def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]): + macro_name: str = macro.name + macro_func: Callable = macro.generator(ctx) + + # put plugin macros into the nam + if macro.package_name in PACKAGES: + namespace: str = GLOBAL_PROJECT_NAME + else: + namespace = macro.package_name + + if namespace not in self.packages: + value: Dict[str, Callable] = {} + self.packages[namespace] = value + + # TODO: if macro_name exists already, that means you had duplicate + # names in the same namespace (probably a plugin vs core/multiple + # plugins issue). That should be an error, right? + self.packages[namespace][macro_name] = macro_func + + if namespace == self.search_package: + self.locals[macro_name] = macro_func + elif namespace in {self.root_package, GLOBAL_PROJECT_NAME}: + self.globals[macro_name] = macro_func + + def add_macros(self, macros: Iterable[ParsedMacro], ctx: Dict[str, Any]): + for macro in macros: + self.add_macro(macro, ctx) + + def get_macro_dict(self) -> Dict[str, Any]: + root_namespace: Dict[str, Union[Dict[str, Callable], Callable]] = {} + + root_namespace.update(self.packages) + root_namespace.update(self.globals) + root_namespace.update(self.locals) + + return root_namespace + + +class ManifestContext(ConfiguredContext): + """The Macro context has everything in the target context, plus the macros + in the manifest. + + The given macros can override any previous context values, which will be + available as if they were accessed relative to the package name. + """ + def __init__( + self, + config: AdapterRequiredConfig, + manifest: Manifest, + search_package: str, + ) -> None: + super().__init__(config) + self.manifest = manifest + self.search_package = search_package + + def get_macros(self) -> Dict[str, Any]: + nsp = _MacroNamespace(self.config.project_name, self.search_package) + nsp.add_macros(self.manifest.macros.values(), self._ctx) + return nsp.get_macro_dict() + + def to_dict(self) -> Dict[str, Any]: + dct = super().to_dict() + dct.update(self.get_macros()) + return dct + + +class QueryHeaderContext(ManifestContext): + def __init__( + self, config: AdapterRequiredConfig, manifest: Manifest + ) -> None: + super().__init__(config, manifest, config.project_name) + + +def generate_query_header_context( + config: AdapterRequiredConfig, manifest: Manifest +): + ctx = QueryHeaderContext(config, manifest) + return ctx.to_dict() diff --git a/core/dbt/context/docs.py b/core/dbt/context/docs.py new file mode 100644 index 00000000000..6bed0600b09 --- /dev/null +++ b/core/dbt/context/docs.py @@ -0,0 +1,105 @@ +from typing import ( + Any, Optional, List, Dict, Union +) + +from dbt.exceptions import ( + doc_invalid_args, + doc_target_not_found, +) +from dbt.config.runtime import RuntimeConfig +from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.parsed import Docref, ParsedMacro + +from dbt.context.base import contextmember +from dbt.context.configured import ConfiguredContext + + +class DocsParseContext(ConfiguredContext): + def __init__( + self, + config: RuntimeConfig, + node: Any, + docrefs: List[Docref], + column_name: Optional[str], + ) -> None: + super().__init__(config) + self.node = node + self.docrefs = docrefs + self.column_name = column_name + + @contextmember + def doc(self, *args: str) -> str: + # when you call doc(), this is what happens at parse time + if len(args) != 1 and len(args) != 2: + doc_invalid_args(self.node, args) + doc_package_name = '' + doc_name = args[0] + if len(args) == 2: + doc_package_name = args[1] + + docref = Docref(documentation_package=doc_package_name, + documentation_name=doc_name, + column_name=self.column_name) + self.docrefs.append(docref) + + # At parse time, nothing should care about what doc() returns + return '' + + +class DocsRuntimeContext(ConfiguredContext): + def __init__( + self, + config: RuntimeConfig, + node: Union[ParsedMacro, CompileResultNode], + manifest: Manifest, + current_project: str, + ) -> None: + super().__init__(config) + self.node = node + self.manifest = manifest + self.current_project = current_project + + @contextmember + def doc(self, *args: str) -> str: + # when you call doc(), this is what happens at runtime + if len(args) == 1: + doc_package_name = None + doc_name = args[0] + elif len(args) == 2: + doc_package_name, doc_name = args + else: + doc_invalid_args(self.node, args) + + target_doc = self.manifest.resolve_doc( + doc_name, + doc_package_name, + self.current_project, + self.node.package_name, + ) + + if target_doc is None: + doc_target_not_found(self.node, doc_name, doc_package_name) + + return target_doc.block_contents + + +def generate_parser_docs( + config: RuntimeConfig, + unparsed: Any, + docrefs: List[Docref], + column_name: Optional[str] = None, +) -> Dict[str, Any]: + + ctx = DocsParseContext(config, unparsed, docrefs, column_name) + return ctx.to_dict() + + +def generate_runtime_docs( + config: RuntimeConfig, + target: Any, + manifest: Manifest, + current_project: str, +) -> Dict[str, Any]: + ctx = DocsRuntimeContext(config, target, manifest, current_project) + return ctx.to_dict() diff --git a/core/dbt/context/operation.py b/core/dbt/context/operation.py deleted file mode 100644 index 97a461fdfe0..00000000000 --- a/core/dbt/context/operation.py +++ /dev/null @@ -1,39 +0,0 @@ -from typing import Optional, Dict, Any - -from dbt.context import runtime -from dbt.context.common import generate_execute_macro -from dbt.exceptions import raise_compiler_error -from dbt.config import RuntimeConfig -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.parsed import ParsedMacro - - -class RefResolver(runtime.RefResolver): - def __call__(self, *args): - # When you call ref(), this is what happens at operation runtime - target_model, name = self.resolve(args) - return self.create_relation(target_model, name) - - def create_ephemeral_relation(self, target_model, name): - # In operations, we can't ref() ephemeral nodes, because ParsedMacros - # do not support set_cte - raise_compiler_error( - 'Operations can not ref() ephemeral nodes, but {} is ephemeral' - .format(target_model.name), - self.model - ) - - -class Provider(runtime.Provider): - ref = RefResolver - - -def generate( - model: ParsedMacro, - runtime_config: RuntimeConfig, - manifest: Manifest, - package_name: Optional[str] -) -> Dict[str, Any]: - return generate_execute_macro( - model, runtime_config, manifest, Provider(), package_name - ) diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py deleted file mode 100644 index c7b1ed73bfd..00000000000 --- a/core/dbt/context/parser.py +++ /dev/null @@ -1,150 +0,0 @@ -import dbt.exceptions - -import dbt.context.common -from dbt.adapters.factory import get_adapter -from dbt.contracts.graph.parsed import Docref - - -def docs(unparsed, docrefs, column_name=None): - - def do_docs(*args): - if len(args) != 1 and len(args) != 2: - dbt.exceptions.doc_invalid_args(unparsed, args) - doc_package_name = '' - doc_name = args[0] - if len(args) == 2: - doc_package_name = args[1] - - docref = Docref(documentation_package=doc_package_name, - documentation_name=doc_name, - column_name=column_name) - docrefs.append(docref) - - # At parse time, nothing should care about what doc() returns - return '' - - return do_docs - - -class Config: - def __init__(self, model, source_config): - self.model = model - self.source_config = source_config - - def _transform_config(self, config): - for oldkey in ('pre_hook', 'post_hook'): - if oldkey in config: - newkey = oldkey.replace('_', '-') - if newkey in config: - dbt.exceptions.raise_compiler_error( - 'Invalid config, has conflicting keys "{}" and "{}"' - .format(oldkey, newkey), - self.model - ) - config[newkey] = config.pop(oldkey) - return config - - def __call__(self, *args, **kwargs): - if len(args) == 1 and len(kwargs) == 0: - opts = args[0] - elif len(args) == 0 and len(kwargs) > 0: - opts = kwargs - else: - dbt.exceptions.raise_compiler_error( - "Invalid inline model config", - self.model) - - opts = self._transform_config(opts) - - self.source_config.update_in_model_config(opts) - return '' - - def set(self, name, value): - return self.__call__({name: value}) - - def require(self, name, validator=None): - return '' - - def get(self, name, validator=None, default=None): - return '' - - -class DatabaseWrapper(dbt.context.common.BaseDatabaseWrapper): - """The parser subclass of the database wrapper applies any explicit - parse-time overrides. - """ - def __getattr__(self, name): - override = (name in self.adapter._available_ and - name in self.adapter._parse_replacements_) - - if override: - return self.adapter._parse_replacements_[name] - elif name in self.adapter._available_: - return getattr(self.adapter, name) - else: - raise AttributeError( - "'{}' object has no attribute '{}'".format( - self.__class__.__name__, name - ) - ) - - -class Var(dbt.context.base.Var): - def get_missing_var(self, var_name): - # in the parser, just always return None. - return None - - -class RefResolver(dbt.context.common.BaseResolver): - def __call__(self, *args): - # When you call ref(), this is what happens at parse time - if len(args) == 1 or len(args) == 2: - self.model.refs.append(list(args)) - - else: - dbt.exceptions.ref_invalid_args(self.model, args) - - return self.Relation.create_from(self.config, self.model) - - -class SourceResolver(dbt.context.common.BaseResolver): - def __call__(self, *args): - # When you call source(), this is what happens at parse time - if len(args) == 2: - self.model.sources.append(list(args)) - - else: - dbt.exceptions.raise_compiler_error( - "source() takes exactly two arguments ({} given)" - .format(len(args)), self.model) - - return self.Relation.create_from(self.config, self.model) - - -class Provider(dbt.context.common.Provider): - execute = False - Config = Config - DatabaseWrapper = DatabaseWrapper - Var = Var - ref = RefResolver - source = SourceResolver - - -def generate(model, runtime_config, manifest, source_config): - # during parsing, we don't have a connection, but we might need one, so we - # have to acquire it. - # In the future, it would be nice to lazily open the connection, as in some - # projects it would be possible to parse without connecting to the db - with get_adapter(runtime_config).connection_for(model): - return dbt.context.common.generate( - model, runtime_config, manifest, Provider(), source_config - ) - - -def generate_macro(model, runtime_config, manifest, package_name): - # parser.generate_macro is called by the get_${attr}_func family of Parser - # methods, which preparse and cache the generate_${attr}_name family of - # macros for use during parsing - return dbt.context.common.generate_execute_macro( - model, runtime_config, manifest, Provider(), package_name - ) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py new file mode 100644 index 00000000000..2cd90b86d89 --- /dev/null +++ b/core/dbt/context/providers.py @@ -0,0 +1,735 @@ +import abc +import os +from typing import ( + Callable, Any, Dict, Optional, Union, List, TypeVar, Type +) +from typing_extensions import Protocol + + +from dbt.adapters.base.column import Column +from dbt.adapters.factory import get_adapter +from dbt.clients import agate_helper +from dbt.clients.jinja import get_rendered +from dbt.config import RuntimeConfig +from dbt.context.base import ( + contextmember, contextproperty, Var +) +from dbt.context.configured import ManifestContext +from dbt.contracts.graph.manifest import Manifest, Disabled +from dbt.contracts.graph.compiled import ( + NonSourceNode, CompiledSeedNode +) +from dbt.contracts.graph.parsed import ( + ParsedMacro, ParsedSourceDefinition, ParsedSeedNode +) +from dbt.exceptions import ( + InternalException, + ValidationException, + missing_config, + raise_compiler_error, + ref_invalid_args, + ref_target_not_found, + ref_bad_context, + source_target_not_found, + wrapped_exports, +) +from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.node_types import NodeType +from dbt.source_config import SourceConfig + +from dbt.utils import ( + get_materialization, add_ephemeral_model_prefix, merge, AttrDict +) + +import agate + + +_MISSING = object() + + +# base classes +class RelationProxy: + def __init__(self, adapter): + self.quoting_config = adapter.config.quoting + self.relation_type = adapter.Relation + + def __getattr__(self, key): + return getattr(self.relation_type, key) + + def create_from_source(self, *args, **kwargs): + # bypass our create when creating from source so as not to mess up + # the source quoting + return self.relation_type.create_from_source(*args, **kwargs) + + def create(self, *args, **kwargs): + kwargs['quote_policy'] = merge( + self.quoting_config, + kwargs.pop('quote_policy', {}) + ) + return self.relation_type.create(*args, **kwargs) + + +class BaseDatabaseWrapper: + """ + Wrapper for runtime database interaction. Applies the runtime quote policy + via a relation proxy. + """ + def __init__(self, adapter): + self.adapter = adapter + self.Relation = RelationProxy(adapter) + + def __getattr__(self, name): + raise NotImplementedError('subclasses need to implement this') + + @property + def config(self): + return self.adapter.config + + def type(self): + return self.adapter.type() + + def commit(self): + return self.adapter.commit_if_has_connection() + + +class BaseResolver(metaclass=abc.ABCMeta): + def __init__(self, db_wrapper, model, config, manifest): + self.db_wrapper = db_wrapper + self.model = model + self.config = config + self.manifest = manifest + + @property + def current_project(self): + return self.config.project_name + + @property + def Relation(self): + return self.db_wrapper.Relation + + @abc.abstractmethod + def __call__(self, *args: str) -> Union[str, RelationProxy]: + pass + + +class BaseRefResolver(BaseResolver): + @abc.abstractmethod + def resolve( + self, name: str, package: Optional[str] = None + ) -> RelationProxy: + ... + + def _repack_args( + self, name: str, package: Optional[str] + ) -> List[str]: + if package is None: + return [name] + else: + return [package, name] + + def __call__(self, *args: str) -> RelationProxy: + name: str + package: Optional[str] = None + + if len(args) == 1: + name = args[0] + elif len(args) == 2: + package, name = args + else: + ref_invalid_args(self.model, args) + return self.resolve(name, package) + + +class BaseSourceResolver(BaseResolver): + @abc.abstractmethod + def resolve(self, source_name: str, table_name: str): + pass + + def __call__(self, *args: str) -> RelationProxy: + if len(args) != 2: + raise_compiler_error( + f"source() takes exactly two arguments ({len(args)} given)", + self.model + ) + return self.resolve(args[0], args[1]) + + +class Config(Protocol): + def __init__(self, model, source_config): + ... + + +class Provider(Protocol): + execute: bool + Config: Type[Config] + DatabaseWrapper: Type[BaseDatabaseWrapper] + Var: Type[Var] + ref: Type[BaseRefResolver] + source: Type[BaseSourceResolver] + + +# `config` implementations +class ParseConfigObject(Config): + def __init__(self, model, source_config): + self.model = model + self.source_config = source_config + + def _transform_config(self, config): + for oldkey in ('pre_hook', 'post_hook'): + if oldkey in config: + newkey = oldkey.replace('_', '-') + if newkey in config: + raise_compiler_error( + 'Invalid config, has conflicting keys "{}" and "{}"' + .format(oldkey, newkey), + self.model + ) + config[newkey] = config.pop(oldkey) + return config + + def __call__(self, *args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + opts = args[0] + elif len(args) == 0 and len(kwargs) > 0: + opts = kwargs + else: + raise_compiler_error( + "Invalid inline model config", + self.model) + + opts = self._transform_config(opts) + + self.source_config.update_in_model_config(opts) + return '' + + def set(self, name, value): + return self.__call__({name: value}) + + def require(self, name, validator=None): + return '' + + def get(self, name, validator=None, default=None): + return '' + + +class RuntimeConfigObject(Config): + def __init__(self, model, source_config=None): + self.model = model + # we never use or get a source config, only the parser cares + + def __call__(self, *args, **kwargs): + return '' + + def set(self, name, value): + return self.__call__({name: value}) + + def _validate(self, validator, value): + validator(value) + + def _lookup(self, name, default=_MISSING): + config = self.model.config + + if hasattr(config, name): + return getattr(config, name) + elif name in config.extra: + return config.extra[name] + elif default is not _MISSING: + return default + else: + missing_config(self.model, name) + + def require(self, name, validator=None): + to_return = self._lookup(name) + + if validator is not None: + self._validate(validator, to_return) + + return to_return + + def get(self, name, validator=None, default=None): + to_return = self._lookup(name, default) + + if validator is not None and default is not None: + self._validate(validator, to_return) + + return to_return + + +# `adapter` implementations +class ParseDatabaseWrapper(BaseDatabaseWrapper): + """The parser subclass of the database wrapper applies any explicit + parse-time overrides. + """ + def __getattr__(self, name): + override = (name in self.adapter._available_ and + name in self.adapter._parse_replacements_) + + if override: + return self.adapter._parse_replacements_[name] + elif name in self.adapter._available_: + return getattr(self.adapter, name) + else: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, name + ) + ) + + +class RuntimeDatabaseWrapper(BaseDatabaseWrapper): + """The runtime database wrapper exposes everything the adapter marks + available. + """ + def __getattr__(self, name): + if name in self.adapter._available_: + return getattr(self.adapter, name) + else: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, name + ) + ) + + +# `ref` implementations +class ParseRefResolver(BaseRefResolver): + def resolve( + self, name: str, package: Optional[str] = None + ) -> RelationProxy: + self.model.refs.append(self._repack_args(name, package)) + + return self.Relation.create_from(self.config, self.model) + + +ResolveRef = Union[Disabled, NonSourceNode] + + +class RuntimeRefResolver(BaseRefResolver): + def resolve( + self, target_name: str, target_package: Optional[str] = None + ) -> RelationProxy: + target_model = self.manifest.resolve_ref( + target_name, + target_package, + self.current_project, + self.model.package_name, + ) + + if target_model is None or isinstance(target_model, Disabled): + ref_target_not_found( + self.model, + target_name, + target_package, + ) + self.validate(target_model, target_name, target_package) + return self.create_relation(target_model, target_name) + + def create_ephemeral_relation( + self, target_model: NonSourceNode, name: str + ) -> RelationProxy: + self.model.set_cte(target_model.unique_id, None) + return self.Relation.create( + type=self.Relation.CTE, + identifier=add_ephemeral_model_prefix(name) + ).quote(identifier=False) + + def create_relation( + self, target_model: NonSourceNode, name: str + ) -> RelationProxy: + if get_materialization(target_model) == 'ephemeral': + return self.create_ephemeral_relation(target_model, name) + else: + return self.Relation.create_from(self.config, target_model) + + def validate( + self, + resolved: NonSourceNode, + target_name: str, + target_package: Optional[str] + ) -> None: + if resolved.unique_id not in self.model.depends_on.nodes: + args = self._repack_args(target_name, target_package) + ref_bad_context(self.model, args) + + +class OperationRefResolver(RuntimeRefResolver): + def validate( + self, + resolved: NonSourceNode, + target_name: str, + target_package: Optional[str], + ) -> None: + pass + + def create_ephemeral_relation( + self, target_model: NonSourceNode, name: str + ) -> RelationProxy: + # In operations, we can't ref() ephemeral nodes, because ParsedMacros + # do not support set_cte + raise_compiler_error( + 'Operations can not ref() ephemeral nodes, but {} is ephemeral' + .format(target_model.name), + self.model + ) + + +# `source` implementations +class ParseSourceResolver(BaseSourceResolver): + def resolve(self, source_name: str, table_name: str): + # When you call source(), this is what happens at parse time + self.model.sources.append([source_name, table_name]) + return self.Relation.create_from(self.config, self.model) + + +class RuntimeSourceResolver(BaseSourceResolver): + def resolve(self, source_name: str, table_name: str): + target_source = self.manifest.resolve_source( + source_name, + table_name, + self.current_project, + self.model.package_name, + ) + + if target_source is None: + source_target_not_found( + self.model, + source_name, + table_name, + ) + return self.Relation.create_from_source(target_source) + + +# `var` implementations. +class ParseVar(Var): + def get_missing_var(self, var_name): + # in the parser, just always return None. + return None + + +class RuntimeVar(Var): + pass + + +# Providers +class ParseProvider(Provider): + execute = False + Config = ParseConfigObject + DatabaseWrapper = ParseDatabaseWrapper + Var = ParseVar + ref = ParseRefResolver + source = ParseSourceResolver + + +class RuntimeProvider(Provider): + execute = True + Config = RuntimeConfigObject + DatabaseWrapper = RuntimeDatabaseWrapper + Var = RuntimeVar + ref = RuntimeRefResolver + source = RuntimeSourceResolver + + +class OperationProvider(RuntimeProvider): + ref = OperationRefResolver + + +T = TypeVar('T') + + +# Base context collection, used for parsing configs. +class ProviderContext(ManifestContext): + def __init__(self, model, config, manifest, provider, source_config): + if provider is None: + raise InternalException( + f"Invalid provider given to context: {provider}" + ) + super().__init__(config, manifest, model.package_name) + self.sql_results: Dict[str, AttrDict] = {} + self.model: Union[ParsedMacro, NonSourceNode] = model + self.source_config = source_config + self.provider: Provider = provider + self.adapter = get_adapter(self.config) + self.db_wrapper = self.provider.DatabaseWrapper(self.adapter) + + @contextproperty + def _sql_results(self) -> Dict[str, AttrDict]: + return self.sql_results + + @contextmember + def load_result(self, name: str) -> Optional[AttrDict]: + return self.sql_results.get(name) + + @contextmember + def store_result( + self, name: str, status: Any, agate_table: Optional[agate.Table] = None + ) -> str: + if agate_table is None: + agate_table = agate_helper.empty_table() + + self.sql_results[name] = AttrDict({ + 'status': status, + 'data': agate_helper.as_matrix(agate_table), + 'table': agate_table + }) + return '' + + @contextproperty + def validation(self): + def validate_any(*args) -> Callable[[T], None]: + def inner(value: T) -> None: + for arg in args: + if isinstance(arg, type) and isinstance(value, arg): + return + elif value == arg: + return + raise ValidationException( + 'Expected value "{}" to be one of {}' + .format(value, ','.join(map(str, args)))) + return inner + + return AttrDict({ + 'any': validate_any, + }) + + @contextmember + def write(self, payload: str) -> str: + # macros/source defs aren't 'writeable'. + if isinstance(self.model, (ParsedMacro, ParsedSourceDefinition)): + raise_compiler_error( + 'cannot "write" macros or sources' + ) + self.model.build_path = self.model.write_node( + self.config.target_path, 'run', payload + ) + return '' + + @contextmember + def render(self, string: str) -> str: + return get_rendered(string, self._ctx, self.model) + + @contextmember + def try_or_compiler_error( + self, message_if_exception: str, func: Callable, *args, **kwargs + ) -> Any: + try: + return func(*args, **kwargs) + except Exception: + raise_compiler_error( + message_if_exception, self.model + ) + + @contextmember + def load_agate_table(self) -> agate.Table: + if not isinstance(self.model, (ParsedSeedNode, CompiledSeedNode)): + raise_compiler_error( + 'can only load_agate_table for seeds (got a {})' + .format(self.model.resource_type) + ) + path = self.model.seed_file_path + column_types = self.model.config.column_types + try: + table = agate_helper.from_csv(path, text_columns=column_types) + except ValueError as e: + raise_compiler_error(str(e)) + table.original_abspath = os.path.abspath(path) + return table + + @contextproperty + def ref(self) -> Callable: + return self.provider.ref( + self.db_wrapper, self.model, self.config, self.manifest + ) + + @contextproperty + def source(self) -> Callable: + return self.provider.source( + self.db_wrapper, self.model, self.config, self.manifest + ) + + @contextproperty('config') + def ctx_config(self) -> Config: + return self.provider.Config(self.model, self.source_config) + + @contextproperty + def execute(self) -> bool: + return self.provider.execute + + @contextproperty + def exceptions(self) -> Dict[str, Any]: + return wrapped_exports(self.model) + + @contextproperty + def database(self) -> str: + return self.config.credentials.database + + @contextproperty + def schema(self) -> str: + return self.config.credentials.schema + + @contextproperty + def var(self) -> Var: + return self.provider.Var( + self.model, context=self._ctx, overrides=self.config.cli_vars + ) + + @contextproperty('adapter') + def ctx_adapter(self) -> BaseDatabaseWrapper: + return self.db_wrapper + + @contextproperty + def api(self) -> Dict[str, Any]: + return { + 'Relation': self.db_wrapper.Relation, + 'Column': self.adapter.Column, + } + + @contextproperty + def column(self) -> Type[Column]: + return self.adapter.Column + + @contextproperty + def env(self) -> Dict[str, Any]: + return self.target + + @contextproperty + def graph(self) -> Dict[str, Any]: + return self.manifest.flat_graph + + @contextproperty('model') + def ctx_model(self) -> Dict[str, Any]: + return self.model.to_dict() + + @contextproperty + def pre_hooks(self) -> Optional[List[Dict[str, Any]]]: + return None + + @contextproperty + def post_hooks(self) -> Optional[List[Dict[str, Any]]]: + return None + + @contextproperty + def sql(self) -> Optional[str]: + return None + + @contextproperty + def sql_now(self) -> str: + return self.adapter.date_function() + + +class MacroContext(ProviderContext): + """Internally, macros can be executed like nodes, with some restrictions: + + - they don't have have all values available that nodes do: + - 'this', 'pre_hooks', 'post_hooks', and 'sql' are missing + - 'schema' does not use any 'model' information + - they can't be configured with config() directives + """ + def __init__( + self, + model: ParsedMacro, + config: RuntimeConfig, + manifest: Manifest, + provider: Provider, + search_package: Optional[str], + ) -> None: + super().__init__(model, config, manifest, provider, None) + # override the model-based package with the given one + if search_package is None: + # if the search package name isn't specified, use the root project + self._search_package = config.project_name + else: + self._search_package = search_package + + +class ModelContext(ProviderContext): + model: NonSourceNode + + @contextproperty + def pre_hooks(self) -> List[Dict[str, Any]]: + if isinstance(self.model, ParsedSourceDefinition): + return [] + return [ + h.to_dict() for h in self.model.config.pre_hook + ] + + @contextproperty + def post_hooks(self) -> List[Dict[str, Any]]: + if isinstance(self.model, ParsedSourceDefinition): + return [] + return [ + h.to_dict() for h in self.model.config.post_hook + ] + + @contextproperty + def sql(self) -> Optional[str]: + return getattr(self.model, 'injected_sql', None) + + @contextproperty + def database(self) -> str: + return getattr( + self.model, 'database', self.config.credentials.database + ) + + @contextproperty + def schema(self) -> str: + return getattr( + self.model, 'schema', self.config.credentials.schema + ) + + @contextproperty + def this(self) -> Optional[RelationProxy]: + if self.model.resource_type == NodeType.Operation: + return None + return self.db_wrapper.Relation.create_from(self.config, self.model) + + +def generate_parser_model( + model: NonSourceNode, + config: RuntimeConfig, + manifest: Manifest, + source_config: SourceConfig, +) -> Dict[str, Any]: + # during parsing, we don't have a connection, but we might need one, so we + # have to acquire it. + # In the future, it would be nice to lazily open the connection, as in some + # projects it would be possible to parse without connecting to the db + ctx = ModelContext( + model, config, manifest, ParseProvider(), source_config + ) + return ctx.to_dict() + + +def generate_parser_macro( + macro: ParsedMacro, + config: RuntimeConfig, + manifest: Manifest, + package_name: Optional[str], +) -> Dict[str, Any]: + ctx = MacroContext( + macro, config, manifest, ParseProvider(), package_name + ) + return ctx.to_dict() + + +def generate_runtime_model( + model: NonSourceNode, + config: RuntimeConfig, + manifest: Manifest, +) -> Dict[str, Any]: + ctx = ModelContext( + model, config, manifest, RuntimeProvider(), None + ) + return ctx.to_dict() + + +def generate_runtime_macro( + macro: ParsedMacro, + config: RuntimeConfig, + manifest: Manifest, + package_name: Optional[str], +) -> Dict[str, Any]: + + ctx = MacroContext( + macro, config, manifest, OperationProvider(), package_name + ) + return ctx.to_dict() diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py deleted file mode 100644 index 39fdeac9ec6..00000000000 --- a/core/dbt/context/runtime.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import Dict, Any - - -import dbt.clients.jinja -import dbt.context.base -import dbt.context.common -import dbt.flags - -from dbt.config import RuntimeConfig -from dbt.contracts.graph.compiled import CompileResultNode -from dbt.contracts.graph.manifest import Manifest -from dbt.logger import GLOBAL_LOGGER as logger # noqa -from dbt.parser.util import ParserUtils -from dbt.utils import get_materialization, add_ephemeral_model_prefix - - -class RefResolver(dbt.context.common.BaseResolver): - def resolve(self, args): - name = None - package = None - - if len(args) == 1: - name = args[0] - elif len(args) == 2: - package, name = args - else: - dbt.exceptions.ref_invalid_args(self.model, args) - - target_model = ParserUtils.resolve_ref( - self.manifest, - name, - package, - self.current_project, - self.model.package_name) - - if target_model is None or target_model is ParserUtils.DISABLED: - dbt.exceptions.ref_target_not_found( - self.model, - name, - package) - return target_model, name - - def create_ephemeral_relation(self, target_model, name): - self.model.set_cte(target_model.unique_id, None) - return self.Relation.create( - type=self.Relation.CTE, - identifier=add_ephemeral_model_prefix(name) - ).quote(identifier=False) - - def create_relation(self, target_model, name): - if get_materialization(target_model) == 'ephemeral': - return self.create_ephemeral_relation(target_model, name) - else: - return self.Relation.create_from(self.config, target_model) - - def validate(self, resolved, args): - if resolved.unique_id not in self.model.depends_on.nodes: - dbt.exceptions.ref_bad_context(self.model, args) - - def __call__(self, *args): - # When you call ref(), this is what happens at runtime - target_model, name = self.resolve(args) - self.validate(target_model, args) - return self.create_relation(target_model, name) - - -class SourceResolver(dbt.context.common.BaseResolver): - def resolve(self, source_name, table_name): - target_source = ParserUtils.resolve_source( - self.manifest, - source_name, - table_name, - self.current_project, - self.model.package_name - ) - - if target_source is None: - dbt.exceptions.source_target_not_found( - self.model, - source_name, - table_name) - return target_source - - def __call__(self, source_name, table_name): - """When you call source(), this is what happens at runtime""" - target_source = self.resolve(source_name, table_name) - return self.Relation.create_from_source(target_source) - - -_MISSING = object() - - -class Config: - def __init__(self, model, source_config=None): - self.model = model - # we never use or get a source config, only the parser cares - - def __call__(self, *args, **kwargs): - return '' - - def set(self, name, value): - return self.__call__({name: value}) - - def _validate(self, validator, value): - validator(value) - - def _lookup(self, name, default=_MISSING): - config = self.model.config - - if hasattr(config, name): - return getattr(config, name) - elif name in config.extra: - return config.extra[name] - elif default is not _MISSING: - return default - else: - dbt.exceptions.missing_config(self.model, name) - - def require(self, name, validator=None): - to_return = self._lookup(name) - - if validator is not None: - self._validate(validator, to_return) - - return to_return - - def get(self, name, validator=None, default=None): - to_return = self._lookup(name, default) - - if validator is not None and default is not None: - self._validate(validator, to_return) - - return to_return - - -class DatabaseWrapper(dbt.context.common.BaseDatabaseWrapper): - """The runtime database wrapper exposes everything the adapter marks - available. - """ - def __getattr__(self, name): - if name in self.adapter._available_: - return getattr(self.adapter, name) - else: - raise AttributeError( - "'{}' object has no attribute '{}'".format( - self.__class__.__name__, name - ) - ) - - -class Var(dbt.context.base.Var): - pass - - -class Provider(dbt.context.common.Provider): - execute = True - Config = Config - DatabaseWrapper = DatabaseWrapper - Var = Var - ref = RefResolver - source = SourceResolver - - -def generate( - model: CompileResultNode, runtime_config: RuntimeConfig, manifest: Manifest -) -> Dict[str, Any]: - return dbt.context.common.generate( - model, runtime_config, manifest, Provider(), None - ) diff --git a/core/dbt/context/target.py b/core/dbt/context/target.py new file mode 100644 index 00000000000..489ac42ab26 --- /dev/null +++ b/core/dbt/context/target.py @@ -0,0 +1,36 @@ +from typing import Any, Dict + +from dbt.contracts.connection import HasCredentials + +from dbt.context.base import ( + BaseContext, contextproperty +) + + +class TargetContext(BaseContext): + def __init__(self, config: HasCredentials, cli_vars: Dict[str, Any]): + super().__init__(cli_vars=cli_vars) + self.config = config + + @contextproperty + def target(self) -> Dict[str, Any]: + target = dict( + self.config.credentials.connection_info(with_aliases=True) + ) + target.update({ + 'type': self.config.credentials.type, + 'threads': self.config.threads, + 'name': self.config.target_name, + # not specified, but present for compatibility + 'target_name': self.config.target_name, + 'profile_name': self.config.profile_name, + 'config': self.config.config.to_dict(), + }) + return target + + +def generate_target_context( + config: HasCredentials, cli_vars: Dict[str, Any] +) -> Dict[str, Any]: + ctx = TargetContext(config, cli_vars) + return ctx.to_dict() diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 4dad8a97dcf..e6214771537 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -2,7 +2,8 @@ import itertools from dataclasses import dataclass, field from typing import ( - Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable + Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List, Callable, + Union ) from typing_extensions import Protocol @@ -13,6 +14,7 @@ from dbt.contracts.util import Replaceable from dbt.exceptions import InternalException +from dbt.helper_types import NoValue from dbt.utils import translate_aliases @@ -150,9 +152,31 @@ def to_dict(self, omit_none=True, validate=False, *, with_aliases=False): return serialized +class UserConfigContract(Protocol): + send_anonymous_usage_stats: bool + use_colors: bool + partial_parse: Optional[bool] + printer_width: Optional[int] + + def set_values(self, cookie_dir: str) -> None: + ... + + def to_dict( + self, omit_none: bool = True, validate: bool = False + ) -> Dict[str, Any]: + ... + + class HasCredentials(Protocol): credentials: Credentials + profile_name: str + config: UserConfigContract + target_name: str + threads: int -class AdapterRequiredConfig(HasCredentials): - query_comment: Optional[str] +class AdapterRequiredConfig(HasCredentials, Protocol): + project_name: str + query_comment: Optional[Union[str, NoValue]] + cli_vars: Dict[str, Any] + target_path: str diff --git a/core/dbt/contracts/graph/compiled.py b/core/dbt/contracts/graph/compiled.py index a5f9637a66c..de2f52517d5 100644 --- a/core/dbt/contracts/graph/compiled.py +++ b/core/dbt/contracts/graph/compiled.py @@ -212,10 +212,8 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource: return cls.from_dict(compiled.to_dict(), validate=False) -# We allow either parsed or compiled nodes, or parsed sources, as some -# 'compile()' calls in the runner actually just return the original parsed -# node they were given. -CompileResultNode = Union[ +# This is anything that can be in manifest.nodes and isn't a Source. +NonSourceNode = Union[ CompiledAnalysisNode, CompiledModelNode, CompiledHookNode, @@ -229,6 +227,13 @@ def parsed_instance_for(compiled: CompiledNode) -> ParsedResource: ParsedRPCNode, ParsedSeedNode, ParsedSnapshotNode, - ParsedSourceDefinition, ParsedTestNode, ] + +# We allow either parsed or compiled nodes, or parsed sources, as some +# 'compile()' calls in the runner actually just return the original parsed +# node they were given. +CompileResultNode = Union[ + NonSourceNode, + ParsedSourceDefinition, +] diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 6d95455e98f..ed7732c88b8 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -17,7 +17,7 @@ ParsedNode, ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch, ParsedSourceDefinition ) -from dbt.contracts.graph.compiled import CompileResultNode +from dbt.contracts.graph.compiled import CompileResultNode, NonSourceNode from dbt.contracts.util import Writable, Replaceable from dbt.exceptions import ( raise_duplicate_resource_name, InternalException, raise_compiler_error, @@ -362,6 +362,11 @@ def search(self, haystack: Iterable[N]) -> Optional[N]: return None +@dataclass +class Disabled: + target: ParsedNode + + @dataclass class Manifest: """The manifest for the full graph, after parsing and during compilation. @@ -444,14 +449,17 @@ def find_docs_by_name( def find_refable_by_name( self, name: str, package: Optional[str] - ) -> Optional[CompileResultNode]: + ) -> Optional[NonSourceNode]: """Find any valid target for "ref()" in the graph by its name and package name, or None for any package. """ searcher: NameSearcher = NameSearcher( name, package, NodeType.refable() ) - return searcher.search(self.nodes.values()) + result = searcher.search(self.nodes.values()) + if result is not None: + assert not isinstance(result, ParsedSourceDefinition) + return result def find_source_by_name( self, source_name: str, table_name: str, package: Optional[str] @@ -727,6 +735,93 @@ def expect(self, unique_id: str) -> CompileResultNode: ) return self.nodes[unique_id] + def resolve_ref( + self, + target_model_name: str, + target_model_package: Optional[str], + current_project: str, + node_package: str, + ) -> Optional[Union[NonSourceNode, Disabled]]: + if target_model_package is not None: + return self.find_refable_by_name( + target_model_name, + target_model_package) + + target_model = None + disabled_target = None + + # first pass: look for models in the current_project + # second pass: look for models in the node's package + # final pass: look for models in any package + # todo: exclude the packages we have already searched. overriding + # a package model in another package doesn't necessarily work atm + candidates = [current_project, node_package, None] + for candidate in candidates: + target_model = self.find_refable_by_name( + target_model_name, + candidate) + + if target_model is not None and target_model.config.enabled: + return target_model + + # it's possible that the node is disabled + if disabled_target is None: + disabled_target = self.find_disabled_by_name( + target_model_name, candidate + ) + + if disabled_target is not None: + return Disabled(disabled_target) + return None + + def resolve_source( + self, + target_source_name: str, + target_table_name: str, + current_project: str, + node_package: str + ) -> Optional[ParsedSourceDefinition]: + candidate_targets = [current_project, node_package, None] + target_source = None + for candidate in candidate_targets: + target_source = self.find_source_by_name( + target_source_name, + target_table_name, + candidate + ) + if target_source is not None: + return target_source + + return None + + def resolve_doc( + self, + name: str, + package: Optional[str], + current_project: str, + node_package: str, + ) -> Optional[ParsedDocumentation]: + """Resolve the given documentation. This follows the same algorithm as + resolve_ref except the is_enabled checks are unnecessary as docs are + always enabled. + """ + if package is not None: + return self.find_docs_by_name( + name, package + ) + + candidate_targets = [ + current_project, + node_package, + None, + ] + target_doc = None + for candidate in candidate_targets: + target_doc = self.find_docs_by_name(name, candidate) + if target_doc is not None: + break + return target_doc + @dataclass class WritableManifest(JsonSchemaMixin, Writable): diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 432bb642539..2f66394f878 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -1,3 +1,4 @@ +import os from dataclasses import dataclass, field, Field from typing import ( Optional, @@ -18,6 +19,7 @@ ) from dbt.clients.jinja import MacroGenerator +from dbt.clients.system import write_file import dbt.flags from dbt.contracts.graph.unparsed import ( UnparsedNode, UnparsedMacro, UnparsedDocumentationFile, Quoting, @@ -227,6 +229,14 @@ class ParsedNodeDefaults(ParsedNodeMandatory): patch_path: Optional[str] = None build_path: Optional[str] = None + def write_node(self, target_path: str, subdirectory: str, payload: str): + full_path = os.path.join( + target_path, subdirectory, self.package_name, self.path + ) + + write_file(full_path, payload) + return full_path + @dataclass class ParsedNode(ParsedNodeDefaults, ParsedNodeMixins): diff --git a/core/dbt/contracts/project.py b/core/dbt/contracts/project.py index c6ce52a422b..0dc514ca2f5 100644 --- a/core/dbt/contracts/project.py +++ b/core/dbt/contracts/project.py @@ -1,4 +1,5 @@ from dbt.contracts.util import Replaceable, Mergeable, list_str +from dbt.contracts.connection import UserConfigContract from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt import tracking from dbt.ui import printer @@ -175,7 +176,7 @@ def from_dict(cls, data, validate=True): @dataclass -class UserConfig(ExtensibleJsonSchemaMixin, Replaceable): +class UserConfig(ExtensibleJsonSchemaMixin, Replaceable, UserConfigContract): send_anonymous_usage_stats: bool = DEFAULT_SEND_ANONYMOUS_USAGE_STATS use_colors: bool = DEFAULT_USE_COLORS partial_parse: Optional[bool] = None @@ -193,13 +194,6 @@ def set_values(self, cookie_dir): if self.printer_width: printer.printer_width(self.printer_width) - @classmethod - def from_maybe_dict(cls, value: Optional[Dict[str, Any]]) -> 'UserConfig': - if value is None: - return cls() - else: - return cls.from_dict(value) - @dataclass class ProfileConfig(HyphenatedJsonSchemaMixin, Replaceable): diff --git a/core/dbt/deps/base.py b/core/dbt/deps/base.py index 15670ab04bd..0baa03ed86d 100644 --- a/core/dbt/deps/base.py +++ b/core/dbt/deps/base.py @@ -69,7 +69,7 @@ def get_version(self) -> Optional[str]: raise NotImplementedError @abc.abstractmethod - def _fetch_metadata(self, project): + def _fetch_metadata(self, project, renderer): raise NotImplementedError @abc.abstractmethod @@ -80,17 +80,17 @@ def install(self, project): def nice_version_name(self): raise NotImplementedError - def fetch_metadata(self, project): + def fetch_metadata(self, project, renderer): if not self._cached_metadata: - self._cached_metadata = self._fetch_metadata(project) + self._cached_metadata = self._fetch_metadata(project, renderer) return self._cached_metadata - def get_project_name(self, project): - metadata = self.fetch_metadata(project) + def get_project_name(self, project, renderer): + metadata = self.fetch_metadata(project, renderer) return metadata.name - def get_installation_path(self, project): - dest_dirname = self.get_project_name(project) + def get_installation_path(self, project, renderer): + dest_dirname = self.get_project_name(project, renderer) return os.path.join(project.modules_path, dest_dirname) diff --git a/core/dbt/deps/git.py b/core/dbt/deps/git.py index 52a5fc000ff..238ac754703 100644 --- a/core/dbt/deps/git.py +++ b/core/dbt/deps/git.py @@ -70,7 +70,7 @@ def _checkout(self): raise return os.path.join(get_downloads_path(), dir_) - def _fetch_metadata(self, project) -> ProjectPackageMetadata: + def _fetch_metadata(self, project, renderer) -> ProjectPackageMetadata: path = self._checkout() if self.revision == 'master' and self.warn_unpinned: warn_or_error( @@ -79,11 +79,11 @@ def _fetch_metadata(self, project) -> ProjectPackageMetadata: .format(self.git, PIN_PACKAGE_URL), log_fmt=printer.yellow('WARNING: {}') ) - loaded = Project.from_project_root(path, {}) + loaded = Project.from_project_root(path, renderer) return ProjectPackageMetadata.from_project(loaded) - def install(self, project): - dest_path = self.get_installation_path(project) + def install(self, project, renderer): + dest_path = self.get_installation_path(project, renderer) if os.path.exists(dest_path): if system.path_is_symlink(dest_path): system.remove_file(dest_path) diff --git a/core/dbt/deps/local.py b/core/dbt/deps/local.py index 175c03e434b..9b7b21c3792 100644 --- a/core/dbt/deps/local.py +++ b/core/dbt/deps/local.py @@ -38,13 +38,15 @@ def resolve_path(self, project): project.project_root, ) - def _fetch_metadata(self, project): - loaded = project.from_project_root(self.resolve_path(project), {}) + def _fetch_metadata(self, project, renderer): + loaded = project.from_project_root( + self.resolve_path(project), renderer + ) return ProjectPackageMetadata.from_project(loaded) - def install(self, project): + def install(self, project, renderer): src_path = self.resolve_path(project) - dest_path = self.get_installation_path(project) + dest_path = self.get_installation_path(project, renderer) can_create_symlink = system.supports_symlinks() diff --git a/core/dbt/deps/registry.py b/core/dbt/deps/registry.py index e9200f90e56..5f6b05f6d1b 100644 --- a/core/dbt/deps/registry.py +++ b/core/dbt/deps/registry.py @@ -47,12 +47,12 @@ def get_version(self): def nice_version_name(self): return 'version {}'.format(self.version) - def _fetch_metadata(self, project) -> RegistryPackageMetadata: + def _fetch_metadata(self, project, renderer) -> RegistryPackageMetadata: dct = registry.package_version(self.package, self.version) return RegistryPackageMetadata.from_dict(dct) - def install(self, project): - metadata = self.fetch_metadata(project) + def install(self, project, renderer): + metadata = self.fetch_metadata(project, renderer) tar_name = '{}.{}.tar.gz'.format(self.package, self.version) tar_path = os.path.realpath( @@ -63,7 +63,7 @@ def install(self, project): download_url = metadata.downloads.tarball system.download(download_url, tar_path) deps_path = project.modules_path - package_name = self.get_project_name(project) + package_name = self.get_project_name(project, renderer) system.untar_package(tar_path, deps_path, package_name) diff --git a/core/dbt/deps/resolver.py b/core/dbt/deps/resolver.py index b499f407f18..8f571d35712 100644 --- a/core/dbt/deps/resolver.py +++ b/core/dbt/deps/resolver.py @@ -3,7 +3,8 @@ from dbt.exceptions import raise_dependency_error, InternalException -from dbt.config import Project +from dbt.context.target import generate_target_context +from dbt.config import Project, ConfigRenderer, RuntimeConfig from dbt.deps.base import BasePackage, PinnedPackage, UnpinnedPackage from dbt.deps.local import LocalUnpinnedPackage from dbt.deps.git import GitUnpinnedPackage @@ -96,11 +97,11 @@ def __iter__(self) -> Iterator[UnpinnedPackage]: def _check_for_duplicate_project_names( - final_deps: List[PinnedPackage], config: Project + final_deps: List[PinnedPackage], config: Project, renderer: ConfigRenderer ): seen: Set[str] = set() for package in final_deps: - project_name = package.get_project_name(config) + project_name = package.get_project_name(config, renderer) if project_name in seen: raise_dependency_error( f'Found duplicate project "{project_name}". This occurs when ' @@ -117,20 +118,22 @@ def _check_for_duplicate_project_names( def resolve_packages( - packages: List[PackageContract], config: Project + packages: List[PackageContract], config: RuntimeConfig ) -> List[PinnedPackage]: pending = PackageListing.from_contracts(packages) final = PackageListing() + renderer = ConfigRenderer(generate_target_context(config, config.cli_vars)) + while pending: next_pending = PackageListing() # resolve the dependency in question for package in pending: final.incorporate(package) - target = final[package].resolved().fetch_metadata(config) + target = final[package].resolved().fetch_metadata(config, renderer) next_pending.update_from(target.packages) pending = next_pending resolved = final.resolved() - _check_for_duplicate_project_names(resolved, config) + _check_for_duplicate_project_names(resolved, config, renderer) return resolved diff --git a/core/dbt/exceptions.py b/core/dbt/exceptions.py index 73a23ae8a47..3255f433501 100644 --- a/core/dbt/exceptions.py +++ b/core/dbt/exceptions.py @@ -372,7 +372,7 @@ def raise_dependency_error(msg) -> NoReturn: def invalid_type_error(method_name, arg_name, got_value, expected_type, - version='0.13.0'): + version='0.13.0') -> NoReturn: """Raise a CompilationException when an adapter method available to macros has changed. """ @@ -385,13 +385,13 @@ def invalid_type_error(method_name, arg_name, got_value, expected_type, got_value=got_value, got_type=got_type)) -def ref_invalid_args(model, args): +def ref_invalid_args(model, args) -> NoReturn: raise_compiler_error( "ref() takes at most two arguments ({} given)".format(len(args)), model) -def ref_bad_context(model, args): +def ref_bad_context(model, args) -> NoReturn: ref_args = ', '.join("'{}'".format(a) for a in args) ref_string = '{{{{ ref({}) }}}}'.format(ref_args) @@ -419,7 +419,7 @@ def ref_bad_context(model, args): raise_compiler_error(error_msg, model) -def doc_invalid_args(model, args): +def doc_invalid_args(model, args) -> NoReturn: raise_compiler_error( "doc() takes at most two arguments ({} given)".format(len(args)), model) @@ -497,7 +497,7 @@ def source_disabled_message(model, target_name, target_table_name): target_table_name)) -def source_target_not_found(model, target_name, target_table_name): +def source_target_not_found(model, target_name, target_table_name) -> NoReturn: msg = source_disabled_message(model, target_name, target_table_name) raise_compiler_error(msg, model) diff --git a/core/dbt/graph/selector.py b/core/dbt/graph/selector.py index 147b6a2ef65..7a788498eb1 100644 --- a/core/dbt/graph/selector.py +++ b/core/dbt/graph/selector.py @@ -5,7 +5,7 @@ import networkx as nx # type: ignore from dbt.logger import GLOBAL_LOGGER as logger -from dbt.utils import is_enabled, coalesce +from dbt.utils import coalesce from dbt.node_types import NodeType import dbt.exceptions @@ -351,7 +351,7 @@ def _is_graph_member(self, node_name): node = self.manifest.nodes[node_name] if node.resource_type == NodeType.Source: return True - return not node.empty and is_enabled(node) + return not node.empty and node.config.enabled def _is_match(self, node_name, resource_types, tags, required): node = self.manifest.nodes[node_name] diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index da881cd8704..876133d5207 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -18,7 +18,7 @@ ) from dbt.compilation import compile_node -import dbt.context.runtime +from dbt.context.providers import generate_runtime_model import dbt.exceptions import dbt.utils import dbt.tracking @@ -429,8 +429,9 @@ def _materialization_relations( raise CompilationException(msg, node=model) def execute(self, model, manifest): - context = dbt.context.runtime.generate( - model, self.config, manifest) + context = generate_runtime_model( + model, self.config, manifest + ) materialization_macro = manifest.find_materialization_macro_by_name( self.config.project_name, diff --git a/core/dbt/parser/__init__.py b/core/dbt/parser/__init__.py index b5855f1fead..c509e357f75 100644 --- a/core/dbt/parser/__init__.py +++ b/core/dbt/parser/__init__.py @@ -9,9 +9,8 @@ from .schemas import SchemaParser # noqa from .seeds import SeedParser # noqa from .snapshots import SnapshotParser # noqa -from .util import ParserUtils # noqa from . import ( # noqa analysis, base, data_test, docs, hooks, macros, models, results, schemas, - snapshots, util + snapshots ) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index a44e7eeff99..4f48be019b9 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -7,10 +7,11 @@ from hologram import ValidationError -import dbt.context.parser +from dbt.context.providers import generate_parser_model, generate_parser_macro import dbt.flags from dbt import deprecations from dbt import hooks +from dbt.adapters.factory import get_adapter from dbt.clients.jinja import get_rendered from dbt.config import Project, RuntimeConfig from dbt.contracts.graph.manifest import ( @@ -128,13 +129,10 @@ def default_database(self): return self.root_project.credentials.database def _build_generate_macro_function(self, macro: ParsedMacro) -> Callable: - context = dbt.context.parser.generate_macro( - model=macro, - runtime_config=self.root_project, - manifest=self.macro_manifest, - package_name=None, + root_context = generate_parser_macro( + macro, self.root_project, self.macro_manifest, None ) - return macro.generator(context) + return macro.generator(root_context) def get_schema_func(self) -> RelationUpdate: """The get_schema function is set by a few different things: @@ -277,7 +275,7 @@ def _create_parsetime_node( def _context_for( self, parsed_node: IntermediateNode, config: SourceConfig ) -> Dict[str, Any]: - return dbt.context.parser.generate( + return generate_parser_model( parsed_node, self.root_project, self.macro_manifest, config ) @@ -289,10 +287,11 @@ def render_with_context( Note: this mutates the config object when config() calls are rendered. """ - context = self._context_for(parsed_node, config) + with get_adapter(self.root_project).connection_for(parsed_node): + context = self._context_for(parsed_node, config) - get_rendered(parsed_node.raw_sql, context, parsed_node, - capture_macros=True) + get_rendered(parsed_node.raw_sql, context, parsed_node, + capture_macros=True) def update_parsed_node_schema( self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 7b816dcb872..2988ad2d599 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -2,7 +2,7 @@ import os import pickle from datetime import datetime -from typing import Dict, Optional, Mapping, Callable, Any, List, Type +from typing import Dict, Optional, Mapping, Callable, Any, List, Type, Union from dbt.include.global_project import PACKAGES import dbt.exceptions @@ -12,8 +12,12 @@ from dbt.node_types import NodeType from dbt.clients.system import make_directory from dbt.config import Project, RuntimeConfig -from dbt.contracts.graph.compiled import CompileResultNode -from dbt.contracts.graph.manifest import Manifest, FilePath, FileHash +from dbt.context.docs import generate_runtime_docs +from dbt.contracts.graph.compiled import CompileResultNode, NonSourceNode +from dbt.contracts.graph.manifest import Manifest, FilePath, FileHash, Disabled +from dbt.contracts.graph.parsed import ( + ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo +) from dbt.exceptions import raise_compiler_error from dbt.parser.base import BaseParser, Parser from dbt.parser.analysis import AnalysisParser @@ -27,7 +31,6 @@ from dbt.parser.search import FileBlock from dbt.parser.seeds import SeedParser from dbt.parser.snapshots import SnapshotParser -from dbt.parser.util import ParserUtils from dbt.version import __version__ @@ -292,6 +295,12 @@ def read_parse_results(self) -> Optional[ParseResult]: return None + def process_manifest(self, manifest: Manifest): + project_name = self.root_project.project_name + process_sources(manifest, project_name) + process_refs(manifest, project_name) + process_docs(manifest, self.root_project) + def create_manifest(self) -> Manifest: nodes: Dict[str, CompileResultNode] = {} nodes.update(self.results.nodes) @@ -310,15 +319,7 @@ def create_manifest(self) -> Manifest: ) manifest.patch_nodes(self.results.patches) manifest.patch_macros(self.results.macro_patches) - manifest = ParserUtils.process_sources( - manifest, self.root_project.project_name - ) - manifest = ParserUtils.process_refs( - manifest, self.root_project.project_name - ) - manifest = ParserUtils.process_docs( - manifest, self.root_project.project_name - ) + self.process_manifest(manifest) return manifest @classmethod @@ -420,7 +421,208 @@ def _project_directories(config): yield full_obj -def load_all_projects(config) -> Mapping[str, Project]: +def _get_node_column(node, column_name): + """Given a ParsedNode, add some fields that might be missing. Return a + reference to the dict that refers to the given column, creating it if + it doesn't yet exist. + """ + if column_name in node.columns: + column = node.columns[column_name] + else: + node.columns[column_name] = ColumnInfo(name=column_name) + node.columns[column_name] = column + + return column + + +DocsContextCallback = Callable[ + [Union[ParsedNode, ParsedSourceDefinition]], + Dict[str, Any] +] + + +def _process_docs_for_node( + context: Dict[str, Any], + node: NonSourceNode, +): + for docref in node.docrefs: + column_name = docref.column_name + + if column_name is None: + obj = node + else: + obj = _get_node_column(node, column_name) + + raw = obj.description or '' + # At this point, we know that our documentation string has a + # 'docs("...")' pointing at it. We want to render it. + obj.description = dbt.clients.jinja.get_rendered(raw, context) + + +def _process_docs_for_source( + context: Dict[str, Any], + source: ParsedSourceDefinition, +): + table_description = source.description + source_description = source.source_description + table_description = dbt.clients.jinja.get_rendered(table_description, + context) + source_description = dbt.clients.jinja.get_rendered(source_description, + context) + source.description = table_description + source.source_description = source_description + + for column in source.columns.values(): + column_desc = column.description + column_desc = dbt.clients.jinja.get_rendered(column_desc, context) + column.description = column_desc + + +def _process_docs_for_macro( + context: Dict[str, Any], macro: ParsedMacro +) -> None: + for docref in macro.docrefs: + raw = macro.description or '' + macro.description = dbt.clients.jinja.get_rendered(raw, context) + + +def process_docs(manifest: Manifest, config: RuntimeConfig): + for node in manifest.nodes.values(): + ctx = generate_runtime_docs( + config, + node, + manifest, + config.project_name, + ) + if node.resource_type == NodeType.Source: + assert isinstance(node, ParsedSourceDefinition) # appease mypy + _process_docs_for_source(ctx, node) + else: + assert not isinstance(node, ParsedSourceDefinition) + _process_docs_for_node(ctx, node) + for macro in manifest.macros.values(): + ctx = generate_runtime_docs( + config, + macro, + manifest, + config.project_name, + ) + _process_docs_for_macro(ctx, macro) + + +def _process_refs_for_node( + manifest: Manifest, current_project: str, node: NonSourceNode +): + """Given a manifest and a node in that manifest, process its refs""" + for ref in node.refs: + target_model: Optional[Union[Disabled, NonSourceNode]] = None + target_model_name: str + target_model_package: Optional[str] = None + + if len(ref) == 1: + target_model_name = ref[0] + elif len(ref) == 2: + target_model_package, target_model_name = ref + else: + raise dbt.exceptions.InternalException( + f'Refs should always be 1 or 2 arguments - got {len(ref)}' + ) + + target_model = manifest.resolve_ref( + target_model_name, + target_model_package, + current_project, + node.package_name, + ) + + if target_model is None or isinstance(target_model, Disabled): + # This may raise. Even if it doesn't, we don't want to add + # this node to the graph b/c there is no destination node + node.config.enabled = False + dbt.utils.invalid_ref_fail_unless_test( + node, target_model_name, target_model_package, + disabled=(isinstance(target_model, Disabled)) + ) + + continue + + target_model_id = target_model.unique_id + + node.depends_on.nodes.append(target_model_id) + # TODO: I think this is extraneous, node should already be the same + # as manifest.nodes[node.unique_id] (we're mutating node here, not + # making a new one) + manifest.update_node(node) + + +def process_refs(manifest: Manifest, current_project: str): + for node in manifest.nodes.values(): + if node.resource_type == NodeType.Source: + continue + assert not isinstance(node, ParsedSourceDefinition) + _process_refs_for_node(manifest, current_project, node) + return manifest + + +def _process_sources_for_node( + manifest: Manifest, current_project: str, node: NonSourceNode +): + target_source = None + for source_name, table_name in node.sources: + target_source = manifest.resolve_source( + source_name, + table_name, + current_project, + node.package_name, + ) + + if target_source is None: + # this folows the same pattern as refs + node.config.enabled = False + dbt.utils.invalid_source_fail_unless_test( + node, + source_name, + table_name) + continue + target_source_id = target_source.unique_id + node.depends_on.nodes.append(target_source_id) + manifest.update_node(node) + + +def process_sources(manifest: Manifest, current_project: str): + for node in manifest.nodes.values(): + if node.resource_type == NodeType.Source: + continue + assert not isinstance(node, ParsedSourceDefinition) + _process_sources_for_node(manifest, current_project, node) + return manifest + + +def process_macro( + config: RuntimeConfig, manifest: Manifest, macro: ParsedMacro +) -> None: + ctx = generate_runtime_docs( + config, + macro, + manifest, + config.project_name, + ) + _process_docs_for_macro(ctx, macro) + + +def process_node( + config: RuntimeConfig, manifest: Manifest, node: NonSourceNode +): + + _process_sources_for_node( + manifest, config.project_name, node + ) + _process_refs_for_node(manifest, config.project_name, node) + ctx = generate_runtime_docs(config, node, manifest, config.project_name) + _process_docs_for_node(ctx, node) + + +def load_all_projects(config: RuntimeConfig) -> Mapping[str, Project]: all_projects = {config.project_name: config} project_paths = itertools.chain( internal_project_names(), diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index 63b8c729b00..f23a627ca4a 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -8,11 +8,12 @@ from hologram import ValidationError -from dbt.context.base import ConfigRenderContext from dbt.clients.jinja import get_rendered from dbt.clients.yaml_helper import load_yaml_text -from dbt.config.renderer import ConfigRenderer +from dbt.config import RuntimeConfig, ConfigRenderer +from dbt.context.docs import generate_parser_docs +from dbt.context.target import generate_target_context from dbt.contracts.graph.manifest import SourceFile from dbt.contracts.graph.parsed import ( ParsedNodePatch, @@ -27,7 +28,6 @@ UnparsedMacroUpdate, UnparsedAnalysisUpdate, UnparsedSourceTableDefinition, FreshnessThreshold, ) -from dbt.context.parser import docs from dbt.exceptions import ( validator_error_message, JSONValidationException, raise_invalid_schema_yml_version, ValidationException, CompilationException @@ -89,12 +89,13 @@ def add(self, column: UnparsedColumn, description, data_type, meta): def collect_docrefs( + config: RuntimeConfig, target: UnparsedSchemaYaml, refs: ParserRef, column_name: Optional[str], *descriptions: str, ) -> None: - context = {'doc': docs(target, refs.docrefs, column_name)} + context = generate_parser_docs(config, target, refs.docrefs, column_name) for description in descriptions: get_rendered(description, context) @@ -186,7 +187,9 @@ def parse_node(self, block: SchemaTestBlock) -> ParsedTestNode: builds the initial node to be parsed, but rendering is basically the same """ - render_ctx = ConfigRenderContext(self.root_project.cli_vars).to_dict() + render_ctx = generate_target_context( + self.root_project, self.root_project.cli_vars + ) builder = TestBuilder[Target]( test=block.test, target=block.target, @@ -361,7 +364,13 @@ def parse_docs(self, block: TargetBlock) -> ParserRef: description = column.description data_type = column.data_type meta = column.meta - collect_docrefs(block.target, refs, column_name, description) + collect_docrefs( + self.root_project, + block.target, + refs, + column_name, + description, + ) refs.add(column, description, data_type, meta) return refs @@ -441,7 +450,9 @@ def parse_docs(self, block: TargetBlock) -> ParserRef: description = column.description data_type = column.data_type meta = column.meta - collect_docrefs(block.target, refs, column_name, description) + collect_docrefs( + self.root_project, block.target, refs, column_name, description + ) refs.add(column, description, data_type, meta) return refs @@ -472,7 +483,11 @@ def parse_patch( class SourceParser(YamlDocsReader[SourceTarget, ParsedSourceDefinition]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._renderer = ConfigRenderer(self.root_project.cli_vars) + self._renderer = ConfigRenderer( + generate_target_context( + self.root_project, self.root_project.cli_vars + ) + ) def get_block(self, node: SourceTarget) -> TestBlock: return TestBlock.from_yaml_block(self.yaml, node) @@ -518,7 +533,10 @@ def parse_patch( description = table.description or '' meta = table.meta or {} source_description = source.description or '' - collect_docrefs(source, refs, None, description, source_description) + collect_docrefs( + self.root_project, source, refs, None, description, + source_description + ) loaded_at_field = table.loaded_at_field or source.loaded_at_field @@ -566,7 +584,9 @@ def collect_docrefs( self, block: TargetBlock[NonSourceTarget], refs: ParserRef ) -> str: description = block.target.description - collect_docrefs(block.target, refs, None, description) + collect_docrefs( + self.root_project, block.target, refs, None, description + ) return description @abstractmethod diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py deleted file mode 100644 index 7e31a3e5135..00000000000 --- a/core/dbt/parser/util.py +++ /dev/null @@ -1,307 +0,0 @@ -from typing import Optional, Union - -import dbt.exceptions -import dbt.utils -from dbt.node_types import NodeType -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.graph.parsed import ( - ColumnInfo, ParsedNode, ParsedMacro, ParsedSourceDefinition, - ParsedDocumentation, -) -from dbt.config import RuntimeConfig -from dbt.flags import SINGLE_THREADED_HANDLER - - -def docs(node, manifest, current_project: str, column_name=None): - """Return a function that will process `doc()` references in jinja, look - them up in the manifest, and return the appropriate block contents. - """ - def do_docs(*args: str): - if len(args) == 1: - doc_package_name = None - doc_name = args[0] - elif len(args) == 2: - doc_package_name, doc_name = args - else: - dbt.exceptions.doc_invalid_args(node, args) - - target_doc = ParserUtils.resolve_doc( - manifest, doc_name, doc_package_name, current_project, - node.package_name - ) - - if target_doc is None: - dbt.exceptions.doc_target_not_found(node, doc_name, - doc_package_name) - - return target_doc.block_contents - - return do_docs - - -class Disabled: - pass - - -class ParserUtils: - DISABLED = Disabled() - - @classmethod - def resolve_source( - cls, manifest: Manifest, target_source_name: str, - target_table_name: str, current_project: str, - node_package: str - ) -> Optional[ParsedSourceDefinition]: - candidate_targets = [current_project, node_package, None] - target_source = None - for candidate in candidate_targets: - target_source = manifest.find_source_by_name( - target_source_name, - target_table_name, - candidate - ) - if target_source is not None: - return target_source - - return None - - @classmethod - def resolve_ref( - cls, manifest, target_model_name: Optional[str], - target_model_package: Optional[str], current_project: str, - node_package: str - ) -> Optional[Union[ParsedNode, Disabled]]: - if target_model_package is not None: - return manifest.find_refable_by_name( - target_model_name, - target_model_package) - - target_model = None - disabled_target = None - - # first pass: look for models in the current_project - # second pass: look for models in the node's package - # final pass: look for models in any package - # todo: exclude the packages we have already searched. overriding - # a package model in another package doesn't necessarily work atm - candidates = [current_project, node_package, None] - for candidate in candidates: - target_model = manifest.find_refable_by_name( - target_model_name, - candidate) - - if target_model is not None and dbt.utils.is_enabled(target_model): - return target_model - - # it's possible that the node is disabled - if disabled_target is None: - disabled_target = manifest.find_disabled_by_name( - target_model_name, candidate - ) - - if disabled_target is not None: - return cls.DISABLED - return None - - @classmethod - def resolve_doc( - cls, manifest, target_doc_name: str, target_doc_package: Optional[str], - current_project: str, node_package: str - ) -> Optional[ParsedDocumentation]: - """Resolve the given documentation. This follows the same algorithm as - resolve_ref except the is_enabled checks are unnecessary as docs are - always enabled. - """ - if target_doc_package is not None: - return manifest.find_docs_by_name(target_doc_name, - target_doc_package) - - candidate_targets = [current_project, node_package, None] - target_doc = None - for candidate in candidate_targets: - target_doc = manifest.find_docs_by_name(target_doc_name, candidate) - if target_doc is not None: - break - return target_doc - - @classmethod - def _get_node_column(cls, node, column_name): - """Given a ParsedNode, add some fields that might be missing. Return a - reference to the dict that refers to the given column, creating it if - it doesn't yet exist. - """ - if column_name in node.columns: - column = node.columns[column_name] - else: - node.columns[column_name] = ColumnInfo(name=column_name) - node.columns[column_name] = column - - return column - - @classmethod - def process_docs_for_node( - cls, - manifest: Manifest, - current_project: str, - node: ParsedNode, - ) -> None: - for docref in node.docrefs: - column_name = docref.column_name - - if column_name is None: - obj = node - else: - obj = cls._get_node_column(node, column_name) - - context = { - 'doc': docs(node, manifest, current_project, column_name), - } - - raw = obj.description or '' - # At this point, we know that our documentation string has a - # 'docs("...")' pointing at it. We want to render it. - obj.description = dbt.clients.jinja.get_rendered(raw, context) - - @classmethod - def process_docs_for_macro( - cls, - manifest: Manifest, - current_project: str, - macro: ParsedMacro, - ) -> None: - for docref in macro.docrefs: - context = { - 'doc': docs(macro, manifest, current_project) - } - raw = macro.description or '' - macro.description = dbt.clients.jinja.get_rendered(raw, context) - - @classmethod - def process_docs_for_source( - cls, - manifest: Manifest, - current_project: str, - source: ParsedSourceDefinition - ) -> None: - context = { - 'doc': docs(source, manifest, current_project), - } - table_description = source.description - source_description = source.source_description - table_description = dbt.clients.jinja.get_rendered(table_description, - context) - source_description = dbt.clients.jinja.get_rendered(source_description, - context) - source.description = table_description - source.source_description = source_description - - for column in source.columns.values(): - column_desc = column.description - column_desc = dbt.clients.jinja.get_rendered(column_desc, context) - column.description = column_desc - - @classmethod - def process_docs(cls, manifest, current_project: str): - for node in manifest.nodes.values(): - if node.resource_type == NodeType.Source: - cls.process_docs_for_source(manifest, current_project, node) - else: - cls.process_docs_for_node(manifest, current_project, node) - for macro in manifest.macros.values(): - cls.process_docs_for_macro(manifest, current_project, macro) - return manifest - - @classmethod - def process_refs_for_node(cls, manifest, current_project: str, node): - """Given a manifest and a node in that manifest, process its refs""" - for ref in node.refs: - target_model = None - target_model_name = None - target_model_package = None - - if len(ref) == 1: - target_model_name = ref[0] - elif len(ref) == 2: - target_model_package, target_model_name = ref - - target_model = cls.resolve_ref( - manifest, - target_model_name, - target_model_package, - current_project, - node.package_name) - - if target_model is None or isinstance(target_model, Disabled): - # This may raise. Even if it doesn't, we don't want to add - # this node to the graph b/c there is no destination node - node.config.enabled = False - dbt.utils.invalid_ref_fail_unless_test( - node, target_model_name, target_model_package, - disabled=isinstance(target_model, Disabled) - ) - - continue - - target_model_id = target_model.unique_id - - node.depends_on.nodes.append(target_model_id) - # TODO: I think this is extraneous, node should already be the same - # as manifest.nodes[node.unique_id] (we're mutating node here, not - # making a new one) - manifest.update_node(node) - - @classmethod - def process_refs(cls, manifest, current_project: str): - # process_refs_for_node will mutate this - all_nodes = list(manifest.nodes.values()) - for node in all_nodes: - cls.process_refs_for_node(manifest, current_project, node) - return manifest - - @classmethod - def process_sources_for_node(cls, manifest, current_project: str, node): - target_source = None - for source_name, table_name in node.sources: - target_source = cls.resolve_source( - manifest, - source_name, - table_name, - current_project, - node.package_name) - - if target_source is None: - # this folows the same pattern as refs - node.config.enabled = False - dbt.utils.invalid_source_fail_unless_test( - node, - source_name, - table_name) - continue - target_source_id = target_source.unique_id - node.depends_on.nodes.append(target_source_id) - manifest.update_node(node) - - @classmethod - def process_sources(cls, manifest, current_project: str): - all_nodes = list(manifest.nodes.values()) - for node in all_nodes: - cls.process_sources_for_node(manifest, current_project, node) - return manifest - - @classmethod - def add_new_refs(cls, manifest, config: RuntimeConfig, node, macros): - """Given a new node that is not in the manifest, insert the new node - into it as if it were part of regular ref processing. - """ - if config.args.single_threaded or SINGLE_THREADED_HANDLER: - manifest = manifest.deepcopy() - # it's ok for macros to silently override a local project macro name - manifest.macros.update(macros) - - manifest.add_nodes({node.unique_id: node}) - cls.process_sources_for_node( - manifest, config.project_name, node - ) - cls.process_refs_for_node(manifest, config.project_name, node) - cls.process_docs_for_node(manifest, config.project_name, node) - return manifest diff --git a/core/dbt/perf_utils.py b/core/dbt/perf_utils.py index e785315629b..21f9b7045b3 100644 --- a/core/dbt/perf_utils.py +++ b/core/dbt/perf_utils.py @@ -17,7 +17,7 @@ def get_full_manifest(config: RuntimeConfig) -> Manifest: adapter = get_adapter(config) # type: ignore internal: Manifest = adapter.load_internal_manifest() - def set_header(manifest): + def set_header(manifest: Manifest) -> None: adapter.connections.set_query_header(manifest) return load_manifest(config, internal, set_header) diff --git a/core/dbt/task/base.py b/core/dbt/task/base.py index 99722681bc0..2a113c6b831 100644 --- a/core/dbt/task/base.py +++ b/core/dbt/task/base.py @@ -128,20 +128,14 @@ def move_to_nearest_project_dir(args): os.chdir(nearest_project_dir) -class RequiresProjectTask(BaseTask): - @classmethod - def from_args(cls, args): - move_to_nearest_project_dir(args) - return super().from_args(args) - - -class ConfiguredTask(RequiresProjectTask): +class ConfiguredTask(BaseTask): ConfigType = RuntimeConfig def __init__(self, args, config): super().__init__(args, config) register_adapter(self.config) - -class ProjectOnlyTask(RequiresProjectTask): - ConfigType = Project + @classmethod + def from_args(cls, args): + move_to_nearest_project_dir(args) + return super().from_args(args) diff --git a/core/dbt/task/clean.py b/core/dbt/task/clean.py index ad4c3d1edaf..9a50787f4f1 100644 --- a/core/dbt/task/clean.py +++ b/core/dbt/task/clean.py @@ -2,11 +2,11 @@ import os import shutil -from dbt.task.base import ProjectOnlyTask +from dbt.task.base import ConfiguredTask from dbt.logger import GLOBAL_LOGGER as logger -class CleanTask(ProjectOnlyTask): +class CleanTask(ConfiguredTask): def __is_project_path(self, path): proj_path = os.path.abspath('.') diff --git a/core/dbt/task/debug.py b/core/dbt/task/debug.py index d8851a740d8..e2a92398928 100644 --- a/core/dbt/task/debug.py +++ b/core/dbt/task/debug.py @@ -12,7 +12,9 @@ from dbt.links import ProfileConfigDocs from dbt.adapters.factory import get_adapter, register_adapter from dbt.version import get_installed_version -from dbt.config import Project, Profile +from dbt.config import Project, Profile, ConfigRenderer +from dbt.context.base import generate_base_context +from dbt.context.target import generate_target_context from dbt.clients.yaml_helper import load_yaml_text from dbt.ui.printer import green, red @@ -126,9 +128,16 @@ def _load_project(self): self.project_fail_details = FILE_NOT_FOUND return red('ERROR not found') + if self.profile is None: + ctx = generate_base_context(self.cli_vars) + else: + ctx = generate_target_context(self.profile, self.cli_vars) + + renderer = ConfigRenderer(ctx) + try: self.project = Project.from_project_root(self.project_dir, - self.cli_vars) + renderer) except dbt.exceptions.DbtConfigError as exc: self.project_fail_details = str(exc) return red('ERROR invalid') @@ -161,14 +170,16 @@ def _target_found(self): return green('OK found') def _choose_profile_name(self): - assert self.project or self.project_fail_details, \ - '_load_project() required' + project_profile: Optional[str] = None + if os.path.exists(self.project_path): + try: + project_profile = load_yaml_text( + dbt.clients.system.load_file_contents(self.project_path) + ).get('profile') + except dbt.exceptions.Exception: + pass - project_profile = None - if self.project: - project_profile = self.project.profile_name - - args_profile = getattr(self.args, 'profile', None) + args_profile: Optional[str] = getattr(self.args, 'profile', None) try: return Profile.pick_profile_name(args_profile, project_profile) @@ -192,19 +203,23 @@ def _choose_profile_name(self): def _choose_target_name(self): has_raw_profile = (self.raw_profile_data and self.profile_name and self.profile_name in self.raw_profile_data) + if not has_raw_profile: + return None + # mypy appeasement, we checked just above assert self.raw_profile_data is not None assert self.profile_name is not None + raw_profile = self.raw_profile_data[self.profile_name] - if has_raw_profile: - raw_profile = self.raw_profile_data[self.profile_name] + renderer = ConfigRenderer(generate_base_context(self.cli_vars)) - target_name, _ = Profile.render_profile( - raw_profile, self.profile_name, - getattr(self.args, 'target', None), self.cli_vars - ) - return target_name - return None + target_name, _ = Profile.render_profile( + raw_profile, + self.profile_name, + getattr(self.args, 'target', None), + renderer + ) + return target_name def _load_profile(self): if not os.path.exists(self.profile_path): @@ -226,9 +241,10 @@ def _load_profile(self): self.profile_name = self._choose_profile_name() self.target_name = self._choose_target_name() + renderer = ConfigRenderer(generate_base_context(self.cli_vars)) try: - self.profile = QueryCommentedProfile.from_args( - self.args, self.profile_name + self.profile = QueryCommentedProfile.render_from_args( + self.args, renderer, self.profile_name ) except dbt.exceptions.DbtConfigError as exc: self.profile_fail_details = str(exc) @@ -250,8 +266,8 @@ def test_dependencies(self): print('') def test_configuration(self): - project_status = self._load_project() profile_status = self._load_profile() + project_status = self._load_project() print('Configuration:') print(' profiles.yml file [{}]'.format(profile_status)) print(' dbt_project.yml file [{}]'.format(project_status)) @@ -335,7 +351,7 @@ def validate_connection(cls, target_dict): raw_profile=profile_data, profile_name='', target_override=target_name, - cli_vars={}, + renderer=ConfigRenderer(generate_base_context({})), ) result = cls.attempt_connection(profile) if result is not None: diff --git a/core/dbt/task/deps.py b/core/dbt/task/deps.py index 17e33e58fdc..68dbf0e6976 100644 --- a/core/dbt/task/deps.py +++ b/core/dbt/task/deps.py @@ -1,24 +1,25 @@ -from typing import Optional - import dbt.utils import dbt.deprecations import dbt.exceptions -from dbt.config import Project +from dbt.config import RuntimeConfig, ConfigRenderer +from dbt.context.target import generate_target_context from dbt.deps.base import downloads_directory from dbt.deps.resolver import resolve_packages from dbt.logger import GLOBAL_LOGGER as logger from dbt.clients import system -from dbt.task.base import ProjectOnlyTask +from dbt.task.base import ConfiguredTask -class DepsTask(ProjectOnlyTask): - def __init__(self, args, config: Optional[Project] = None): +class DepsTask(ConfiguredTask): + def __init__(self, args, config: RuntimeConfig): super().__init__(args=args, config=config) - def track_package_install(self, package_name, source_type, version): + def track_package_install( + self, package_name: str, source_type: str, version: str + ) -> None: version = 'local' if source_type == 'local' else version h_package_name = dbt.utils.md5(package_name) @@ -40,9 +41,13 @@ def run(self): with downloads_directory(): final_deps = resolve_packages(packages, self.config) + renderer = ConfigRenderer(generate_target_context( + self.config, self.config.cli_vars + )) + for package in final_deps: logger.info('Installing {}', package) - package.install(self.config) + package.install(self.config, renderer) logger.info(' Installed from {}\n', package.nice_version_name()) diff --git a/core/dbt/task/rpc/sql_commands.py b/core/dbt/task/rpc/sql_commands.py index c9e4c09b971..06cb6bf87f1 100644 --- a/core/dbt/task/rpc/sql_commands.py +++ b/core/dbt/task/rpc/sql_commands.py @@ -1,18 +1,23 @@ import base64 -from datetime import datetime import signal import threading +from datetime import datetime +from typing import Dict, Any +from dbt import flags from dbt.adapters.factory import get_adapter from dbt.clients.jinja import extract_toplevel_blocks from dbt.compilation import compile_manifest, compile_node +from dbt.config.runtime import RuntimeConfig +from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.graph.parsed import ParsedRPCNode from dbt.contracts.rpc import RPCExecParameters from dbt.contracts.rpc import RemoteExecutionResult from dbt.exceptions import RPCKilledException, InternalException from dbt.logger import GLOBAL_LOGGER as logger from dbt.parser.results import ParseResult +from dbt.parser.manifest import process_node, process_macro from dbt.parser.rpc import RPCCallParser, RPCMacroParser -from dbt.parser.util import ParserUtils from dbt.rpc.error import invalid_params from dbt.rpc.node_runners import RPCCompileRunner, RPCExecuteRunner from dbt.task.compile import CompileTask @@ -21,6 +26,27 @@ from .base import RPCTask +def add_new_refs( + manifest: Manifest, + config: RuntimeConfig, + node: ParsedRPCNode, + macros: Dict[str, Any] +) -> None: + """Given a new node that is not in the manifest, insert the new node + into it as if it were part of regular ref processing. + """ + if config.args.single_threaded or flags.SINGLE_THREADED_HANDLER: + manifest = manifest.deepcopy() + # it's ok for macros to silently override a local project macro name + manifest.macros.update(macros) + + for macro in macros.values(): + process_macro(config, manifest, macro) + + manifest.add_nodes({node.unique_id: node}) + process_node(config, manifest, node) + + class RemoteRunSQLTask(RPCTask[RPCExecParameters]): def runtime_cleanup(self, selected_uids): """Do some pre-run cleanup that is usually performed in Task __init__. @@ -117,7 +143,7 @@ def _get_exec_node(self): macro_manifest=self.manifest, ) rpc_node = rpc_parser.parse_remote(sql, self.args.name) - self.manifest = ParserUtils.add_new_refs( + add_new_refs( manifest=self.manifest, config=self.config, node=rpc_node, diff --git a/core/dbt/task/serve.py b/core/dbt/task/serve.py index 25666f5f572..fb7a2e8bec4 100644 --- a/core/dbt/task/serve.py +++ b/core/dbt/task/serve.py @@ -7,10 +7,10 @@ from socketserver import TCPServer from dbt.logger import GLOBAL_LOGGER as logger -from dbt.task.base import ProjectOnlyTask +from dbt.task.base import ConfiguredTask -class ServeTask(ProjectOnlyTask): +class ServeTask(ConfiguredTask): def run(self): os.chdir(self.config.target_path) diff --git a/core/dbt/writer.py b/core/dbt/writer.py deleted file mode 100644 index e36519b2c60..00000000000 --- a/core/dbt/writer.py +++ /dev/null @@ -1,14 +0,0 @@ -import os.path - -import dbt.clients.system - - -def write_node(node, target_path, subdirectory, payload): - node_path = node.path - - full_path = os.path.join(target_path, subdirectory, node.package_name, - node_path) - - dbt.clients.system.write_file(full_path, payload) - - return full_path diff --git a/test/integration/033_event_tracking_test/test_events.py b/test/integration/033_event_tracking_test/test_events.py index 00205b035b6..0489695e679 100644 --- a/test/integration/033_event_tracking_test/test_events.py +++ b/test/integration/033_event_tracking_test/test_events.py @@ -256,9 +256,9 @@ def test__postgres_event_tracking_deps(self): ] expected_contexts = [ - self.build_context('deps', 'start', adapter_type=None), + self.build_context('deps', 'start'), package_context, - self.build_context('deps', 'end', result_type='ok', adapter_type=None) + self.build_context('deps', 'end', result_type='ok') ] self.run_event_test(["deps"], expected_calls, expected_contexts) diff --git a/test/integration/base.py b/test/integration/base.py index c5cbc902c72..6b2cf545f10 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -19,7 +19,7 @@ from dbt.adapters.factory import get_adapter, reset_adapters, register_adapter from dbt.clients.jinja import template_cache from dbt.config import RuntimeConfig -from dbt.context import common +from dbt.context import providers from dbt.logger import GLOBAL_LOGGER as logger, log_manager @@ -748,7 +748,7 @@ def get_connection(self, name=None): """ if name is None: name = '__test' - with patch.object(common, 'get_adapter', return_value=self.adapter): + with patch.object(providers, 'get_adapter', return_value=self.adapter): with self.adapter.connection_named(name): conn = self.adapter.connections.get_thread_connection() yield conn diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 18247050472..05e738c9a09 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -10,6 +10,7 @@ from dbt.adapters.bigquery import BigQueryAdapter from dbt.adapters.bigquery import BigQueryRelation from dbt.adapters.bigquery.connections import BigQueryConnectionManager +from dbt.adapters.base.query_headers import MacroQueryStringSetter import dbt.exceptions from dbt.logger import GLOBAL_LOGGER as logger # noqa @@ -81,6 +82,8 @@ def get_adapter(self, target): ) adapter = BigQueryAdapter(config) + adapter.connections.query_header = MacroQueryStringSetter(config, MagicMock(macros={})) + self.qh_patch = patch.object(adapter.connections.query_header, 'add') self.mock_query_header_add = self.qh_patch.start() self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q) diff --git a/test/unit/test_config.py b/test/unit/test_config.py index eea3e25ddec..2db240a392b 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -13,6 +13,7 @@ import dbt.exceptions from dbt.adapters.postgres import PostgresCredentials from dbt.adapters.redshift import RedshiftCredentials +from dbt.context.base import generate_base_context from dbt.contracts.project import PackageConfig, LocalPackage, GitPackage from dbt.semver import VersionSpecifier from dbt.task.run_operation import RunOperationTask @@ -33,6 +34,10 @@ def temp_cd(path): os.chdir(current_path) +def empty_renderer(): + return dbt.config.ConfigRenderer(generate_base_context({})) + + model_config = { 'my_package_name': { 'enabled': True, @@ -214,8 +219,9 @@ def setUp(self): super().setUp() def from_raw_profiles(self): + renderer = empty_renderer() return dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default', {} + self.default_profile_data, 'default', renderer ) def test_from_raw_profiles(self): @@ -291,9 +297,10 @@ def test_missing_target(self): self.assertEqual(profile.credentials.type, 'postgres') def test_profile_invalid_project(self): + renderer = empty_renderer() with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'invalid-profile', {} + self.default_profile_data, 'invalid-profile', renderer ) self.assertEqual(exc.exception.result_type, 'invalid_project') @@ -301,9 +308,10 @@ def test_profile_invalid_project(self): self.assertIn('invalid-profile', str(exc.exception)) def test_profile_invalid_target(self): + renderer = empty_renderer() with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default', {}, + self.default_profile_data, 'default', renderer, target_override='nope' ) @@ -313,9 +321,11 @@ def test_profile_invalid_target(self): self.assertIn('- with-vars', str(exc.exception)) def test_no_outputs(self): + renderer = empty_renderer() + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: dbt.config.Profile.from_raw_profiles( - {'some-profile': {'target': 'blah'}}, 'some-profile', {} + {'some-profile': {'target': 'blah'}}, 'some-profile', renderer ) self.assertIn('outputs not specified', str(exc.exception)) self.assertIn('some-profile', str(exc.exception)) @@ -325,23 +335,25 @@ def test_neq(self): self.assertNotEqual(profile, object()) def test_eq(self): + renderer = empty_renderer() profile = dbt.config.Profile.from_raw_profiles( - deepcopy(self.default_profile_data), 'default', {} + deepcopy(self.default_profile_data), 'default', renderer ) other = dbt.config.Profile.from_raw_profiles( - deepcopy(self.default_profile_data), 'default', {} + deepcopy(self.default_profile_data), 'default', renderer ) self.assertEqual(profile, other) def test_invalid_env_vars(self): self.env_override['env_value_port'] = 'hello' + renderer = empty_renderer() with mock.patch.dict(os.environ, self.env_override): with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: dbt.config.Profile.from_raw_profile_info( self.default_profile_data['default'], 'default', - {}, + renderer, target_override='with-vars' ) self.assertIn("not of type 'integer'", str(exc.exception)) @@ -355,10 +367,11 @@ def setUp(self): def from_raw_profile_info(self, raw_profile=None, profile_name='default', **kwargs): if raw_profile is None: raw_profile = self.default_profile_data['default'] + renderer = empty_renderer() kw = { 'raw_profile': raw_profile, 'profile_name': profile_name, - 'cli_vars': {}, + 'renderer': renderer, } kw.update(kwargs) return dbt.config.Profile.from_raw_profile_info(**kw) @@ -367,10 +380,10 @@ def from_args(self, project_profile_name='default', **kwargs): kw = { 'args': self.args, 'project_profile_name': project_profile_name, + 'renderer': empty_renderer() } kw.update(kwargs) - return dbt.config.Profile.from_args(**kw) - + return dbt.config.Profile.render_from_args(**kw) def test_profile_simple(self): profile = self.from_args() @@ -492,11 +505,12 @@ def test_invalid_env_vars(self): def test_cli_and_env_vars(self): self.args.target = 'cli-and-env-vars' self.args.vars = '{"cli_value_host": "cli-postgres-host"}' + renderer = dbt.config.ConfigRenderer(generate_base_context({'cli_value_host': 'cli-postgres-host'})) with mock.patch.dict(os.environ, self.env_override): - profile = self.from_args() + profile = self.from_args(renderer=renderer) from_raw = self.from_raw_profile_info( target_override='cli-and-env-vars', - cli_vars={'cli_value_host': 'cli-postgres-host'}, + renderer=renderer, ) self.assertEqual(profile.profile_name, 'default') @@ -526,7 +540,7 @@ def setUp(self): def test_defaults(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project.project_name, 'my_test_project') self.assertEqual(project.version, '0.0.1') @@ -556,16 +570,16 @@ def test_defaults(self): def test_eq(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) other = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project, other) def test_neq(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertNotEqual(project, object()) @@ -575,14 +589,14 @@ def test_implicit_overrides(self): 'target-path': 'other-target', }) project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project.docs_paths, ['other-models', 'data', 'snapshots']) self.assertEqual(project.clean_targets, ['other-target']) def test_hashed_name(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project.hashed_name(), '754cd47eac1d6f50a5f7cd399ec43da4') @@ -709,7 +723,7 @@ def test_string_run_hooks(self): 'on-run-end': '{{ logging.log_run_end_event() }}', }) project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual( project.on_run_start, @@ -723,25 +737,26 @@ def test_string_run_hooks(self): def test_invalid_project_name(self): self.default_project_data['name'] = 'invalid-project-name' with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: - dbt.config.Project.from_project_config(self.default_project_data) + dbt.config.Project.from_project_config(self.default_project_data, None) self.assertIn('invalid-project-name', str(exc.exception)) def test_no_project(self): + renderer = empty_renderer() with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: - dbt.config.Project.from_project_root(self.project_dir, {}) + dbt.config.Project.from_project_root(self.project_dir, renderer) self.assertIn('no dbt_project.yml', str(exc.exception)) def test_invalid_version(self): self.default_project_data['require-dbt-version'] = 'hello!' with self.assertRaises(dbt.exceptions.DbtProjectError): - dbt.config.Project.from_project_config(self.default_project_data) + dbt.config.Project.from_project_config(self.default_project_data, None) def test_unsupported_version(self): self.default_project_data['require-dbt-version'] = '>99999.0.0' # allowed, because the RuntimeConfig checks, not the Project itself - dbt.config.Project.from_project_config(self.default_project_data) + dbt.config.Project.from_project_config(self.default_project_data, None) def test__no_unused_resource_config_paths(self): self.default_project_data.update({ @@ -749,7 +764,7 @@ def test__no_unused_resource_config_paths(self): 'seeds': {}, }) project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) resource_fqns = {'models': model_fqns} @@ -762,7 +777,7 @@ def test__unused_resource_config_paths(self): 'seeds': {}, }) project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) resource_fqns = {'models': model_fqns} @@ -771,7 +786,7 @@ def test__unused_resource_config_paths(self): def test__get_unused_resource_config_paths_empty(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) unused = project.get_unused_resource_config_paths({'models': frozenset(( ('my_test_project', 'foo', 'bar'), @@ -781,7 +796,7 @@ def test__get_unused_resource_config_paths_empty(self): def test__warn_for_unused_resource_config_paths_empty(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) dbt.flags.WARN_ERROR = True try: @@ -800,7 +815,7 @@ def test_none_values(self): 'on-run-start': None, }) project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project.models, {}) self.assertEqual(project.on_run_start, []) @@ -813,7 +828,7 @@ def test_nested_none_values(self): 'seeds': {'vars': None, 'pre-hook': None, 'post-hook': None, 'column_types': None}, }) project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project.models, {'vars': {}, 'pre-hook': [], 'post-hook': []}) self.assertEqual(project.seeds, {'vars': {}, 'pre-hook': [], 'post-hook': [], 'column_types': {}}) @@ -826,7 +841,7 @@ def test_cycle(self): }) with self.assertRaises(dbt.exceptions.DbtProjectError): dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) @@ -857,7 +872,7 @@ def setUp(self): def test__get_unused_resource_config_paths(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) unused = project.get_unused_resource_config_paths(self.used, []) self.assertEqual(len(unused), 1) @@ -866,14 +881,14 @@ def test__get_unused_resource_config_paths(self): @mock.patch.object(dbt.config.project, 'warn_or_error') def test__warn_for_unused_resource_config_paths(self, warn_or_error): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) unused = project.warn_for_unused_resource_config_paths(self.used, []) warn_or_error.assert_called_once() def test__warn_for_unused_resource_config_paths_disabled(self): project = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) unused = project.get_unused_resource_config_paths( self.used, @@ -891,18 +906,20 @@ def setUp(self): self.default_project_data['project-root'] = self.project_dir def test_from_project_root(self): - project = dbt.config.Project.from_project_root(self.project_dir, {}) + renderer = empty_renderer() + project = dbt.config.Project.from_project_root(self.project_dir, renderer) from_config = dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) self.assertEqual(project, from_config) self.assertEqual(project.version, "0.0.1") self.assertEqual(project.project_name, 'my_test_project') def test_with_invalid_package(self): + renderer = empty_renderer() self.write_packages({'invalid': ['not a package of any kind']}) with self.assertRaises(dbt.exceptions.DbtProjectError): - dbt.config.Project.from_project_root(self.project_dir, {}) + dbt.config.Project.from_project_root(self.project_dir, renderer) class TestRunOperationTask(BaseFileTest): @@ -940,11 +957,11 @@ def setUp(self): self.default_project_data['project-root'] = self.project_dir def test_cli_and_env_vars(self): - cli_vars = '{"cli_version": "0.1.2"}' + renderer = dbt.config.ConfigRenderer(generate_base_context({'cli_version': '0.1.2'})) with mock.patch.dict(os.environ, self.env_override): project = dbt.config.Project.from_project_root( self.project_dir, - cli_vars + renderer, ) self.assertEqual(project.version, "0.1.2") @@ -960,12 +977,13 @@ def setUp(self): def get_project(self): return dbt.config.Project.from_project_config( - self.default_project_data + self.default_project_data, None ) def get_profile(self): + renderer = empty_renderer() return dbt.config.Profile.from_raw_profiles( - self.default_profile_data, self.default_project_data['profile'], {} + self.default_profile_data, self.default_project_data['profile'], renderer ) def from_parts(self, exc=None): diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 0b96adfd1a9..00b71de5e65 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -2,7 +2,7 @@ from unittest import mock from dbt.contracts.graph.parsed import ParsedModelNode, NodeConfig, DependsOn -from dbt.context import parser, runtime +from dbt.context import providers from dbt.node_types import NodeType import dbt.exceptions from .mock_adapter import adapter_factory @@ -44,36 +44,36 @@ def setUp(self): self.context = mock.MagicMock() def test_var_default_something(self): - var = runtime.Var(self.model, self.context, overrides={'foo': 'baz'}) + var = providers.RuntimeVar(self.model, self.context, overrides={'foo': 'baz'}) self.assertEqual(var('foo'), 'baz') self.assertEqual(var('foo', 'bar'), 'baz') def test_var_default_none(self): - var = runtime.Var(self.model, self.context, overrides={'foo': None}) + var = providers.RuntimeVar(self.model, self.context, overrides={'foo': None}) self.assertEqual(var('foo'), None) self.assertEqual(var('foo', 'bar'), None) def test_var_not_defined(self): - var = runtime.Var(self.model, self.context, overrides={}) + var = providers.RuntimeVar(self.model, self.context, overrides={}) self.assertEqual(var('foo', 'bar'), 'bar') with self.assertRaises(dbt.exceptions.CompilationException): var('foo') def test_parser_var_default_something(self): - var = parser.Var(self.model, self.context, overrides={'foo': 'baz'}) + var = providers.ParseVar(self.model, self.context, overrides={'foo': 'baz'}) self.assertEqual(var('foo'), 'baz') self.assertEqual(var('foo', 'bar'), 'baz') def test_parser_var_default_none(self): - var = parser.Var(self.model, self.context, overrides={'foo': None}) + var = providers.ParseVar(self.model, self.context, overrides={'foo': None}) self.assertEqual(var('foo'), None) self.assertEqual(var('foo', 'bar'), None) def test_parser_var_not_defined(self): # at parse-time, we should not raise if we encounter a missing var # that way disabled models don't get parse errors - var = parser.Var(self.model, self.context, overrides={}) + var = providers.ParseVar(self.model, self.context, overrides={}) self.assertEqual(var('foo', 'bar'), 'bar') self.assertEqual(var('foo'), None) @@ -84,7 +84,7 @@ def setUp(self): self.mock_config = mock.MagicMock() adapter_class = adapter_factory() self.mock_adapter = adapter_class(self.mock_config) - self.wrapper = parser.DatabaseWrapper(self.mock_adapter) + self.wrapper = providers.ParseDatabaseWrapper(self.mock_adapter) self.responder = self.mock_adapter.responder def test_unwrapped_method(self): @@ -103,7 +103,7 @@ def setUp(self): self.mock_config.quoting = {'database': True, 'schema': True, 'identifier': True} adapter_class = adapter_factory() self.mock_adapter = adapter_class(self.mock_config) - self.wrapper = runtime.DatabaseWrapper(self.mock_adapter) + self.wrapper = providers.RuntimeDatabaseWrapper(self.mock_adapter) self.responder = self.mock_adapter.responder def test_unwrapped_method(self): diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 95c0e196ed3..697f301b7bc 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -4,7 +4,6 @@ import dbt.clients.system import dbt.compilation -import dbt.context.parser import dbt.exceptions import dbt.flags import dbt.linker @@ -33,7 +32,7 @@ def tearDown(self): self.load_projects_patcher.stop() self.file_system_patcher.stop() self.get_adapter_patcher.stop() - self.get_adapter_patcher_cmn.stop() + self.get_adapter_patcher_parser.stop() self.mock_filesystem_constructor.stop() self.mock_hook_constructor.stop() self.load_patch.stop() @@ -51,12 +50,12 @@ def setUp(self): self.hook_patcher = patch.object( dbt.parser.hooks.HookParser, '__new__' ) - self.get_adapter_patcher = patch('dbt.context.parser.get_adapter') + self.get_adapter_patcher = patch('dbt.context.providers.get_adapter') self.factory = self.get_adapter_patcher.start() # also patch this one - self.get_adapter_patcher_cmn = patch('dbt.context.common.get_adapter') - self.factory_cmn = self.get_adapter_patcher_cmn.start() + self.get_adapter_patcher_parser = patch('dbt.parser.base.get_adapter') + self.factory_cmn = self.get_adapter_patcher_parser.start() def mock_write_gpickle(graph, outfile): diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index d86becc501e..21810d3ced0 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -8,14 +8,15 @@ import dbt.parser from dbt.exceptions import CompilationException from dbt.parser import ( - ModelParser, MacroParser, DataTestParser, SchemaParser, ParserUtils, - ParseResult, SnapshotParser, AnalysisParser + ModelParser, MacroParser, DataTestParser, SchemaParser, ParseResult, + SnapshotParser, AnalysisParser ) from dbt.parser.schemas import ( TestablePatchParser, SourceParser, AnalysisPatchParser, MacroPatchParser ) from dbt.parser.search import FileBlock from dbt.parser.schema_test_builders import YamlBlock +from dbt.parser.manifest import process_docs, process_sources, process_refs from dbt.node_types import NodeType from dbt.contracts.graph.manifest import ( @@ -89,15 +90,16 @@ def setUp(self): 'root': self.root_project_config, 'snowplow': self.snowplow_project_config } - self.patcher = mock.patch('dbt.context.parser.get_adapter') + self.patcher = mock.patch('dbt.context.providers.get_adapter') self.factory = self.patcher.start() - self.patcher_cmn = mock.patch('dbt.context.common.get_adapter') - self.factory_cmn = self.patcher_cmn.start() + + self.parser_patcher = mock.patch('dbt.parser.base.get_adapter') + self.factory_parser = self.parser_patcher.start() self.macro_manifest = Manifest.from_macros() def tearDown(self): - self.patcher_cmn.stop() + self.parser_patcher.stop() self.patcher.stop() def file_block_for(self, data: str, filename: str, searched: str): @@ -682,8 +684,9 @@ def test_basic(self): self.assertEqual(self.parser.results.files[path].nodes, ['analysis.snowplow.analysis_1']) -class ParserUtilsTest(unittest.TestCase): +class ProcessingTest(BaseParserTest): def setUp(self): + super().setUp() x_depends_on = mock.MagicMock() y_depends_on = mock.MagicMock() x_uid = 'model.project.x' @@ -751,19 +754,15 @@ def setUp(self): nodes=nodes, macros={}, docs=docs, disabled=[], files={}, generated_at=mock.MagicMock() ) - def test_resolve_docs(self): - # no error. TODO: real test - result = ParserUtils.process_docs(self.manifest, 'project') - self.assertIs(result, self.manifest) + def test_process_docs(self): + process_docs(self.manifest, self.root_project_config) self.assertEqual(self.x_node.description, 'other_project: some docs') self.assertEqual(self.y_node.description, 'some docs') - def test_resolve_sources(self): - result = ParserUtils.process_sources(self.manifest, 'project') - self.assertIs(result, self.manifest) + def test_process_sources(self): + process_sources(self.manifest, 'project') self.x_node.depends_on.nodes.append.assert_called_once_with('source.thirdproject.src.tbl') - def test_resolve_refs(self): - result = ParserUtils.process_refs(self.manifest, 'project') - self.assertIs(result, self.manifest) + def test_process_refs(self): + process_refs(self.manifest, 'project') self.y_node.depends_on.nodes.append.assert_called_once_with('model.project.x') diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 5b5c6795dce..0f6b079f754 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -4,7 +4,9 @@ import dbt.flags as flags from dbt.task.debug import DebugTask +from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.adapters.postgres import PostgresAdapter +from dbt.config import ConfigRenderer from dbt.exceptions import ValidationException, DbtConfigError from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt.parser.results import ParseResult @@ -244,6 +246,8 @@ def setUp(self): self.psycopg2.connect.return_value = self.handle self.adapter = PostgresAdapter(self.config) + self.adapter.connections.query_header = MacroQueryStringSetter(self.config, mock.MagicMock(macros={})) + self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add') self.mock_query_header_add = self.qh_patch.start() self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q) diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index a54d6921289..95bf593e294 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -4,8 +4,8 @@ import dbt.flags as flags -import dbt.parser.manifest from dbt.adapters.snowflake import SnowflakeAdapter +from dbt.adapters.base.query_headers import MacroQueryStringSetter from dbt.logger import GLOBAL_LOGGER as logger # noqa from dbt.parser.results import ParseResult from snowflake import connector as snowflake_connector @@ -61,6 +61,8 @@ def setUp(self): self.snowflake.return_value = self.handle self.adapter = SnowflakeAdapter(self.config) + self.adapter.connections.query_header = MacroQueryStringSetter(self.config, mock.MagicMock(macros={})) + self.qh_patch = mock.patch.object(self.adapter.connections.query_header, 'add') self.mock_query_header_add = self.qh_patch.start() self.mock_query_header_add.side_effect = lambda q: '/* dbt */\n{}'.format(q) diff --git a/test/unit/test_source_config.py b/test/unit/test_source_config.py index 25147f7c131..5b9cac11875 100644 --- a/test/unit/test_source_config.py +++ b/test/unit/test_source_config.py @@ -59,7 +59,7 @@ def setUp(self): 'root': self.root_project_config, 'snowplow': self.snowplow_project_config } - self.patcher = mock.patch('dbt.context.parser.get_adapter') + self.patcher = mock.patch('dbt.context.providers.get_adapter') self.factory = self.patcher.start() def tearDown(self): diff --git a/test/unit/utils.py b/test/unit/utils.py index 1be0078313e..f392cb7cd4c 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -34,17 +34,25 @@ def mock_connection(name): def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): - from dbt.config import Project, Profile, RuntimeConfig + from dbt.config import Project, Profile, RuntimeConfig, ConfigRenderer + from dbt.context.base import generate_base_context from dbt.utils import parse_cli_vars from copy import deepcopy + if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) + renderer = ConfigRenderer(generate_base_context((cli_vars))) + if not isinstance(project, Project): project = Project.from_project_config(deepcopy(project), packages) + if not isinstance(profile, Profile): - profile = Profile.from_raw_profile_info(deepcopy(profile), - project.profile_name, - cli_vars) + profile = Profile.from_raw_profile_info( + deepcopy(profile), + project.profile_name, + renderer, + ) + args = Obj() args.vars = repr(cli_vars) args.profile_dir = '/dev/null'