From a17eac6f1ade48e10f490bd5de097c1e15608020 Mon Sep 17 00:00:00 2001 From: Gerda Shank Date: Mon, 12 Apr 2021 19:02:00 -0400 Subject: [PATCH] Add necessary macros to schema test context namespace --- core/dbt/clients/jinja.py | 36 +++++++++++++ core/dbt/context/macro_resolver.py | 54 ++++++++++++++----- core/dbt/context/macros.py | 34 +++++++----- core/dbt/context/providers.py | 17 ++++-- core/dbt/parser/manifest.py | 29 +++++++++- .../custom_schema_tests.sql | 19 +++++++ .../test-context-models/model_a.sql | 1 + .../test-context-models/model_b.sql | 1 + .../test-context-models/schema.yml | 8 +++ .../test_schema_v2_tests.py | 36 +++++++++++++ test/unit/test_macro_calls.py | 45 ++++++++++++++++ 11 files changed, 248 insertions(+), 32 deletions(-) create mode 100644 test/integration/008_schema_tests_test/test-context-macros/custom_schema_tests.sql create mode 100644 test/integration/008_schema_tests_test/test-context-models/model_a.sql create mode 100644 test/integration/008_schema_tests_test/test-context-models/model_b.sql create mode 100644 test/integration/008_schema_tests_test/test-context-models/schema.yml create mode 100644 test/unit/test_macro_calls.py diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index f6f7993436a..98ca4c51f1a 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -647,3 +647,39 @@ def _convert_function( kwargs = deep_map(_convert_function, node.test_metadata.kwargs) context[SCHEMA_TEST_KWARGS_NAME] = kwargs + + +def statically_extract_macro_calls(string, ctx): + # set 'capture_macros' to capture undefined + env = get_environment(None, capture_macros=True) + parsed = env.parse(string) + + standard_calls = { + 'source': [], + 'ref': [], + 'config': [], + } + + possible_macro_calls = [] + for func_call in parsed.find_all(jinja2.nodes.Call): + if hasattr(func_call, 'node') and hasattr(func_call.node, 'name'): + func_name = func_call.node.name + else: + # This is a kludge to capture an adapter.dispatch('') call. + # Call(node=Getattr( + # node=Name(name='adapter', ctx='load'), attr='dispatch', ctx='load'), + # args=[Const(value='get_snapshot_unique_id')], kwargs=[], + # dyn_args=None, dyn_kwargs=None) + if (hasattr(func_call, 'node') and hasattr(func_call.node, 'attr') and + func_call.node.attr == 'dispatch'): + func_name = func_call.args[0].value + else: + continue + if func_name in standard_calls: + continue + elif ctx.get(func_name): + continue + else: + possible_macro_calls.append(func_name) + + return possible_macro_calls diff --git a/core/dbt/context/macro_resolver.py b/core/dbt/context/macro_resolver.py index aae83185337..ed0388e050b 100644 --- a/core/dbt/context/macro_resolver.py +++ b/core/dbt/context/macro_resolver.py @@ -14,8 +14,12 @@ # so that higher precedence macros are found first. # This functionality is also provided by the MacroNamespace, # but the intention is to eventually replace that class. -# This enables us to get the macor unique_id without +# This enables us to get the macro unique_id without # processing every macro in the project. +# Note: the root project macros override everything in the +# dbt internal projects. External projects (dependencies) will +# use their own macros first, then pull from the root project +# followed by dbt internal projects. class MacroResolver: def __init__( self, @@ -48,18 +52,29 @@ def _build_internal_packages_namespace(self): self.internal_packages_namespace.update( self.internal_packages[pkg]) + # search order: + # local_namespace (package of particular node), not including + # the internal packages or the root package + # This means that within an extra package, it uses its own macros + # root package namespace + # non-internal packages (that aren't local or root) + # dbt internal packages def _build_macros_by_name(self): macros_by_name = {} - # search root package macros - for macro in self.root_package_macros.values(): + + # all internal packages (already in the right order) + for macro in self.internal_packages_namespace.values(): macros_by_name[macro.name] = macro - # search miscellaneous non-internal packages + + # non-internal packages for fnamespace in self.packages.values(): for macro in fnamespace.values(): macros_by_name[macro.name] = macro - # search all internal packages - for macro in self.internal_packages_namespace.values(): + + # root package macros + for macro in self.root_package_macros.values(): macros_by_name[macro.name] = macro + self.macros_by_name = macros_by_name def _add_macro_to( @@ -97,18 +112,26 @@ def add_macros(self): for macro in self.macros.values(): self.add_macro(macro) - def get_macro_id(self, local_package, macro_name): + def get_macro(self, local_package, macro_name): local_package_macros = {} if (local_package not in self.internal_package_names and local_package in self.packages): local_package_macros = self.packages[local_package] # First: search the local packages for this macro if macro_name in local_package_macros: - return local_package_macros[macro_name].unique_id + return local_package_macros[macro_name] + # Now look up in the standard search order if macro_name in self.macros_by_name: - return self.macros_by_name[macro_name].unique_id + return self.macros_by_name[macro_name] return None + def get_macro_id(self, local_package, macro_name): + macro = self.get_macro(local_package, macro_name) + if macro is None: + return None + else: + return macro.unique_id + # Currently this is just used by test processing in the schema # parser (in connection with the MacroResolver). Future work @@ -127,10 +150,11 @@ def __init__( local_namespace = {} if depends_on_macros: for macro_unique_id in depends_on_macros: - macro = self.manifest.macros[macro_unique_id] - local_namespace[macro.name] = MacroGenerator( - macro, self.ctx, self.node, self.thread_ctx, - ) + if macro_unique_id in self.macro_resolver.macros: + macro = self.macro_resolver.macros[macro_unique_id] + local_namespace[macro.name] = MacroGenerator( + macro, self.ctx, self.node, self.thread_ctx, + ) self.local_namespace = local_namespace def get_from_package( @@ -141,12 +165,14 @@ def get_from_package( macro = self.macro_resolver.macros_by_name.get(name) elif package_name == GLOBAL_PROJECT_NAME: macro = self.macro_resolver.internal_packages_namespace.get(name) - elif package_name in self.resolver.packages: + elif package_name in self.macro_resolver.packages: macro = self.macro_resolver.packages[package_name].get(name) else: raise_compiler_error( f"Could not find package '{package_name}'" ) + if not macro: + return None macro_func = MacroGenerator( macro, self.ctx, self.node, self.thread_ctx ) diff --git a/core/dbt/context/macros.py b/core/dbt/context/macros.py index 6332fb967ca..7d7caa1c14f 100644 --- a/core/dbt/context/macros.py +++ b/core/dbt/context/macros.py @@ -19,13 +19,17 @@ # and provide the ability to flatten them into the ManifestContexts # that are created for jinja, so that macro calls can be resolved. # Creates special iterators and _keys methods to flatten the lists. +# When this class is created it has a static 'local_namespace' which +# depends on the package of the node, so it only works for one +# particular local package at a time for "flattening" into a context. +# 'get_by_package' should work for any macro. class MacroNamespace(Mapping): def __init__( self, - global_namespace: FlatNamespace, - local_namespace: FlatNamespace, - global_project_namespace: FlatNamespace, - packages: Dict[str, FlatNamespace], + global_namespace: FlatNamespace, # root package macros + local_namespace: FlatNamespace, # packages for *this* node + global_project_namespace: FlatNamespace, # internal packages + packages: Dict[str, FlatNamespace], # non-internal packages ): self.global_namespace: FlatNamespace = global_namespace self.local_namespace: FlatNamespace = local_namespace @@ -33,13 +37,13 @@ def __init__( self.global_project_namespace: FlatNamespace = global_project_namespace def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]: - yield self.local_namespace - yield self.global_namespace - yield self.packages + yield self.local_namespace # local package + yield self.global_namespace # root package + yield self.packages # non-internal packages yield { - GLOBAL_PROJECT_NAME: self.global_project_namespace, + GLOBAL_PROJECT_NAME: self.global_project_namespace, # dbt } - yield self.global_project_namespace + yield self.global_project_namespace # other internal project besides dbt # provides special keys method for MacroNamespace iterator # returns keys from local_namespace, global_namespace, packages, @@ -98,7 +102,9 @@ def __init__( # internal packages comes from get_adapter_package_names self.internal_package_names = set(internal_packages) self.internal_package_names_order = internal_packages - # macro_func is added here if in root package + # macro_func is added here if in root package, since + # the root package acts as a "global" namespace, overriding + # everything else except local external package macro calls self.globals: FlatNamespace = {} # macro_func is added here if it's the package for this node self.locals: FlatNamespace = {} @@ -169,8 +175,8 @@ def build_namespace( global_project_namespace.update(self.internal_packages[pkg]) return MacroNamespace( - global_namespace=self.globals, - local_namespace=self.locals, - global_project_namespace=global_project_namespace, - packages=self.packages, + global_namespace=self.globals, # root package macros + local_namespace=self.locals, # packages for *this* node + global_project_namespace=global_project_namespace, # internal packages + packages=self.packages, # non internal_packages ) diff --git a/core/dbt/context/providers.py b/core/dbt/context/providers.py index 2794b296bbf..7f85db245ad 100644 --- a/core/dbt/context/providers.py +++ b/core/dbt/context/providers.py @@ -1408,7 +1408,12 @@ def __init__( self.macro_resolver = macro_resolver self.thread_ctx = MacroStack() super().__init__(model, config, manifest, provider, context_config) - self._build_test_namespace + self._build_test_namespace() + # We need to rebuild this because it's already been built by + # the ProviderContext with the wrong namespace. + self.db_wrapper = self.provider.DatabaseWrapper( + self.adapter, self.namespace + ) def _build_namespace(self): return {} @@ -1421,11 +1426,17 @@ def _build_test_namespace(self): depends_on_macros = [] if self.model.depends_on and self.model.depends_on.macros: depends_on_macros = self.model.depends_on.macros + lookup_macros = depends_on_macros.copy() + for macro_unique_id in lookup_macros: + lookup_macro = self.macro_resolver.macros.get(macro_unique_id) + if lookup_macro: + depends_on_macros.extend(lookup_macro.depends_on.macros) + macro_namespace = TestMacroNamespace( - self.macro_resolver, self.ctx, self.node, self.thread_ctx, + self.macro_resolver, self._ctx, self.model, self.thread_ctx, depends_on_macros ) - self._namespace = macro_namespace + self.namespace = macro_namespace def generate_test_context( diff --git a/core/dbt/parser/manifest.py b/core/dbt/parser/manifest.py index 1b7db61dfc7..9e4704c152e 100644 --- a/core/dbt/parser/manifest.py +++ b/core/dbt/parser/manifest.py @@ -14,14 +14,17 @@ from dbt.adapters.factory import ( get_adapter, get_relation_class_by_name, + get_adapter_package_names, ) from dbt.helper_types import PathSet from dbt.logger import GLOBAL_LOGGER as logger, DbtProcessState from dbt.node_types import NodeType -from dbt.clients.jinja import get_rendered +from dbt.clients.jinja import get_rendered, statically_extract_macro_calls from dbt.clients.system import make_directory from dbt.config import Project, RuntimeConfig from dbt.context.docs import generate_runtime_docs +from dbt.context.macro_resolver import MacroResolver +from dbt.context.base import generate_base_context from dbt.contracts.files import FileHash, ParseFileType from dbt.parser.read_files import read_files, load_source_file from dbt.contracts.graph.compiled import ManifestNode @@ -198,6 +201,7 @@ def load(self): for search_key in parser_files['MacroParser']: block = FileBlock(self.manifest.files[search_key]) self.parse_with_cache(block, parser) + self.reparse_macros() # This is where a loop over self.manifest.macros should be performed # to set the 'depends_on' information from static rendering. self._perf_info.load_macros_elapsed = (time.perf_counter() - start_load_macros) @@ -270,6 +274,29 @@ def parse_project( self._perf_info.path_count + total_path_count ) + # Loop through macros in the manifest and statically parse + # the 'macro_sql' to find depends_on.macros + def reparse_macros(self): + internal_package_names = get_adapter_package_names( + self.root_project.credentials.type + ) + macro_resolver = MacroResolver( + self.manifest.macros, + self.root_project.project_name, + internal_package_names + ) + base_ctx = generate_base_context({}) + for macro in self.manifest.macros.values(): + possible_macro_calls = statically_extract_macro_calls(macro.macro_sql, base_ctx) + for macro_name in possible_macro_calls: + # adapter.dispatch calls can generate a call with the same name as the macro + # it ought to be an adapter prefix (postgres_) or default_ + if macro_name == macro.name: + continue + dep_macro_id = macro_resolver.get_macro_id(macro.package_name, macro_name) + if dep_macro_id: + macro.depends_on.add_macro(dep_macro_id) # will check for dupes + # This is where we use the partial-parse state from the # pickle file (if it exists) def parse_with_cache( diff --git a/test/integration/008_schema_tests_test/test-context-macros/custom_schema_tests.sql b/test/integration/008_schema_tests_test/test-context-macros/custom_schema_tests.sql new file mode 100644 index 00000000000..ca11d2a6fcf --- /dev/null +++ b/test/integration/008_schema_tests_test/test-context-macros/custom_schema_tests.sql @@ -0,0 +1,19 @@ +{% macro test_type_one(model) %} + + select count(*) from ( + + select * from {{ model }} + union all + select * from {{ ref('model_b') }} + + ) as Foo + +{% endmacro %} + +{% macro test_type_two(model) %} + + {{ config(severity = "WARN") }} + + select count(*) from {{ model }} + +{% endmacro %} diff --git a/test/integration/008_schema_tests_test/test-context-models/model_a.sql b/test/integration/008_schema_tests_test/test-context-models/model_a.sql new file mode 100644 index 00000000000..3bd54a4c1b6 --- /dev/null +++ b/test/integration/008_schema_tests_test/test-context-models/model_a.sql @@ -0,0 +1 @@ +select 1 as fun diff --git a/test/integration/008_schema_tests_test/test-context-models/model_b.sql b/test/integration/008_schema_tests_test/test-context-models/model_b.sql new file mode 100644 index 00000000000..01f38b0698e --- /dev/null +++ b/test/integration/008_schema_tests_test/test-context-models/model_b.sql @@ -0,0 +1 @@ +select 1 as notfun diff --git a/test/integration/008_schema_tests_test/test-context-models/schema.yml b/test/integration/008_schema_tests_test/test-context-models/schema.yml new file mode 100644 index 00000000000..cf221dec670 --- /dev/null +++ b/test/integration/008_schema_tests_test/test-context-models/schema.yml @@ -0,0 +1,8 @@ + +version: 2 + +models: + - name: model_a + tests: + - type_one + - type_two diff --git a/test/integration/008_schema_tests_test/test_schema_v2_tests.py b/test/integration/008_schema_tests_test/test_schema_v2_tests.py index 712805516fa..f7dcff0d3b5 100644 --- a/test/integration/008_schema_tests_test/test_schema_v2_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_v2_tests.py @@ -390,3 +390,39 @@ def test_postgres_schema_uppercase_sql(self): self.assertEqual(len(results), 1) +class TestSchemaTestContext(DBTIntegrationTest): + @property + def schema(self): + return "schema_tests_008" + + @property + def models(self): + return "test-context-models" + + @property + def project_config(self): + return { + 'config-version': 2, + "macro-paths": ["test-context-macros"], + } + + @use_profile('postgres') + def test_postgres_test_context_tests(self): + # This test tests the the TestContext and TestMacroNamespace + # are working correctly + results = self.run_dbt(strict=False) + self.assertEqual(len(results), 2) + + results = self.run_dbt(['test'], expect_pass=False) + self.assertEqual(len(results), 2) + result0 = results[0] + result1 = results[1] + for result in results: + if result.node.name == 'type_two_model_a_': + # This will be WARN if the test macro was rendered correctly + self.assertEqual(result.node.config.severity, 'WARN') + elif result.node.name == 'type_one_model_a': + # This will have correct compiled_sql if the test macro + # was rendered correctly + self.assertRegex(result.node.compiled_sql, r'union all') + diff --git a/test/unit/test_macro_calls.py b/test/unit/test_macro_calls.py new file mode 100644 index 00000000000..bdde2c41fd2 --- /dev/null +++ b/test/unit/test_macro_calls.py @@ -0,0 +1,45 @@ +import os +import unittest +from unittest.mock import MagicMock, patch + +from dataclasses import dataclass, field +from typing import Dict, Any + +from dbt.clients.jinja import statically_extract_macro_calls +from dbt.context.base import generate_base_context + + +class MacroCalls(unittest.TestCase): + + def setUp(self): + self.macro_strings = [ + "{% macro parent_macro() %} {% do return(nested_macro()) %} {% endmacro %}", + "{% macro lr_macro() %} {{ return(load_result('relations').table) }} {% endmacro %}", + "{% macro get_snapshot_unique_id() -%} {{ return(adapter.dispatch('get_snapshot_unique_id')()) }} {%- endmacro %}", + "{% macro get_columns_in_query(select_sql) -%} {{ return(adapter.dispatch('get_columns_in_query')(select_sql)) }} {% endmacro %}", + """{% macro test_mutually_exclusive_ranges(model) %} + with base as ( + select {{ get_snapshot_unique_id() }} as dbt_unique_id, + * + from {{ model }} ) + {% endmacro %}""", + ] + + self.possible_macro_calls = [ + ['nested_macro'], + ['load_result'], + ['get_snapshot_unique_id'], + ['get_columns_in_query'], + ['get_snapshot_unique_id'], + ] + + def test_macro_calls(self): + ctx = generate_base_context({}) + + index = 0 + for macro_string in self.macro_strings: + possible_macro_calls = statically_extract_macro_calls(macro_string, ctx) + self.assertEqual(possible_macro_calls, self.possible_macro_calls[index]) + index = index + 1 + +