diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 5d21db61b6e..3af51b1d24e 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -20,6 +20,7 @@ from dbt import deprecations from dbt.clients.agate_helper import empty_table, merge_tables +from dbt.clients.jinja import MacroGenerator from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedSeedNode @@ -996,7 +997,7 @@ def execute_macro( ) macro_context.update(context_override) - macro_function = macro.generator(macro_context) + macro_function = MacroGenerator(macro, macro_context) try: result = macro_function(**kwargs) diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index 7180fb79f71..37c2052601b 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -2,8 +2,8 @@ import linecache import os import tempfile +import threading from contextlib import contextmanager -from dataclasses import dataclass from typing import ( List, Union, Set, Optional, Dict, Any, Iterator, Type, NoReturn ) @@ -15,12 +15,16 @@ import jinja2.parser import jinja2.sandbox -import dbt.exceptions -import dbt.utils +from dbt.utils import ( + get_dbt_macro_name, get_docs_macro_name, get_materialization_macro_name +) from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag +from dbt.exceptions import ( + InternalException, raise_compiler_error, CompilationException, + invalid_materialization_argument, MacroReturn +) from dbt.flags import MACRO_DEBUGGING - from dbt.logger import GLOBAL_LOGGER as logger # noqa @@ -64,7 +68,7 @@ def parse_macro(self): # modified to fuzz macros defined in the same file. this way # dbt can understand the stack of macros being called. # - @cmcarthur - node.name = dbt.utils.get_dbt_macro_name( + node.name = get_dbt_macro_name( self.parse_assign_target(name_only=True).name) self.parse_signature(node) @@ -138,7 +142,7 @@ def get_macro(self): # make the module. previously we set both vars and local, but that's # redundant: They both end up in the same place module = template.make_module(vars=self.context, shared=False) - macro = module.__dict__[dbt.utils.get_dbt_macro_name(name)] + macro = module.__dict__[get_dbt_macro_name(name)] module.__dict__.update(self.context) return macro @@ -147,11 +151,11 @@ def exception_handler(self) -> Iterator[None]: try: yield except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e: - dbt.exceptions.raise_compiler_error(str(e)) + raise_compiler_error(str(e)) def call_macro(self, *args, **kwargs): if self.context is None: - raise dbt.exceptions.InternalException( + raise InternalException( 'Context is still None in call_macro!' ) assert self.context is not None @@ -161,46 +165,76 @@ def call_macro(self, *args, **kwargs): with self.exception_handler(): try: return macro(*args, **kwargs) - except dbt.exceptions.MacroReturn as e: + except MacroReturn as e: return e.value -@dataclass -class MacroProxy: - generator: 'MacroGenerator' +class MacroStack(threading.local): + def __init__(self): + super().__init__() + self.call_stack = [] @property - def node(self): - return self.generator.node + def depth(self) -> int: + return len(self.call_stack) - def __call__(self, *args, **kwargs): - return self.generator.call_macro(*args, **kwargs) + def push(self, name): + self.call_stack.append(name) + + def pop(self, name): + got = self.call_stack.pop() + if got != name: + raise InternalException(f'popped {got}, expected {name}') class MacroGenerator(BaseMacroGenerator): - def __init__(self, node, context: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, + macro, + context: Optional[Dict[str, Any]] = None, + node: Optional[Any] = None, + stack: Optional[MacroStack] = None + ) -> None: super().__init__(context) + self.macro = macro self.node = node + self.stack = stack def get_template(self): - return template_cache.get_node_template(self.node) + return template_cache.get_node_template(self.macro) def get_name(self) -> str: - return self.node.name + return self.macro.name @contextmanager def exception_handler(self) -> Iterator[None]: try: yield except (TypeError, jinja2.exceptions.TemplateRuntimeError) as e: - dbt.exceptions.raise_compiler_error(str(e), self.node) - except dbt.exceptions.CompilationException as e: - e.stack.append(self.node) + raise_compiler_error(str(e), self.macro) + except CompilationException as e: + e.stack.append(self.macro) raise e - def __call__(self, context: Dict[str, Any]) -> MacroProxy: - self.context = context - return MacroProxy(self) + @contextmanager + def track_call(self): + if self.stack is None or self.node is None: + yield + else: + unique_id = self.macro.unique_id + depth = self.stack.depth + # only mark depth=0 as a dependency + if depth == 0: + self.node.depends_on.add_macro(unique_id) + self.stack.push(unique_id) + try: + yield + finally: + self.stack.pop(unique_id) + + def __call__(self, *args, **kwargs): + with self.track_call(): + return self.call_macro(*args, **kwargs) class QueryStringGenerator(BaseMacroGenerator): @@ -250,11 +284,13 @@ def parse(self, parser): adapter_name = value.value else: - dbt.exceptions.invalid_materialization_argument( - materialization_name, target.name) + invalid_materialization_argument( + materialization_name, target.name + ) - node.name = dbt.utils.get_materialization_macro_name( - materialization_name, adapter_name) + node.name = get_materialization_macro_name( + materialization_name, adapter_name + ) node.body = parser.parse_statements(('name:endmaterialization',), drop_needle=True) @@ -271,7 +307,7 @@ def parse(self, parser): node.args = [] node.defaults = [] - node.name = dbt.utils.get_docs_macro_name(docs_name) + node.name = get_docs_macro_name(docs_name) node.body = parser.parse_statements(('name:enddocs',), drop_needle=True) return node @@ -308,7 +344,7 @@ def __deepcopy__(self, memo): self.node.name, path)) # match jinja's message - dbt.exceptions.raise_compiler_error( + raise_compiler_error( "{!r} is undefined".format(self.name), node=self.node ) @@ -356,9 +392,9 @@ def catch_jinja(node=None) -> Iterator[None]: yield except jinja2.exceptions.TemplateSyntaxError as e: e.translated = False - raise dbt.exceptions.CompilationException(str(e), node) from e + raise CompilationException(str(e), node) from e except jinja2.exceptions.UndefinedError as e: - raise dbt.exceptions.CompilationException(str(e), node) from e + raise CompilationException(str(e), node) from e def parse(string): diff --git a/core/dbt/context/configured.py b/core/dbt/context/configured.py index 9acdd1c1d24..b979ae849e8 100644 --- a/core/dbt/context/configured.py +++ b/core/dbt/context/configured.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, Iterable, Union +from typing import Any, Dict, Iterable, Union, Optional -from dbt.clients.jinja import MacroProxy +from dbt.clients.jinja import MacroGenerator, MacroStack from dbt.contracts.connection import AdapterRequiredConfig from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedMacro @@ -25,20 +25,32 @@ def project_name(self) -> str: return self.config.project_name -Namespace = Union[Dict[str, MacroProxy], MacroProxy] +FlatNamespace = Dict[str, MacroGenerator] +NamespaceMember = Union[FlatNamespace, MacroGenerator] +FullNamespace = Dict[str, NamespaceMember] -class _MacroNamespace: - def __init__(self, root_package, search_package): +class MacroNamespace: + def __init__( + self, + root_package: str, + search_package: str, + thread_ctx: MacroStack, + node: Optional[Any] = None, + ) -> None: self.root_package = root_package self.search_package = search_package - self.globals: Dict[str, MacroProxy] = {} - self.locals: Dict[str, MacroProxy] = {} - self.packages: Dict[str, Dict[str, MacroProxy]] = {} + self.globals: FlatNamespace = {} + self.locals: FlatNamespace = {} + self.packages: Dict[str, FlatNamespace] = {} + self.thread_ctx = thread_ctx + self.node = node def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]): macro_name: str = macro.name - macro_func: MacroProxy = macro.generator(ctx) + macro_func: MacroGenerator = MacroGenerator( + macro, ctx, self.node, self.thread_ctx + ) # put plugin macros into the root namespace if macro.package_name in PACKAGES: @@ -47,11 +59,11 @@ def add_macro(self, macro: ParsedMacro, ctx: Dict[str, Any]): namespace = macro.package_name if namespace not in self.packages: - value: Dict[str, MacroProxy] = {} + value: Dict[str, MacroGenerator] = {} self.packages[namespace] = value if macro_name in self.packages[namespace]: - raise_duplicate_macro_name(macro_func.node, macro, namespace) + raise_duplicate_macro_name(macro_func.macro, macro, namespace) self.packages[namespace][macro_name] = macro_func if namespace == self.search_package: @@ -63,8 +75,8 @@ def add_macros(self, macros: Iterable[ParsedMacro], ctx: Dict[str, Any]): for macro in macros: self.add_macro(macro, ctx) - def get_macro_dict(self) -> Dict[str, Any]: - root_namespace: Dict[str, Namespace] = {} + def get_macro_dict(self) -> FullNamespace: + root_namespace: FullNamespace = {} root_namespace.update(self.packages) root_namespace.update(self.globals) @@ -89,9 +101,18 @@ def __init__( super().__init__(config) self.manifest = manifest self.search_package = search_package + self.macro_stack = MacroStack() + + def _get_namespace(self): + return MacroNamespace( + self.config.project_name, + self.search_package, + self.macro_stack, + None, + ) def get_macros(self) -> Dict[str, Any]: - nsp = _MacroNamespace(self.config.project_name, self.search_package) + nsp = self._get_namespace() nsp.add_macros(self.manifest.macros.values(), self._ctx) return nsp.get_macro_dict() diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 900ae2741a1..6fafe2ed247 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -14,7 +14,7 @@ from dbt.context.base import ( contextmember, contextproperty, Var ) -from dbt.context.configured import ManifestContext +from dbt.context.configured import ManifestContext, MacroNamespace from dbt.contracts.graph.manifest import Manifest, Disabled from dbt.contracts.graph.compiled import ( NonSourceNode, CompiledSeedNode @@ -451,6 +451,14 @@ def __init__(self, model, config, manifest, provider, source_config): self.adapter = get_adapter(self.config) self.db_wrapper = self.provider.DatabaseWrapper(self.adapter) + def _get_namespace(self): + return MacroNamespace( + self.config.project_name, + self.search_package, + self.macro_stack, + self.model, + ) + @contextproperty def _sql_results(self) -> Dict[str, AttrDict]: return self.sql_results diff --git a/core/dbt/contracts/graph/parsed.py b/core/dbt/contracts/graph/parsed.py index 9695ebaec25..d1411f368ce 100644 --- a/core/dbt/contracts/graph/parsed.py +++ b/core/dbt/contracts/graph/parsed.py @@ -17,7 +17,6 @@ StrEnum, register_pattern ) -from dbt.clients.jinja import MacroGenerator from dbt.clients.system import write_file import dbt.flags from dbt.contracts.graph.unparsed import ( @@ -147,10 +146,23 @@ class HasUniqueID(JsonSchemaMixin, Replaceable): @dataclass -class DependsOn(JsonSchemaMixin, Replaceable): - nodes: List[str] = field(default_factory=list) +class MacroDependsOn(JsonSchemaMixin, Replaceable): macros: List[str] = field(default_factory=list) + # 'in' on lists is O(n) so this is O(n^2) for # of macros + def add_macro(self, value: str): + if value not in self.macros: + self.macros.append(value) + + +@dataclass +class DependsOn(MacroDependsOn): + nodes: List[str] = field(default_factory=list) + + def add_node(self, value: str): + if value not in self.nodes: + self.nodes.append(value) + @dataclass class HasRelationMetadata(JsonSchemaMixin, Replaceable): @@ -484,11 +496,6 @@ class ParsedMacroPatch(ParsedPatch): pass -@dataclass -class MacroDependsOn(JsonSchemaMixin, Replaceable): - macros: List[str] = field(default_factory=list) - - @dataclass class ParsedMacro(UnparsedMacro, HasUniqueID): name: str @@ -505,13 +512,6 @@ class ParsedMacro(UnparsedMacro, HasUniqueID): def local_vars(self): return {} - @property - def generator(self) -> MacroGenerator: - """ - Returns a function that can be called to render the macro results. - """ - return MacroGenerator(self) - def patch(self, patch: ParsedMacroPatch): self.patch_path: Optional[str] = patch.original_file_path self.description = patch.description diff --git a/core/dbt/node_runners.py b/core/dbt/node_runners.py index 3356e78993c..f2fbfbca711 100644 --- a/core/dbt/node_runners.py +++ b/core/dbt/node_runners.py @@ -6,19 +6,19 @@ from dbt import deprecations from dbt.adapters.base import BaseRelation -from dbt.logger import GLOBAL_LOGGER as logger +from dbt.clients.jinja import MacroGenerator +from dbt.compilation import compile_node +from dbt.context.providers import generate_runtime_model +from dbt.contracts.graph.manifest import Manifest +from dbt.contracts.results import ( + RunModelResult, collect_timing_info, SourceFreshnessResult, PartialResult, +) from dbt.exceptions import ( NotImplementedException, CompilationException, RuntimeException, InternalException, missing_materialization ) +from dbt.logger import GLOBAL_LOGGER as logger from dbt.node_types import NodeType -from dbt.contracts.graph.manifest import Manifest -from dbt.contracts.results import ( - RunModelResult, collect_timing_info, SourceFreshnessResult, PartialResult, -) -from dbt.compilation import compile_node - -from dbt.context.providers import generate_runtime_model import dbt.exceptions import dbt.tracking import dbt.ui.printer @@ -447,7 +447,7 @@ def execute(self, model, manifest): hook_ctx = self.adapter.pre_model_hook(context_config) try: - result = materialization_macro.generator(context)() + result = MacroGenerator(materialization_macro, context)() finally: self.adapter.post_model_hook(context_config, hook_ctx) diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index 1f302682469..c90eef150c6 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -7,6 +7,8 @@ from hologram import ValidationError +from dbt.clients.jinja import MacroGenerator +from dbt.clients.system import load_file_contents from dbt.context.providers import generate_parser_model, generate_parser_macro import dbt.flags from dbt import deprecations @@ -26,7 +28,6 @@ from dbt.source_config import SourceConfig from dbt.parser.results import ParseResult, ManifestNodes from dbt.parser.search import FileBlock -from dbt.clients.system import load_file_contents # internally, the parser may store a less-restrictive type that will be # transformed into the final type. But it will have to be derived from @@ -132,7 +133,7 @@ def _build_generate_macro_function(self, macro: ParsedMacro) -> Callable: root_context = generate_parser_macro( macro, self.root_project, self.macro_manifest, None ) - return macro.generator(root_context) + return MacroGenerator(macro, root_context) def get_schema_func(self) -> RelationUpdate: """The get_schema function is set by a few different things: diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index 5a8652ad425..6a95cdacc18 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -1055,7 +1055,10 @@ def expected_seeded_manifest(self, model_database=None): 'severity': 'ERROR', }, 'sources': [], - 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, + 'depends_on': { + 'macros': ['macro.dbt.test_not_null'], + 'nodes': ['model.test.model'], + }, 'description': '', 'fqn': ['test', 'schema_test', 'not_null_model_id'], 'name': 'not_null_model_id', @@ -1103,7 +1106,10 @@ def expected_seeded_manifest(self, model_database=None): 'severity': 'ERROR', }, 'sources': [], - 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, + 'depends_on': { + 'macros': ['macro.test.test_nothing'], + 'nodes': ['model.test.model'], + }, 'description': '', 'fqn': ['test', 'schema_test', 'test_nothing_model_'], 'name': 'test_nothing_model_', @@ -1151,7 +1157,10 @@ def expected_seeded_manifest(self, model_database=None): 'severity': 'ERROR', }, 'sources': [], - 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, + 'depends_on': { + 'macros': ['macro.dbt.test_unique'], + 'nodes': ['model.test.model'], + }, 'description': '', 'fqn': ['test', 'schema_test', 'unique_model_id'], 'name': 'unique_model_id', @@ -2916,7 +2925,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'severity': 'ERROR', }, 'sources': [], - 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, + 'depends_on': { + 'macros': ['macro.dbt.test_not_null'], + 'nodes': ['model.test.model'], + }, 'description': '', 'docrefs': [], 'extra_ctes': [], @@ -2974,7 +2986,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'severity': 'ERROR', }, 'database': self.default_database, - 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, + 'depends_on': { + 'macros': ['macro.test.test_nothing'], + 'nodes': ['model.test.model'], + }, 'description': '', 'docrefs': [], 'extra_ctes': [], @@ -3032,7 +3047,10 @@ def expected_run_results(self, quote_schema=True, quote_model=False, 'severity': 'ERROR', }, 'database': self.default_database, - 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, + 'depends_on': { + 'macros': ['macro.dbt.test_unique'], + 'nodes': ['model.test.model'], + }, 'description': '', 'docrefs': [], 'extra_ctes': [], diff --git a/test/integration/047_dbt_ls_test/test_ls.py b/test/integration/047_dbt_ls_test/test_ls.py index cece73b1f3c..da0a4bbd13b 100644 --- a/test/integration/047_dbt_ls_test/test_ls.py +++ b/test/integration/047_dbt_ls_test/test_ls.py @@ -268,7 +268,7 @@ def expect_test_output(self): { 'name': 'not_null_outer_id', 'package_name': 'test', - 'depends_on': {'nodes': ['model.test.outer'], 'macros': []}, + 'depends_on': {'nodes': ['model.test.outer'], 'macros': ['macro.dbt.test_not_null']}, 'tags': ['schema'], 'config': { 'enabled': True, @@ -309,7 +309,7 @@ def expect_test_output(self): { 'name': 'unique_outer_id', 'package_name': 'test', - 'depends_on': {'nodes': ['model.test.outer'], 'macros': []}, + 'depends_on': {'nodes': ['model.test.outer'], 'macros': ['macro.dbt.test_unique']}, 'tags': ['schema'], 'config': { 'enabled': True, diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 8375c4226b8..3c9f71ede6b 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -7,6 +7,7 @@ # make sure 'postgres' is in PACKAGES from dbt.adapters import postgres # noqa +from dbt.clients.jinja import MacroStack from dbt.contracts.graph.parsed import ( ParsedModelNode, NodeConfig, DependsOn, ParsedMacro ) @@ -397,7 +398,7 @@ def test_docs_runtime_context(config): def test_macro_namespace(config, manifest): - mn = configured._MacroNamespace('root', 'search') + mn = configured.MacroNamespace('root', 'search', MacroStack()) mn.add_macros(manifest.macros.values(), {}) # same pkg, same name