Skip to content

Commit

Permalink
Add necessary macros to schema test context namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Apr 16, 2021
1 parent cee0bfb commit a17eac6
Show file tree
Hide file tree
Showing 11 changed files with 248 additions and 32 deletions.
36 changes: 36 additions & 0 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,39 @@ def _convert_function(

kwargs = deep_map(_convert_function, node.test_metadata.kwargs)
context[SCHEMA_TEST_KWARGS_NAME] = kwargs


def statically_extract_macro_calls(string, ctx):
# set 'capture_macros' to capture undefined
env = get_environment(None, capture_macros=True)
parsed = env.parse(string)

standard_calls = {
'source': [],
'ref': [],
'config': [],
}

possible_macro_calls = []
for func_call in parsed.find_all(jinja2.nodes.Call):
if hasattr(func_call, 'node') and hasattr(func_call.node, 'name'):
func_name = func_call.node.name
else:
# This is a kludge to capture an adapter.dispatch('<macro_name>') call.
# Call(node=Getattr(
# node=Name(name='adapter', ctx='load'), attr='dispatch', ctx='load'),
# args=[Const(value='get_snapshot_unique_id')], kwargs=[],
# dyn_args=None, dyn_kwargs=None)
if (hasattr(func_call, 'node') and hasattr(func_call.node, 'attr') and
func_call.node.attr == 'dispatch'):
func_name = func_call.args[0].value
else:
continue
if func_name in standard_calls:
continue
elif ctx.get(func_name):
continue
else:
possible_macro_calls.append(func_name)

return possible_macro_calls
54 changes: 40 additions & 14 deletions core/dbt/context/macro_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
# so that higher precedence macros are found first.
# This functionality is also provided by the MacroNamespace,
# but the intention is to eventually replace that class.
# This enables us to get the macor unique_id without
# This enables us to get the macro unique_id without
# processing every macro in the project.
# Note: the root project macros override everything in the
# dbt internal projects. External projects (dependencies) will
# use their own macros first, then pull from the root project
# followed by dbt internal projects.
class MacroResolver:
def __init__(
self,
Expand Down Expand Up @@ -48,18 +52,29 @@ def _build_internal_packages_namespace(self):
self.internal_packages_namespace.update(
self.internal_packages[pkg])

# search order:
# local_namespace (package of particular node), not including
# the internal packages or the root package
# This means that within an extra package, it uses its own macros
# root package namespace
# non-internal packages (that aren't local or root)
# dbt internal packages
def _build_macros_by_name(self):
macros_by_name = {}
# search root package macros
for macro in self.root_package_macros.values():

# all internal packages (already in the right order)
for macro in self.internal_packages_namespace.values():
macros_by_name[macro.name] = macro
# search miscellaneous non-internal packages

# non-internal packages
for fnamespace in self.packages.values():
for macro in fnamespace.values():
macros_by_name[macro.name] = macro
# search all internal packages
for macro in self.internal_packages_namespace.values():

# root package macros
for macro in self.root_package_macros.values():
macros_by_name[macro.name] = macro

self.macros_by_name = macros_by_name

def _add_macro_to(
Expand Down Expand Up @@ -97,18 +112,26 @@ def add_macros(self):
for macro in self.macros.values():
self.add_macro(macro)

def get_macro_id(self, local_package, macro_name):
def get_macro(self, local_package, macro_name):
local_package_macros = {}
if (local_package not in self.internal_package_names and
local_package in self.packages):
local_package_macros = self.packages[local_package]
# First: search the local packages for this macro
if macro_name in local_package_macros:
return local_package_macros[macro_name].unique_id
return local_package_macros[macro_name]
# Now look up in the standard search order
if macro_name in self.macros_by_name:
return self.macros_by_name[macro_name].unique_id
return self.macros_by_name[macro_name]
return None

def get_macro_id(self, local_package, macro_name):
macro = self.get_macro(local_package, macro_name)
if macro is None:
return None
else:
return macro.unique_id


# Currently this is just used by test processing in the schema
# parser (in connection with the MacroResolver). Future work
Expand All @@ -127,10 +150,11 @@ def __init__(
local_namespace = {}
if depends_on_macros:
for macro_unique_id in depends_on_macros:
macro = self.manifest.macros[macro_unique_id]
local_namespace[macro.name] = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx,
)
if macro_unique_id in self.macro_resolver.macros:
macro = self.macro_resolver.macros[macro_unique_id]
local_namespace[macro.name] = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx,
)
self.local_namespace = local_namespace

def get_from_package(
Expand All @@ -141,12 +165,14 @@ def get_from_package(
macro = self.macro_resolver.macros_by_name.get(name)
elif package_name == GLOBAL_PROJECT_NAME:
macro = self.macro_resolver.internal_packages_namespace.get(name)
elif package_name in self.resolver.packages:
elif package_name in self.macro_resolver.packages:
macro = self.macro_resolver.packages[package_name].get(name)
else:
raise_compiler_error(
f"Could not find package '{package_name}'"
)
if not macro:
return None
macro_func = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx
)
Expand Down
34 changes: 20 additions & 14 deletions core/dbt/context/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,31 @@
# and provide the ability to flatten them into the ManifestContexts
# that are created for jinja, so that macro calls can be resolved.
# Creates special iterators and _keys methods to flatten the lists.
# When this class is created it has a static 'local_namespace' which
# depends on the package of the node, so it only works for one
# particular local package at a time for "flattening" into a context.
# 'get_by_package' should work for any macro.
class MacroNamespace(Mapping):
def __init__(
self,
global_namespace: FlatNamespace,
local_namespace: FlatNamespace,
global_project_namespace: FlatNamespace,
packages: Dict[str, FlatNamespace],
global_namespace: FlatNamespace, # root package macros
local_namespace: FlatNamespace, # packages for *this* node
global_project_namespace: FlatNamespace, # internal packages
packages: Dict[str, FlatNamespace], # non-internal packages
):
self.global_namespace: FlatNamespace = global_namespace
self.local_namespace: FlatNamespace = local_namespace
self.packages: Dict[str, FlatNamespace] = packages
self.global_project_namespace: FlatNamespace = global_project_namespace

def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]:
yield self.local_namespace
yield self.global_namespace
yield self.packages
yield self.local_namespace # local package
yield self.global_namespace # root package
yield self.packages # non-internal packages
yield {
GLOBAL_PROJECT_NAME: self.global_project_namespace,
GLOBAL_PROJECT_NAME: self.global_project_namespace, # dbt
}
yield self.global_project_namespace
yield self.global_project_namespace # other internal project besides dbt

# provides special keys method for MacroNamespace iterator
# returns keys from local_namespace, global_namespace, packages,
Expand Down Expand Up @@ -98,7 +102,9 @@ def __init__(
# internal packages comes from get_adapter_package_names
self.internal_package_names = set(internal_packages)
self.internal_package_names_order = internal_packages
# macro_func is added here if in root package
# macro_func is added here if in root package, since
# the root package acts as a "global" namespace, overriding
# everything else except local external package macro calls
self.globals: FlatNamespace = {}
# macro_func is added here if it's the package for this node
self.locals: FlatNamespace = {}
Expand Down Expand Up @@ -169,8 +175,8 @@ def build_namespace(
global_project_namespace.update(self.internal_packages[pkg])

return MacroNamespace(
global_namespace=self.globals,
local_namespace=self.locals,
global_project_namespace=global_project_namespace,
packages=self.packages,
global_namespace=self.globals, # root package macros
local_namespace=self.locals, # packages for *this* node
global_project_namespace=global_project_namespace, # internal packages
packages=self.packages, # non internal_packages
)
17 changes: 14 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,12 @@ def __init__(
self.macro_resolver = macro_resolver
self.thread_ctx = MacroStack()
super().__init__(model, config, manifest, provider, context_config)
self._build_test_namespace
self._build_test_namespace()
# We need to rebuild this because it's already been built by
# the ProviderContext with the wrong namespace.
self.db_wrapper = self.provider.DatabaseWrapper(
self.adapter, self.namespace
)

def _build_namespace(self):
return {}
Expand All @@ -1421,11 +1426,17 @@ def _build_test_namespace(self):
depends_on_macros = []
if self.model.depends_on and self.model.depends_on.macros:
depends_on_macros = self.model.depends_on.macros
lookup_macros = depends_on_macros.copy()
for macro_unique_id in lookup_macros:
lookup_macro = self.macro_resolver.macros.get(macro_unique_id)
if lookup_macro:
depends_on_macros.extend(lookup_macro.depends_on.macros)

macro_namespace = TestMacroNamespace(
self.macro_resolver, self.ctx, self.node, self.thread_ctx,
self.macro_resolver, self._ctx, self.model, self.thread_ctx,
depends_on_macros
)
self._namespace = macro_namespace
self.namespace = macro_namespace


def generate_test_context(
Expand Down
29 changes: 28 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from dbt.adapters.factory import (
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
)
from dbt.helper_types import PathSet
from dbt.logger import GLOBAL_LOGGER as logger, DbtProcessState
from dbt.node_types import NodeType
from dbt.clients.jinja import get_rendered
from dbt.clients.jinja import get_rendered, statically_extract_macro_calls
from dbt.clients.system import make_directory
from dbt.config import Project, RuntimeConfig
from dbt.context.docs import generate_runtime_docs
from dbt.context.macro_resolver import MacroResolver
from dbt.context.base import generate_base_context
from dbt.contracts.files import FileHash, ParseFileType
from dbt.parser.read_files import read_files, load_source_file
from dbt.contracts.graph.compiled import ManifestNode
Expand Down Expand Up @@ -198,6 +201,7 @@ def load(self):
for search_key in parser_files['MacroParser']:
block = FileBlock(self.manifest.files[search_key])
self.parse_with_cache(block, parser)
self.reparse_macros()
# This is where a loop over self.manifest.macros should be performed
# to set the 'depends_on' information from static rendering.
self._perf_info.load_macros_elapsed = (time.perf_counter() - start_load_macros)
Expand Down Expand Up @@ -270,6 +274,29 @@ def parse_project(
self._perf_info.path_count + total_path_count
)

# Loop through macros in the manifest and statically parse
# the 'macro_sql' to find depends_on.macros
def reparse_macros(self):
internal_package_names = get_adapter_package_names(
self.root_project.credentials.type
)
macro_resolver = MacroResolver(
self.manifest.macros,
self.root_project.project_name,
internal_package_names
)
base_ctx = generate_base_context({})
for macro in self.manifest.macros.values():
possible_macro_calls = statically_extract_macro_calls(macro.macro_sql, base_ctx)
for macro_name in possible_macro_calls:
# adapter.dispatch calls can generate a call with the same name as the macro
# it ought to be an adapter prefix (postgres_) or default_
if macro_name == macro.name:
continue
dep_macro_id = macro_resolver.get_macro_id(macro.package_name, macro_name)
if dep_macro_id:
macro.depends_on.add_macro(dep_macro_id) # will check for dupes

# This is where we use the partial-parse state from the
# pickle file (if it exists)
def parse_with_cache(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{% macro test_type_one(model) %}

select count(*) from (

select * from {{ model }}
union all
select * from {{ ref('model_b') }}

) as Foo

{% endmacro %}

{% macro test_type_two(model) %}

{{ config(severity = "WARN") }}

select count(*) from {{ model }}

{% endmacro %}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as fun
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as notfun
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

version: 2

models:
- name: model_a
tests:
- type_one
- type_two
36 changes: 36 additions & 0 deletions test/integration/008_schema_tests_test/test_schema_v2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,39 @@ def test_postgres_schema_uppercase_sql(self):
self.assertEqual(len(results), 1)


class TestSchemaTestContext(DBTIntegrationTest):
@property
def schema(self):
return "schema_tests_008"

@property
def models(self):
return "test-context-models"

@property
def project_config(self):
return {
'config-version': 2,
"macro-paths": ["test-context-macros"],
}

@use_profile('postgres')
def test_postgres_test_context_tests(self):
# This test tests the the TestContext and TestMacroNamespace
# are working correctly
results = self.run_dbt(strict=False)
self.assertEqual(len(results), 2)

results = self.run_dbt(['test'], expect_pass=False)
self.assertEqual(len(results), 2)
result0 = results[0]
result1 = results[1]
for result in results:
if result.node.name == 'type_two_model_a_':
# This will be WARN if the test macro was rendered correctly
self.assertEqual(result.node.config.severity, 'WARN')
elif result.node.name == 'type_one_model_a':
# This will have correct compiled_sql if the test macro
# was rendered correctly
self.assertRegex(result.node.compiled_sql, r'union all')

Loading

0 comments on commit a17eac6

Please sign in to comment.