Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add necessary macros to schema test context namespace #3272

Merged
merged 1 commit into from
Apr 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- Add feature to add `_n` alias to same column names in SQL query ([#3147](https://github.com/fishtown-analytics/dbt/issues/3147), [#3158](https://github.com/fishtown-analytics/dbt/pull/3158))
- Raise a proper error message if dbt parses a macro twice due to macro duplication or misconfiguration. ([#2449](https://github.com/fishtown-analytics/dbt/issues/2449), [#3165](https://github.com/fishtown-analytics/dbt/pull/3165))
- Fix exposures missing in graph context variable. ([#3241](https://github.com/fishtown-analytics/dbt/issues/3241))
- Ensure that schema test macros are properly processed ([#3229](https://github.com/fishtown-analytics/dbt/issues/3229), [#3272](https://github.com/fishtown-analytics/dbt/pull/3272))

### Features
- Add optional configs for `require_partition_filter` and `partition_expiration_days` in BigQuery ([#1843](https://github.com/fishtown-analytics/dbt/issues/1843), [#2928](https://github.com/fishtown-analytics/dbt/pull/2928))
Expand Down
36 changes: 36 additions & 0 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,3 +647,39 @@ def _convert_function(

kwargs = deep_map(_convert_function, node.test_metadata.kwargs)
context[SCHEMA_TEST_KWARGS_NAME] = kwargs


def statically_extract_macro_calls(string, ctx):
# set 'capture_macros' to capture undefined
env = get_environment(None, capture_macros=True)
parsed = env.parse(string)

standard_calls = {
'source': [],
'ref': [],
'config': [],
}
Comment on lines +657 to +661
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are the empty lists from a previous approach? can standard_calls just be a set of strings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That structure was there from an experiment that Drew wrote. It could be a set of strings right now, but I think I'll leave it like this as a connection to that previous experiment and to leave space for a possible future enhancement.


possible_macro_calls = []
for func_call in parsed.find_all(jinja2.nodes.Call):
if hasattr(func_call, 'node') and hasattr(func_call.node, 'name'):
func_name = func_call.node.name
else:
# This is a kludge to capture an adapter.dispatch('<macro_name>') call.
# Call(node=Getattr(
# node=Name(name='adapter', ctx='load'), attr='dispatch', ctx='load'),
# args=[Const(value='get_snapshot_unique_id')], kwargs=[],
# dyn_args=None, dyn_kwargs=None)
if (hasattr(func_call, 'node') and hasattr(func_call.node, 'attr') and
func_call.node.attr == 'dispatch'):
func_name = func_call.args[0].value
else:
continue
if func_name in standard_calls:
continue
elif ctx.get(func_name):
continue
else:
possible_macro_calls.append(func_name)

return possible_macro_calls
54 changes: 40 additions & 14 deletions core/dbt/context/macro_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
# so that higher precedence macros are found first.
# This functionality is also provided by the MacroNamespace,
# but the intention is to eventually replace that class.
# This enables us to get the macor unique_id without
# This enables us to get the macro unique_id without
# processing every macro in the project.
# Note: the root project macros override everything in the
# dbt internal projects. External projects (dependencies) will
# use their own macros first, then pull from the root project
# followed by dbt internal projects.
class MacroResolver:
def __init__(
self,
Expand Down Expand Up @@ -48,18 +52,29 @@ def _build_internal_packages_namespace(self):
self.internal_packages_namespace.update(
self.internal_packages[pkg])

# search order:
# local_namespace (package of particular node), not including
# the internal packages or the root package
# This means that within an extra package, it uses its own macros
# root package namespace
# non-internal packages (that aren't local or root)
# dbt internal packages
def _build_macros_by_name(self):
macros_by_name = {}
# search root package macros
for macro in self.root_package_macros.values():

# all internal packages (already in the right order)
for macro in self.internal_packages_namespace.values():
macros_by_name[macro.name] = macro
# search miscellaneous non-internal packages

# non-internal packages
for fnamespace in self.packages.values():
for macro in fnamespace.values():
macros_by_name[macro.name] = macro
# search all internal packages
for macro in self.internal_packages_namespace.values():

# root package macros
for macro in self.root_package_macros.values():
macros_by_name[macro.name] = macro

self.macros_by_name = macros_by_name
Comment on lines 62 to 78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'm having a little trouble following this change, was this ordering wrong before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It never actually worked, so we didn't notice. It needs to be built from lowest precedence to highest, and it was the opposite before.


def _add_macro_to(
Expand Down Expand Up @@ -97,18 +112,26 @@ def add_macros(self):
for macro in self.macros.values():
self.add_macro(macro)

def get_macro_id(self, local_package, macro_name):
def get_macro(self, local_package, macro_name):
local_package_macros = {}
if (local_package not in self.internal_package_names and
local_package in self.packages):
local_package_macros = self.packages[local_package]
# First: search the local packages for this macro
if macro_name in local_package_macros:
return local_package_macros[macro_name].unique_id
return local_package_macros[macro_name]
# Now look up in the standard search order
if macro_name in self.macros_by_name:
return self.macros_by_name[macro_name].unique_id
return self.macros_by_name[macro_name]
return None

def get_macro_id(self, local_package, macro_name):
macro = self.get_macro(local_package, macro_name)
if macro is None:
return None
else:
return macro.unique_id


# Currently this is just used by test processing in the schema
# parser (in connection with the MacroResolver). Future work
Expand All @@ -127,10 +150,11 @@ def __init__(
local_namespace = {}
if depends_on_macros:
for macro_unique_id in depends_on_macros:
macro = self.manifest.macros[macro_unique_id]
local_namespace[macro.name] = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx,
)
if macro_unique_id in self.macro_resolver.macros:
macro = self.macro_resolver.macros[macro_unique_id]
local_namespace[macro.name] = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx,
)
self.local_namespace = local_namespace

def get_from_package(
Expand All @@ -141,12 +165,14 @@ def get_from_package(
macro = self.macro_resolver.macros_by_name.get(name)
elif package_name == GLOBAL_PROJECT_NAME:
macro = self.macro_resolver.internal_packages_namespace.get(name)
elif package_name in self.resolver.packages:
elif package_name in self.macro_resolver.packages:
macro = self.macro_resolver.packages[package_name].get(name)
else:
raise_compiler_error(
f"Could not find package '{package_name}'"
)
if not macro:
return None
macro_func = MacroGenerator(
macro, self.ctx, self.node, self.thread_ctx
)
Expand Down
34 changes: 20 additions & 14 deletions core/dbt/context/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,31 @@
# and provide the ability to flatten them into the ManifestContexts
# that are created for jinja, so that macro calls can be resolved.
# Creates special iterators and _keys methods to flatten the lists.
# When this class is created it has a static 'local_namespace' which
# depends on the package of the node, so it only works for one
# particular local package at a time for "flattening" into a context.
# 'get_by_package' should work for any macro.
class MacroNamespace(Mapping):
def __init__(
self,
global_namespace: FlatNamespace,
local_namespace: FlatNamespace,
global_project_namespace: FlatNamespace,
packages: Dict[str, FlatNamespace],
global_namespace: FlatNamespace, # root package macros
local_namespace: FlatNamespace, # packages for *this* node
global_project_namespace: FlatNamespace, # internal packages
packages: Dict[str, FlatNamespace], # non-internal packages
):
self.global_namespace: FlatNamespace = global_namespace
self.local_namespace: FlatNamespace = local_namespace
self.packages: Dict[str, FlatNamespace] = packages
self.global_project_namespace: FlatNamespace = global_project_namespace

def _search_order(self) -> Iterable[Union[FullNamespace, FlatNamespace]]:
yield self.local_namespace
yield self.global_namespace
yield self.packages
yield self.local_namespace # local package
yield self.global_namespace # root package
yield self.packages # non-internal packages
yield {
GLOBAL_PROJECT_NAME: self.global_project_namespace,
GLOBAL_PROJECT_NAME: self.global_project_namespace, # dbt
}
yield self.global_project_namespace
yield self.global_project_namespace # other internal project besides dbt

# provides special keys method for MacroNamespace iterator
# returns keys from local_namespace, global_namespace, packages,
Expand Down Expand Up @@ -98,7 +102,9 @@ def __init__(
# internal packages comes from get_adapter_package_names
self.internal_package_names = set(internal_packages)
self.internal_package_names_order = internal_packages
# macro_func is added here if in root package
# macro_func is added here if in root package, since
# the root package acts as a "global" namespace, overriding
# everything else except local external package macro calls
self.globals: FlatNamespace = {}
# macro_func is added here if it's the package for this node
self.locals: FlatNamespace = {}
Expand Down Expand Up @@ -169,8 +175,8 @@ def build_namespace(
global_project_namespace.update(self.internal_packages[pkg])

return MacroNamespace(
global_namespace=self.globals,
local_namespace=self.locals,
global_project_namespace=global_project_namespace,
packages=self.packages,
global_namespace=self.globals, # root package macros
local_namespace=self.locals, # packages for *this* node
global_project_namespace=global_project_namespace, # internal packages
packages=self.packages, # non internal_packages
)
17 changes: 14 additions & 3 deletions core/dbt/context/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1408,7 +1408,12 @@ def __init__(
self.macro_resolver = macro_resolver
self.thread_ctx = MacroStack()
super().__init__(model, config, manifest, provider, context_config)
self._build_test_namespace
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oof! we missed this line before

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. One of the reasons that it didn't work :-)

self._build_test_namespace()
# We need to rebuild this because it's already been built by
# the ProviderContext with the wrong namespace.
self.db_wrapper = self.provider.DatabaseWrapper(
self.adapter, self.namespace
)

def _build_namespace(self):
return {}
Expand All @@ -1421,11 +1426,17 @@ def _build_test_namespace(self):
depends_on_macros = []
if self.model.depends_on and self.model.depends_on.macros:
depends_on_macros = self.model.depends_on.macros
lookup_macros = depends_on_macros.copy()
for macro_unique_id in lookup_macros:
lookup_macro = self.macro_resolver.macros.get(macro_unique_id)
if lookup_macro:
depends_on_macros.extend(lookup_macro.depends_on.macros)

macro_namespace = TestMacroNamespace(
self.macro_resolver, self.ctx, self.node, self.thread_ctx,
self.macro_resolver, self._ctx, self.model, self.thread_ctx,
depends_on_macros
)
self._namespace = macro_namespace
self.namespace = macro_namespace


def generate_test_context(
Expand Down
29 changes: 28 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from dbt.adapters.factory import (
get_adapter,
get_relation_class_by_name,
get_adapter_package_names,
)
from dbt.helper_types import PathSet
from dbt.logger import GLOBAL_LOGGER as logger, DbtProcessState
from dbt.node_types import NodeType
from dbt.clients.jinja import get_rendered
from dbt.clients.jinja import get_rendered, statically_extract_macro_calls
from dbt.clients.system import make_directory
from dbt.config import Project, RuntimeConfig
from dbt.context.docs import generate_runtime_docs
from dbt.context.macro_resolver import MacroResolver
from dbt.context.base import generate_base_context
from dbt.contracts.files import FileHash, ParseFileType
from dbt.parser.read_files import read_files, load_source_file
from dbt.contracts.graph.compiled import ManifestNode
Expand Down Expand Up @@ -198,6 +201,7 @@ def load(self):
for search_key in parser_files['MacroParser']:
block = FileBlock(self.manifest.files[search_key])
self.parse_with_cache(block, parser)
self.reparse_macros()
# This is where a loop over self.manifest.macros should be performed
# to set the 'depends_on' information from static rendering.
self._perf_info.load_macros_elapsed = (time.perf_counter() - start_load_macros)
Expand Down Expand Up @@ -270,6 +274,29 @@ def parse_project(
self._perf_info.path_count + total_path_count
)

# Loop through macros in the manifest and statically parse
# the 'macro_sql' to find depends_on.macros
def reparse_macros(self):
internal_package_names = get_adapter_package_names(
self.root_project.credentials.type
)
macro_resolver = MacroResolver(
self.manifest.macros,
self.root_project.project_name,
internal_package_names
)
base_ctx = generate_base_context({})
for macro in self.manifest.macros.values():
possible_macro_calls = statically_extract_macro_calls(macro.macro_sql, base_ctx)
for macro_name in possible_macro_calls:
# adapter.dispatch calls can generate a call with the same name as the macro
# it ought to be an adapter prefix (postgres_) or default_
if macro_name == macro.name:
continue
dep_macro_id = macro_resolver.get_macro_id(macro.package_name, macro_name)
if dep_macro_id:
macro.depends_on.add_macro(dep_macro_id) # will check for dupes

# This is where we use the partial-parse state from the
# pickle file (if it exists)
def parse_with_cache(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
{% macro test_type_one(model) %}

select count(*) from (

select * from {{ model }}
union all
select * from {{ ref('model_b') }}

) as Foo

{% endmacro %}

{% macro test_type_two(model) %}

{{ config(severity = "WARN") }}

select count(*) from {{ model }}

{% endmacro %}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as fun
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as notfun
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

version: 2

models:
- name: model_a
tests:
- type_one
- type_two
36 changes: 36 additions & 0 deletions test/integration/008_schema_tests_test/test_schema_v2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,39 @@ def test_postgres_schema_uppercase_sql(self):
self.assertEqual(len(results), 1)


class TestSchemaTestContext(DBTIntegrationTest):
@property
def schema(self):
return "schema_tests_008"

@property
def models(self):
return "test-context-models"

@property
def project_config(self):
return {
'config-version': 2,
"macro-paths": ["test-context-macros"],
}

@use_profile('postgres')
def test_postgres_test_context_tests(self):
# This test tests the the TestContext and TestMacroNamespace
# are working correctly
results = self.run_dbt(strict=False)
self.assertEqual(len(results), 2)

results = self.run_dbt(['test'], expect_pass=False)
self.assertEqual(len(results), 2)
result0 = results[0]
result1 = results[1]
for result in results:
if result.node.name == 'type_two_model_a_':
# This will be WARN if the test macro was rendered correctly
self.assertEqual(result.node.config.severity, 'WARN')
elif result.node.name == 'type_one_model_a':
# This will have correct compiled_sql if the test macro
# was rendered correctly
self.assertRegex(result.node.compiled_sql, r'union all')

Loading