diff --git a/dbt/compat.py b/dbt/compat.py index 7f89e6d70c9..bbd76f5cb76 100644 --- a/dbt/compat.py +++ b/dbt/compat.py @@ -19,8 +19,10 @@ if WHICH_PYTHON == 2: from SimpleHTTPServer import SimpleHTTPRequestHandler + from SocketServer import TCPServer else: from http.server import SimpleHTTPRequestHandler + from socketserver import TCPServer def to_unicode(s): diff --git a/dbt/config.py b/dbt/config.py index 46b7e8d3866..843467c2fc0 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -189,6 +189,25 @@ def render_profile_data(self, as_parsed): ) +def _list_if_none(value): + if value is None: + value = [] + return value + + +def _dict_if_none(value): + if value is None: + value = {} + return value + + +def _list_if_none_or_string(value): + value = _list_if_none(value) + if isinstance(value, compat.basestring): + return [value] + return value + + class Project(object): def __init__(self, project_name, version, project_root, profile_name, source_paths, macro_paths, data_paths, test_paths, @@ -220,24 +239,25 @@ def __init__(self, project_name, version, project_root, profile_name, @staticmethod def _preprocess(project_dict): """Pre-process certain special keys to convert them from None values - into empty containers. + into empty containers, and to turn strings into arrays of strings. """ handlers = { - ('archive',): list, - ('on-run-start',): list, - ('on-run-end',): list, + ('archive',): _list_if_none, + ('on-run-start',): _list_if_none_or_string, + ('on-run-end',): _list_if_none_or_string, } + for k in ('models', 'seeds'): - handlers[(k,)] = dict - handlers[(k, 'vars')] = dict - handlers[(k, 'pre-hook')] = list - handlers[(k, 'post-hook')] = list - handlers[('seeds', 'column_types')] = dict + handlers[(k,)] = _dict_if_none + handlers[(k, 'vars')] = _dict_if_none + handlers[(k, 'pre-hook')] = _list_if_none_or_string + handlers[(k, 'post-hook')] = _list_if_none_or_string + handlers[('seeds', 'column_types')] = _dict_if_none def converter(value, keypath): - if value is None and keypath in handlers: + if keypath in handlers: handler = handlers[keypath] - return handler() + return handler(value) else: return value diff --git a/dbt/task/serve.py b/dbt/task/serve.py index 448ca2fcb93..3a94ffbafed 100644 --- a/dbt/task/serve.py +++ b/dbt/task/serve.py @@ -1,10 +1,9 @@ import shutil import os import webbrowser -from socketserver import TCPServer from dbt.include import DOCS_INDEX_FILE_PATH -from dbt.compat import SimpleHTTPRequestHandler +from dbt.compat import SimpleHTTPRequestHandler, TCPServer from dbt.logger import GLOBAL_LOGGER as logger from dbt.task.base_task import RunnableTask diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 649fbb2380c..bf36698de84 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -767,6 +767,23 @@ def test_all_overrides(self): str(project) json.dumps(project.to_project_config()) + def test_string_run_hooks(self): + self.default_project_data.update({ + 'on-run-start': '{{ logging.log_run_start_event() }}', + 'on-run-end': '{{ logging.log_run_end_event() }}', + }) + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertEqual( + project.on_run_start, + ['{{ logging.log_run_start_event() }}'] + ) + self.assertEqual( + project.on_run_end, + ['{{ logging.log_run_end_event() }}'] + ) + def test_invalid_project_name(self): self.default_project_data['name'] = 'invalid-project-name' with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: