diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index dfa98f4de91..98bcc2d410c 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -231,7 +231,7 @@ def search_package_name(self): def add_macros_from( self, context: Dict[str, Any], - macros: Dict[str, ParsedMacro], + macros: Dict[str, ParsedMacro] ): global_macros: List[Dict[str, Callable]] = [] local_macros: List[Dict[str, Callable]] = [] @@ -248,9 +248,14 @@ def add_macros_from( # adapter packages are part of the global project space _add_macro_map(context, package_name, macro_map) - if package_name == self.search_package_name: + if package_name == self.config.project_name: + # If we're in the root project, allow global override + global_macros.append(macro_map) + elif package_name == self.search_package_name: + # If we're in the current project, allow local override local_macros.append(macro_map) elif package_name in PACKAGES: + # If it comes from a dbt package, allow global override global_macros.append(macro_map) # Load global macros before local macros -- local takes precedence diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index 7dcff01945b..9b4a743a8f0 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -212,16 +212,21 @@ def search_package_name(self): return self.model.package_name def add_provider_functions(self, context): - context['ref'] = self.provider.ref( - self.db_wrapper, self.model, self.config, self.manifest - ) - context['source'] = self.provider.source( - self.db_wrapper, self.model, self.config, self.manifest - ) - context['config'] = self.provider.Config( - self.model, self.source_config - ) - context['execute'] = self.provider.execute + # Generate the builtin functions + builtins = { + 'ref': self.provider.ref( + self.db_wrapper, self.model, self.config, self.manifest), + 'source': self.provider.source( + self.db_wrapper, self.model, self.config, self.manifest), + 'config': self.provider.Config( + self.model, self.source_config), + 'execute': self.provider.execute + } + # Install them at .builtins + context['builtins'] = builtins + # Install each of them directly in case they're not + # clobbered by a macro. + context.update(builtins) def add_exceptions(self, context): context['exceptions'] = dbt.exceptions.wrapped_exports(self.model) diff --git a/test/integration/055_ref_override_test/data/seed_1.csv b/test/integration/055_ref_override_test/data/seed_1.csv new file mode 100644 index 00000000000..4de2771bdac --- /dev/null +++ b/test/integration/055_ref_override_test/data/seed_1.csv @@ -0,0 +1,4 @@ +a,b +1,2 +2,4 +3,6 \ No newline at end of file diff --git a/test/integration/055_ref_override_test/data/seed_2.csv b/test/integration/055_ref_override_test/data/seed_2.csv new file mode 100644 index 00000000000..eeadef9495c --- /dev/null +++ b/test/integration/055_ref_override_test/data/seed_2.csv @@ -0,0 +1,4 @@ +a,b +6,2 +12,4 +18,6 \ No newline at end of file diff --git a/test/integration/055_ref_override_test/macros/ref_override_macro.sql b/test/integration/055_ref_override_test/macros/ref_override_macro.sql new file mode 100644 index 00000000000..a4a85b50324 --- /dev/null +++ b/test/integration/055_ref_override_test/macros/ref_override_macro.sql @@ -0,0 +1,4 @@ +-- Macro to override ref and always return the same result +{% macro ref(modelname) %} +{% do return(builtins.ref(modelname).replace_path(identifier='seed_2')) %} +{% endmacro %} \ No newline at end of file diff --git a/test/integration/055_ref_override_test/models/ref_override.sql b/test/integration/055_ref_override_test/models/ref_override.sql new file mode 100644 index 00000000000..3bbf936ae2e --- /dev/null +++ b/test/integration/055_ref_override_test/models/ref_override.sql @@ -0,0 +1,3 @@ +select + * +from {{ ref('seed_1') }} \ No newline at end of file diff --git a/test/integration/055_ref_override_test/test_ref_override.py b/test/integration/055_ref_override_test/test_ref_override.py new file mode 100644 index 00000000000..2d0fc3e068d --- /dev/null +++ b/test/integration/055_ref_override_test/test_ref_override.py @@ -0,0 +1,29 @@ +from test.integration.base import DBTIntegrationTest, use_profile + + +class TestRefOverride(DBTIntegrationTest): + @property + def schema(self): + return "dbt_ref_override_055" + + @property + def project_config(self): + return { + 'data-paths': ['data'], + "macro-paths": ["macros"], + 'seeds': { + 'quote_columns': False + } + } + + @property + def models(self): + return "models" + + @use_profile('postgres') + def test_postgres_ref_override(self): + self.run_dbt(['seed']) + self.run_dbt(['run']) + # We want it to equal seed_2 and not seed_1. If it's + # still pointing at seed_1 then the override hasn't worked. + self.assertTablesEqual('ref_override', 'seed_2')