Skip to content

Commit

Permalink
Merge pull request #1348 from fishtown-analytics/feature/rpc-with-macros
Browse files Browse the repository at this point in the history
RPC: macros
  • Loading branch information
beckjake authored Mar 12, 2019
2 parents fc22cb2 + 81426ae commit c1c09f3
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 30 deletions.
5 changes: 4 additions & 1 deletion core/dbt/parser/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,15 @@ def process_sources(cls, manifest, current_project):
return manifest

@classmethod
def add_new_refs(cls, manifest, current_project, node):
def add_new_refs(cls, manifest, current_project, node, macros):
"""Given a new node that is not in the manifest, copy the manifest and
insert the new node into it as if it were part of regular ref
processing
"""
manifest = manifest.deepcopy(config=current_project)
# it's ok for macros to silently override a local project macro name
manifest.macros.update(macros)

if node.unique_id in manifest.nodes:
# this should be _impossible_ due to the fact that rpc calls get
# a unique ID that starts with 'rpc'!
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def from_error(cls, err):
return cls(err.code, err.message, err.data, err.data.get('logs'))


def invalid_params(err, logs):
def invalid_params(data):
return RPCException(
code=JSONRPCInvalidParams.code,
code=JSONRPCInvalidParams.CODE,
message=JSONRPCInvalidParams.MESSAGE,
data={'logs': logs}
data=data
)


Expand Down
30 changes: 23 additions & 7 deletions core/dbt/task/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dbt.node_runners import CompileRunner, RPCCompileRunner
from dbt.node_types import NodeType
from dbt.parser.analysis import RPCCallParser
from dbt.parser.macros import MacroParser
from dbt.parser.util import ParserUtils
import dbt.ui.printer

Expand Down Expand Up @@ -37,7 +38,6 @@ class RemoteCompileTask(CompileTask, RemoteCallable):

