From 81426ae800abf896bf583d031a48dcf3f4267674 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 11 Mar 2019 18:16:52 -0600 Subject: [PATCH] add optional "macros" parameter to dbt rpc calls --- core/dbt/parser/util.py | 5 +- core/dbt/rpc.py | 6 +- core/dbt/task/compile.py | 30 ++++-- core/dbt/task/runnable.py | 2 - .../042_sources_test/macros/macro.sql | 7 ++ .../042_sources_test/test_sources.py | 100 +++++++++++++++--- 6 files changed, 120 insertions(+), 30 deletions(-) create mode 100644 test/integration/042_sources_test/macros/macro.sql diff --git a/core/dbt/parser/util.py b/core/dbt/parser/util.py index 92c90fe68a0..09b03185818 100644 --- a/core/dbt/parser/util.py +++ b/core/dbt/parser/util.py @@ -231,12 +231,15 @@ def process_sources(cls, manifest, current_project): return manifest @classmethod - def add_new_refs(cls, manifest, current_project, node): + def add_new_refs(cls, manifest, current_project, node, macros): """Given a new node that is not in the manifest, copy the manifest and insert the new node into it as if it were part of regular ref processing """ manifest = manifest.deepcopy(config=current_project) + # it's ok for macros to silently override a local project macro name + manifest.macros.update(macros) + if node.unique_id in manifest.nodes: # this should be _impossible_ due to the fact that rpc calls get # a unique ID that starts with 'rpc'! diff --git a/core/dbt/rpc.py b/core/dbt/rpc.py index ce598b8866d..f20f2ab32fb 100644 --- a/core/dbt/rpc.py +++ b/core/dbt/rpc.py @@ -31,11 +31,11 @@ def from_error(cls, err): return cls(err.code, err.message, err.data, err.data.get('logs')) -def invalid_params(err, logs): +def invalid_params(data): return RPCException( - code=JSONRPCInvalidParams.code, + code=JSONRPCInvalidParams.CODE, message=JSONRPCInvalidParams.MESSAGE, - data={'logs': logs} + data=data ) diff --git a/core/dbt/task/compile.py b/core/dbt/task/compile.py index f223d8534a2..3cbbe78f880 100644 --- a/core/dbt/task/compile.py +++ b/core/dbt/task/compile.py @@ -6,6 +6,7 @@ from dbt.node_runners import CompileRunner, RPCCompileRunner from dbt.node_types import NodeType from dbt.parser.analysis import RPCCallParser +from dbt.parser.macros import MacroParser from dbt.parser.util import ParserUtils import dbt.ui.printer @@ -37,7 +38,6 @@ class RemoteCompileTask(CompileTask, RemoteCallable): def __init__(self, args, config): super(RemoteCompileTask, self).__init__(args, config) - self.parser = None self._base_manifest = GraphLoader.load_all( config, internal_manifest=get_adapter(config).check_internal_manifest() @@ -56,15 +56,28 @@ def runtime_cleanup(self, selected_uids): self._skipped_children = {} self._raise_next_tick = None - def handle_request(self, name, sql): - self.parser = RPCCallParser( + def handle_request(self, name, sql, macros=None): + request_path = os.path.join(self.config.target_path, 'rpc', name) + all_projects = load_all_projects(self.config) + macro_overrides = {} + if macros is not None: + macros = self.decode_sql(macros) + macro_parser = MacroParser(self.config, all_projects) + macro_overrides.update(macro_parser.parse_macro_file( + macro_file_path='from remote system', + macro_file_contents=macros, + root_path=request_path, + package_name=self.config.project_name, + resource_type=NodeType.Macro + )) + + rpc_parser = RPCCallParser( self.config, - all_projects=load_all_projects(self.config), + all_projects=all_projects, macro_manifest=self._base_manifest ) sql = self.decode_sql(sql) - request_path = os.path.join(self.config.target_path, 'rpc', name) node_dict = { 'name': name, 'root_path': request_path, @@ -74,13 +87,16 @@ def handle_request(self, name, sql): 'package_name': self.config.project_name, 'raw_sql': sql, } - unique_id, node = self.parser.parse_sql_node(node_dict) + + unique_id, node = rpc_parser.parse_sql_node(node_dict) self.manifest = ParserUtils.add_new_refs( manifest=self._base_manifest, current_project=self.config, - node=node + node=node, + macros=macro_overrides ) + # don't write our new, weird manifest! self.linker = compile_manifest(self.config, self.manifest, write=False) selected_uids = [node.unique_id] diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 35a346ec97b..d06bf83c35c 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -369,8 +369,6 @@ def decode_sql(self, sql): @staticmethod def raise_invalid_base64(sql): raise rpc.invalid_params( - code=JSONRPCInvalidParams.CODE, - message=JSONRPCInvalidParams.MESSAGE, data={ 'message': 'invalid base64-encoded sql input', 'sql': str(sql), diff --git a/test/integration/042_sources_test/macros/macro.sql b/test/integration/042_sources_test/macros/macro.sql new file mode 100644 index 00000000000..c1d3b1f47bb --- /dev/null +++ b/test/integration/042_sources_test/macros/macro.sql @@ -0,0 +1,7 @@ +{% macro override_me() -%} + exceptions.raise_compiler_error('this is a bad macro') +{%- endmacro %} + +{% macro happy_little_macro() -%} + {{ override_me() }} +{%- endmacro %} diff --git a/test/integration/042_sources_test/test_sources.py b/test/integration/042_sources_test/test_sources.py index fd6867b84ab..d4f0edc81cf 100644 --- a/test/integration/042_sources_test/test_sources.py +++ b/test/integration/042_sources_test/test_sources.py @@ -1,14 +1,21 @@ import unittest -from nose.plugins.attrib import attr from datetime import datetime, timedelta import json import os +import multiprocessing +from base64 import standard_b64encode as b64 +import requests +import socket +import time + + from dbt.exceptions import CompilationException from test.integration.base import DBTIntegrationTest, use_profile, AnyFloat, \ AnyStringWith from dbt.main import handle_and_check + class BaseSourcesTest(DBTIntegrationTest): @property def schema(self): @@ -260,16 +267,6 @@ def test_postgres_malformed_schema_strict_will_break_run(self): self.run_dbt_with_vars(['run'], strict=True) -import multiprocessing -from base64 import standard_b64encode as b64 -import json -import requests -import socket -import time -import os - - - class ServerProcess(multiprocessing.Process): def __init__(self, cli_vars=None): self.port = 22991 @@ -303,7 +300,7 @@ def start(self): raise Exception('server never appeared!') -@unittest.skipIf(os.name=='nt', 'Windows not supported for now') +@unittest.skipIf(os.name == 'nt', 'Windows not supported for now') class TestRPCServer(BaseSourcesTest): def setUp(self): super(TestRPCServer, self).setUp() @@ -316,10 +313,20 @@ def tearDown(self): self._server.terminate() super(TestRPCServer, self).tearDown() - def build_query(self, method, kwargs, sql=None, test_request_id=1): + @property + def project_config(self): + return { + 'data-paths': ['test/integration/042_sources_test/data'], + 'quoting': {'database': True, 'schema': True, 'identifier': True}, + 'macro-paths': ['test/integration/042_sources_test/macros'], + } + + def build_query(self, method, kwargs, sql=None, test_request_id=1, macros=None): if sql is not None: kwargs['sql'] = b64(sql.encode('utf-8')).decode('utf-8') + if macros is not None: + kwargs['macros'] = b64(macros.encode('utf-8')).decode('utf-8') return { 'jsonrpc': '2.0', 'method': method, @@ -333,8 +340,8 @@ def perform_query(self, query): response = requests.post(url, headers=headers, data=json.dumps(query)) return response - def query(self, _method, _sql=None, _test_request_id=1, **kwargs): - built = self.build_query(_method, kwargs, _sql, _test_request_id) + def query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs): + built = self.build_query(_method, kwargs, _sql, _test_request_id, macros) return self.perform_query(built) def assertResultHasTimings(self, result, *names): @@ -425,7 +432,6 @@ def test_compile(self): 'select * from {{ source("test_source", "test_table") }}', name='foo' ).json() - self.assertSuccessfulCompilationResult( source, 'select * from {{ source("test_source", "test_table") }}', @@ -434,6 +440,30 @@ def test_compile(self): self.unique_schema()) ) + macro = self.query( + 'compile', + 'select {{ my_macro() }}', + name='foo', + macros='{% macro my_macro() %}1 as id{% endmacro %}' + ).json() + self.assertSuccessfulCompilationResult( + macro, + 'select {{ my_macro() }}', + compiled_sql='select 1 as id' + ) + + macro_override = self.query( + 'compile', + 'select {{ happy_little_macro() }}', + name='foo', + macros='{% macro override_me() %}2 as id{% endmacro %}' + ).json() + self.assertSuccessfulCompilationResult( + macro_override, + 'select {{ happy_little_macro() }}', + compiled_sql='select 2 as id' + ) + @use_profile('postgres') def test_run(self): # seed + run dbt to make models before using them! @@ -470,7 +500,6 @@ def test_run(self): 'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1', name='foo' ).json() - self.assertSuccessfulRunResult( source, 'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1', @@ -483,6 +512,32 @@ def test_run(self): } ) + macro = self.query( + 'run', + 'select {{ my_macro() }}', + name='foo', + macros='{% macro my_macro() %}1 as id{% endmacro %}' + ).json() + self.assertSuccessfulRunResult( + macro, + raw_sql='select {{ my_macro() }}', + compiled_sql='select 1 as id', + table={'column_names': ['id'], 'rows': [[1.0]]} + ) + + macro_override = self.query( + 'run', + 'select {{ happy_little_macro() }}', + name='foo', + macros='{% macro override_me() %}2 as id{% endmacro %}' + ).json() + self.assertSuccessfulRunResult( + macro_override, + raw_sql='select {{ happy_little_macro() }}', + compiled_sql='select 2 as id', + table={'column_names': ['id'], 'rows': [[2.0]]} + ) + @use_profile('postgres') def test_invalid_requests(self): data = self.query( @@ -526,6 +581,17 @@ def test_invalid_requests(self): self.assertIn('logs', error_data) self.assertTrue(len(error_data['logs']) > 0) + macro_no_override = self.query( + 'run', + 'select {{ happy_little_macro() }}', + name='foo', + ).json() + self.assertIsErrorWithCode(macro_no_override, 10003) + self.assertEqual(error['message'], 'Database Error') + self.assertIn('data', error) + error_data = error['data'] + self.assertEqual(error_data['type'], 'DatabaseException') + @use_profile('postgres') def test_timeout(self): data = self.query(