Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: node depends_on.macros (#2082) #2103

Merged
merged 1 commit into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from dbt import deprecations
from dbt.clients.agate_helper import empty_table, merge_tables
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedSeedNode
Expand Down Expand Up @@ -996,7 +997,7 @@ def execute_macro(
)
macro_context.update(context_override)

macro_function = macro.generator(macro_context)
macro_function = MacroGenerator(macro, macro_context)

try:
result = macro_function(**kwargs)
Expand Down
102 changes: 69 additions & 33 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import linecache
import os
import tempfile
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn
)
Expand All @@ -15,12 +15,16 @@
import jinja2.parser
import jinja2.sandbox

import dbt.exceptions
import dbt.utils
from dbt.utils import (
get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name
)

from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt.exceptions import (
InternalException, raise_compiler_error, CompilationException,
invalid_materialization_argument, MacroReturn
)
from dbt.flags import MACRO_DEBUGGING

from dbt.logger import GLOBAL_LOGGER as logger # noqa


Expand Down Expand Up @@ -64,7 +68,7 @@ def parse_macro(self):
# modified to fuzz macros defined in the same file. this way
# dbt can understand the stack of macros being called.
# - @cmcarthur
node.name = dbt.utils.get_dbt_macro_name(
node.name = get_dbt_macro_name(
self.parse_assign_target(name_only=True).name)

self.parse_signature(node)
Expand Down Expand Up @@ -138,7 +142,7 @@ def get_macro(self):
# make the module. previously we set both vars and local, but that's
# redundant: They both end up in the same place
module = template.make_module(vars=self.context, shared=False)
macro = module.__dict__[dbt.utils.get_dbt_macro_name(name)]
macro = module.__dict__[get_dbt_macro_name(name)]
module.__dict__.update(self.context)
return macro

Expand All @@ -147,11 +151,11 @@ def exception_handler(self) -> Iterator[None]:
try:
yield
except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
dbt.exceptions.raise_compiler_error(str(e))
raise_compiler_error(str(e))

def call_macro(self, *args, **kwargs):
if self.context is None:
raise dbt.exceptions.InternalException(
raise InternalException(
'Context is still None in call_macro!'
)
assert self.context is not None
Expand All @@ -161,46 +165,76 @@ def call_macro(self, *args, **kwargs):
with self.exception_handler():
try:
return macro(*args, **kwargs)
except dbt.exceptions.MacroReturn as e:
except MacroReturn as e:
return e.value


@dataclass
class MacroProxy:
generator: 'MacroGenerator'
class MacroStack(threading.local):
def __init__(self):
super().__init__()
self.call_stack = []

@property
def node(self):
return self.generator.node
def depth(self) -> int:
return len(self.call_stack)

def __call__(self, *args, **kwargs):
return self.generator.call_macro(*args, **kwargs)
def push(self, name):
self.call_stack.append(name)

def pop(self, name):
got = self.call_stack.pop()
if got != name:
raise InternalException(f'popped {got}, expected {name}')


class MacroGenerator(BaseMacroGenerator):
def __init__(self, node, context: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
macro,
context: Optional[Dict[str, Any]] = None,
node: Optional[Any] = None,
stack: Optional[MacroStack] = None
) -> None:
super().__init__(context)
self.macro = macro
self.node = node
self.stack = stack

def get_template(self):
return template_cache.get_node_template(self.node)
return template_cache.get_node_template(self.macro)

def get_name(self) -> str:
return self.node.name
return self.macro.name

@contextmanager
def exception_handler(self) -> Iterator[None]:
try:
yield
except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e:
dbt.exceptions.raise_compiler_error(str(e), self.node)
except dbt.exceptions.CompilationException as e:
e.stack.append(self.node)
raise_compiler_error(str(e), self.macro)
except CompilationException as e:
e.stack.append(self.macro)
raise e

def __call__(self, context: Dict[str, Any]) -> MacroProxy:
self.context = context
return MacroProxy(self)
@contextmanager
def track_call(self):
if self.stack is None or self.node is None:
yield
else:
unique_id = self.macro.unique_id
depth = self.stack.depth
# only mark depth=0 as a dependency
if depth == 0:
self.node.depends_on.add_macro(unique_id)
self.stack.push(unique_id)
try:
yield
finally:
self.stack.pop(unique_id)

def __call__(self, *args, **kwargs):
with self.track_call():
return self.call_macro(*args, **kwargs)


class QueryStringGenerator(BaseMacroGenerator):
Expand Down Expand Up @@ -250,11 +284,13 @@ def parse(self, parser):
adapter_name = value.value

else:
dbt.exceptions.invalid_materialization_argument(
materialization_name, target.name)
invalid_materialization_argument(
materialization_name, target.name
)

node.name = dbt.utils.get_materialization_macro_name(
materialization_name, adapter_name)
node.name = get_materialization_macro_name(
materialization_name, adapter_name
)

node.body = parser.parse_statements(('name:endmaterialization',),
drop_needle=True)
Expand All @@ -271,7 +307,7 @@ def parse(self, parser):

node.args = []
node.defaults = []
node.name = dbt.utils.get_docs_macro_name(docs_name)
node.name = get_docs_macro_name(docs_name)
node.body = parser.parse_statements(('name:enddocs',),
drop_needle=True)
return node
Expand Down Expand Up @@ -308,7 +344,7 @@ def __deepcopy__(self, memo):
self.node.name, path))

# match jinja's message
dbt.exceptions.raise_compiler_error(
raise_compiler_error(
"{!r} is undefined".format(self.name),
node=self.node
)
Expand Down Expand Up @@ -356,9 +392,9 @@ def catch_jinja(node=None) -> Iterator[None]:
yield
except jinja2.exceptions.TemplateSyntaxError as e:
e.translated = False
raise dbt.exceptions.CompilationException(str(e), node) from e
raise CompilationException(str(e), node) from e
except jinja2.exceptions.UndefinedError as e:
raise dbt.exceptions.CompilationException(str(e), node) from e
raise CompilationException(str(e), node) from e


def parse(string):
Expand Down
49 changes: 35 additions & 14 deletions core/dbt/context/configured.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, Iterable, Union
from typing import Any, Dict, Iterable, Union, Optional

from dbt.clients.jinja import MacroProxy
from dbt.clients.jinja import MacroGenerator, MacroStack
from dbt.contracts.connection import AdapterRequiredConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.graph.parsed import ParsedMacro
Expand All @@ -25,20 +25,32 @@ def project_name(self) -> str:
return self.config.project_name


Namespace = Union[Dict[str, MacroProxy], MacroProxy]
FlatNamespace = Dict[str, MacroGenerator]
NamespaceMember = Union[FlatNamespace, MacroGenerator]
FullNamespace = Dict[str, NamespaceMember]


class _MacroNamespace:
def __init__(self, root_package, search_package):
class MacroNamespace:
def __init__(
self,
root_package: str,
search_package: str,
thread_ctx: MacroStack,
node: Optional[Any] = None,
) -> None:
self.root_package = root_package
self.search_package = search_package
self.globals: Dict[str, MacroProxy] = {}
self.locals: Dict[str, MacroProxy] = {}
self.packages: Dict[str, Dict[str, MacroProxy]] = {}
self.globals: FlatNamespace = {}
self.locals: FlatNamespace = {}
self.packages: Dict[str, FlatNamespace] = {}
self.thread_ctx = thread_ctx
self.node = node

def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
macro_name: str = macro.name
macro_func: MacroProxy = macro.generator(ctx)
macro_func: MacroGenerator = MacroGenerator(
macro, ctx, self.node, self.thread_ctx
)

# put plugin macros into the root namespace
if macro.package_name in PACKAGES:
Expand All @@ -47,11 +59,11 @@ def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]):
namespace = macro.package_name

if namespace not in self.packages:
value: Dict[str, MacroProxy] = {}
value: Dict[str, MacroGenerator] = {}
self.packages[namespace] = value

if macro_name in self.packages[namespace]:
raise_duplicate_macro_name(macro_func.node, macro, namespace)
raise_duplicate_macro_name(macro_func.macro, macro, namespace)
self.packages[namespace][macro_name] = macro_func

if namespace == self.search_package:
Expand All @@ -63,8 +75,8 @@ def add_macros(self, macros: Iterable[ParsedMacro], ctx: Dict[str, Any]):
for macro in macros:
self.add_macro(macro, ctx)

def get_macro_dict(self) -> Dict[str, Any]:
root_namespace: Dict[str, Namespace] = {}
def get_macro_dict(self) -> FullNamespace:
root_namespace: FullNamespace = {}

root_namespace.update(self.packages)
root_namespace.update(self.globals)
Expand All @@ -89,9 +101,18 @@ def __init__(
super().__init__(config)
self.manifest = manifest
self.search_package = search_package
self.macro_stack = MacroStack()

def _get_namespace(self):
return MacroNamespace(
self.config.project_name,
self.search_package,
self.macro_stack,
None,
)

def get_macros(self) -> Dict[str, Any]:
nsp = _MacroNamespace(self.config.project_name, self.search_package)
nsp = self._get_namespace()
nsp.add_macros(self.manifest.macros.values(), self._ctx)
return nsp.get_macro_dict()

Expand Down
10 changes: 9 additions & 1 deletion core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from dbt.context.base import (
contextmember, contextproperty, Var
)
from dbt.context.configured import ManifestContext
from dbt.context.configured import ManifestContext, MacroNamespace
from dbt.contracts.graph.manifest import Manifest, Disabled
from dbt.contracts.graph.compiled import (
NonSourceNode, CompiledSeedNode
Expand Down Expand Up @@ -451,6 +451,14 @@ def __init__(self, model, config, manifest, provider, source_config):
self.adapter = get_adapter(self.config)
self.db_wrapper = self.provider.DatabaseWrapper(self.adapter)

def _get_namespace(self):
return MacroNamespace(
self.config.project_name,
self.search_package,
self.macro_stack,
self.model,
)

@contextproperty
def _sql_results(self) -> Dict[str, AttrDict]:
return self.sql_results
Expand Down
30 changes: 15 additions & 15 deletions core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
StrEnum, register_pattern
)

from dbt.clients.jinja import MacroGenerator
from dbt.clients.system import write_file
import dbt.flags
from dbt.contracts.graph.unparsed import (
Expand Down Expand Up @@ -147,10 +146,23 @@ class HasUniqueID(JsonSchemaMixin, Replaceable):


@dataclass
class DependsOn(JsonSchemaMixin, Replaceable):
nodes: List[str] = field(default_factory=list)
class MacroDependsOn(JsonSchemaMixin, Replaceable):
macros: List[str] = field(default_factory=list)

# 'in' on lists is O(n) so this is O(n^2) for # of macros
def add_macro(self, value: str):
if value not in self.macros:
self.macros.append(value)


@dataclass
class DependsOn(MacroDependsOn):
nodes: List[str] = field(default_factory=list)

def add_node(self, value: str):
if value not in self.nodes:
self.nodes.append(value)


@dataclass
class HasRelationMetadata(JsonSchemaMixin, Replaceable):
Expand Down Expand Up @@ -484,11 +496,6 @@ class ParsedMacroPatch(ParsedPatch):
pass


@dataclass
class MacroDependsOn(JsonSchemaMixin, Replaceable):
macros: List[str] = field(default_factory=list)


@dataclass
class ParsedMacro(UnparsedMacro, HasUniqueID):
name: str
Expand All @@ -505,13 +512,6 @@ class ParsedMacro(UnparsedMacro, HasUniqueID):
def local_vars(self):
return {}

@property
def generator(self) -> MacroGenerator:
"""
Returns a function that can be called to render the macro results.
"""
return MacroGenerator(self)

def patch(self, patch: ParsedMacroPatch):
self.patch_path: Optional[str] = patch.original_file_path
self.description = patch.description
Expand Down
Loading