def __init__(self, args, config):
super(RemoteCompileTask, self).__init__(args, config)
self.parser = None
self._base_manifest = GraphLoader.load_all(
config,
internal_manifest=get_adapter(config).check_internal_manifest()
Expand All @@ -56,15 +56,28 @@ def runtime_cleanup(self, selected_uids):
self._skipped_children = {}
self._raise_next_tick = None

def handle_request(self, name, sql):
self.parser = RPCCallParser(
def handle_request(self, name, sql, macros=None):
request_path = os.path.join(self.config.target_path, 'rpc', name)
all_projects = load_all_projects(self.config)
macro_overrides = {}
if macros is not None:
macros = self.decode_sql(macros)
macro_parser = MacroParser(self.config, all_projects)
macro_overrides.update(macro_parser.parse_macro_file(
macro_file_path='from remote system',
macro_file_contents=macros,
root_path=request_path,
package_name=self.config.project_name,
resource_type=NodeType.Macro
))

rpc_parser = RPCCallParser(
self.config,
all_projects=load_all_projects(self.config),
all_projects=all_projects,
macro_manifest=self._base_manifest
)

sql = self.decode_sql(sql)
request_path = os.path.join(self.config.target_path, 'rpc', name)
node_dict = {
'name': name,
'root_path': request_path,
Expand All @@ -74,13 +87,16 @@ def handle_request(self, name, sql):
'package_name': self.config.project_name,
'raw_sql': sql,
}
unique_id, node = self.parser.parse_sql_node(node_dict)

unique_id, node = rpc_parser.parse_sql_node(node_dict)

self.manifest = ParserUtils.add_new_refs(
manifest=self._base_manifest,
current_project=self.config,
node=node
node=node,
macros=macro_overrides
)

# don't write our new, weird manifest!
self.linker = compile_manifest(self.config, self.manifest, write=False)
selected_uids = [node.unique_id]
Expand Down
2 changes: 0 additions & 2 deletions core/dbt/task/runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,6 @@ def decode_sql(self, sql):
@staticmethod
def raise_invalid_base64(sql):
raise rpc.invalid_params(
code=JSONRPCInvalidParams.CODE,
message=JSONRPCInvalidParams.MESSAGE,
data={
'message': 'invalid base64-encoded sql input',
'sql': str(sql),
Expand Down
7 changes: 7 additions & 0 deletions test/integration/042_sources_test/macros/macro.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% macro override_me() -%}
exceptions.raise_compiler_error('this is a bad macro')
{%- endmacro %}

{% macro happy_little_macro() -%}
{{ override_me() }}
{%- endmacro %}
100 changes: 83 additions & 17 deletions test/integration/042_sources_test/test_sources.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import unittest
from nose.plugins.attrib import attr
from datetime import datetime, timedelta
import json
import os

import multiprocessing
from base64 import standard_b64encode as b64
import requests
import socket
import time


from dbt.exceptions import CompilationException
from test.integration.base import DBTIntegrationTest, use_profile, AnyFloat, \
AnyStringWith
from dbt.main import handle_and_check


class BaseSourcesTest(DBTIntegrationTest):
@property
def schema(self):
Expand Down Expand Up @@ -260,16 +267,6 @@ def test_postgres_malformed_schema_strict_will_break_run(self):
self.run_dbt_with_vars(['run'], strict=True)


import multiprocessing
from base64 import standard_b64encode as b64
import json
import requests
import socket
import time
import os



class ServerProcess(multiprocessing.Process):
def __init__(self, cli_vars=None):
self.port = 22991
Expand Down Expand Up @@ -303,7 +300,7 @@ def start(self):
raise Exception('server never appeared!')


@unittest.skipIf(os.name=='nt', 'Windows not supported for now')
@unittest.skipIf(os.name == 'nt', 'Windows not supported for now')
class TestRPCServer(BaseSourcesTest):
def setUp(self):
super(TestRPCServer, self).setUp()
Expand All @@ -316,10 +313,20 @@ def tearDown(self):
self._server.terminate()
super(TestRPCServer, self).tearDown()

def build_query(self, method, kwargs, sql=None, test_request_id=1):
@property
def project_config(self):
return {
'data-paths': ['test/integration/042_sources_test/data'],
'quoting': {'database': True, 'schema': True, 'identifier': True},
'macro-paths': ['test/integration/042_sources_test/macros'],
}

def build_query(self, method, kwargs, sql=None, test_request_id=1, macros=None):
if sql is not None:
kwargs['sql'] = b64(sql.encode('utf-8')).decode('utf-8')

if macros is not None:
kwargs['macros'] = b64(macros.encode('utf-8')).decode('utf-8')
return {
'jsonrpc': '2.0',
'method': method,
Expand All @@ -333,8 +340,8 @@ def perform_query(self, query):
response = requests.post(url, headers=headers, data=json.dumps(query))
return response

def query(self, _method, _sql=None, _test_request_id=1, **kwargs):
built = self.build_query(_method, kwargs, _sql, _test_request_id)
def query(self, _method, _sql=None, _test_request_id=1, macros=None, **kwargs):
built = self.build_query(_method, kwargs, _sql, _test_request_id, macros)
return self.perform_query(built)

def assertResultHasTimings(self, result, *names):
Expand Down Expand Up @@ -425,7 +432,6 @@ def test_compile(self):
'select * from {{ source("test_source", "test_table") }}',
name='foo'
).json()

self.assertSuccessfulCompilationResult(
source,
'select * from {{ source("test_source", "test_table") }}',
Expand All @@ -434,6 +440,30 @@ def test_compile(self):
self.unique_schema())
)

macro = self.query(
'compile',
'select {{ my_macro() }}',
name='foo',
macros='{% macro my_macro() %}1 as id{% endmacro %}'
).json()
self.assertSuccessfulCompilationResult(
macro,
'select {{ my_macro() }}',
compiled_sql='select 1 as id'
)

macro_override = self.query(
'compile',
'select {{ happy_little_macro() }}',
name='foo',
macros='{% macro override_me() %}2 as id{% endmacro %}'
).json()
self.assertSuccessfulCompilationResult(
macro_override,
'select {{ happy_little_macro() }}',
compiled_sql='select 2 as id'
)

@use_profile('postgres')
def test_run(self):
# seed + run dbt to make models before using them!
Expand Down Expand Up @@ -470,7 +500,6 @@ def test_run(self):
'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1',
name='foo'
).json()

self.assertSuccessfulRunResult(
source,
'select * from {{ source("test_source", "test_table") }} order by updated_at limit 1',
Expand All @@ -483,6 +512,32 @@ def test_run(self):
}
)

macro = self.query(
'run',
'select {{ my_macro() }}',
name='foo',
macros='{% macro my_macro() %}1 as id{% endmacro %}'
).json()
self.assertSuccessfulRunResult(
macro,
raw_sql='select {{ my_macro() }}',
compiled_sql='select 1 as id',
table={'column_names': ['id'], 'rows': [[1.0]]}
)

macro_override = self.query(
'run',
'select {{ happy_little_macro() }}',
name='foo',
macros='{% macro override_me() %}2 as id{% endmacro %}'
).json()
self.assertSuccessfulRunResult(
macro_override,
raw_sql='select {{ happy_little_macro() }}',
compiled_sql='select 2 as id',
table={'column_names': ['id'], 'rows': [[2.0]]}
)

@use_profile('postgres')
def test_invalid_requests(self):
data = self.query(
Expand Down Expand Up @@ -526,6 +581,17 @@ def test_invalid_requests(self):
self.assertIn('logs', error_data)
self.assertTrue(len(error_data['logs']) > 0)

macro_no_override = self.query(
'run',
'select {{ happy_little_macro() }}',
name='foo',
).json()
self.assertIsErrorWithCode(macro_no_override, 10003)
self.assertEqual(error['message'], 'Database Error')
self.assertIn('data', error)
error_data = error['data']
self.assertEqual(error_data['type'], 'DatabaseException')

@use_profile('postgres')
def test_timeout(self):
data = self.query(
Expand Down

0 comments on commit c1c09f3

Please sign in to comment.