From 78e235350b4f1ec8d3d9b14a98f1bb64bff7b9d3 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 19 Feb 2017 21:54:20 -0500 Subject: [PATCH 01/25] first pass at wiring in parser --- dbt/compilation.py | 22 +- dbt/model.py | 2 +- dbt/parser.py | 147 +++++++++++ dbt/utils.py | 2 + test/unit/test_parser.py | 555 +++++++++++++++++++++++++++++++++++++++ 5 files changed, 725 insertions(+), 3 deletions(-) create mode 100644 dbt/parser.py create mode 100644 test/unit/test_parser.py diff --git a/dbt/compilation.py b/dbt/compilation.py index 1069d164fac..0ad22810574 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -41,6 +41,7 @@ class Compiler(object): def __init__(self, project, args): self.project = project self.args = args + self.parsed_models = None self.macro_generator = None @@ -538,9 +539,26 @@ def get_models(self): return all_models + def get_parsed_models(self): + root_project = self.project + all_projects = [root_project] + all_projects.extend(dbt.utils.dependency_projects(self.project)) + + all_models = [] + for project in all_projects: + all_models.extend( + dbt.parser.load_and_parse_models( + package_name=project.get('name'), + root_dir=root_project.get('project-root'), + relative_dirs=project.get('source_paths', []))) + + return all_models + def compile(self): linker = Linker() + parsed_models = self.get_parsed_models() + all_models = self.get_models() all_macros = self.get_macros(this_project=self.project) @@ -552,8 +570,8 @@ def compile(self): self.macro_generator = self.generate_macros(all_macros) enabled_models = [ - model for model in all_models - if model.is_enabled and not model.is_empty + model for model in parsed_models + if model.get('enabled') == True and model.get('empty') == True ] compiled_models, written_models = self.compile_models( diff --git a/dbt/model.py b/dbt/model.py index 2feccfa9260..2ed20e244cc 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -169,7 +169,7 @@ def get_project_config(self, project): for k in SourceConfig.ExtendDictFields: config[k] = {} - model_configs = project['models'] + model_configs = project.get('models') if model_configs is None: return config diff --git a/dbt/parser.py b/dbt/parser.py new file mode 100644 index 00000000000..d69337659a9 --- /dev/null +++ b/dbt/parser.py @@ -0,0 +1,147 @@ +import copy +import jinja2 +import jinja2.sandbox +import os + +import dbt.model +import dbt.utils + + +class SilentUndefined(jinja2.Undefined): + """ + Don't fail to parse because of undefined things. This allows us to parse + models before macros, since we aren't guaranteed to know about macros + before models. + """ + def _fail_with_undefined_error(self, *args, **kwargs): + return None + + __add__ = __radd__ = __mul__ = __rmul__ = __div__ = __rdiv__ = \ + __truediv__ = __rtruediv__ = __floordiv__ = __rfloordiv__ = \ + __mod__ = __rmod__ = __pos__ = __neg__ = __call__ = \ + __getitem__ = __lt__ = __le__ = __gt__ = __ge__ = __int__ = \ + __float__ = __complex__ = __pow__ = __rpow__ = \ + _fail_with_undefined_error + + +def get_path(resource_type, package_name, resource_name): + return "{}.{}.{}".format(resource_type, package_name, resource_name) + +def get_model_path(package_name, resource_name): + return get_path('models', package_name, resource_name) + +def get_macro_path(package_name, resource_name): + return get_path('macros', package_name, resource_name) + +def __ref(model): + + def ref(*args): + model_path = None + + if len(args) == 1: + model_path = get_model_path(model.get('package_name'), args[0]) + elif len(args) == 2: + model_path = get_model_path(args[0], args[1]) + else: + dbt.utils.compiler_error( + model.get('name'), + "ref() takes at most two arguments ({} given)".format( + len(args))) + + model['depends_on'].append(model_path) + + return ref + + +def __config(model, cfg): + + def config(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0: + opts = args[0] + elif len(args) == 0 and len(kwargs) > 0: + opts = kwargs + else: + dbt.utils.compiler_error( + model.get('name'), + "Invalid model config given inline in {}".format(model)) + + cfg.update_in_model_config(opts) + + return config + + +def parse_model(model, root_project_config, package_project_config): + parsed_model = copy.deepcopy(model) + + parsed_model.update({ + 'depends_on': [], + }) + + parts = dbt.utils.split_path(model.get('path', '')) + name, _ = os.path.splitext(parts[-1]) + fqn = ([package_project_config.get('name')] + + parts[1:-1] + + [model.get('name')]) + + config = dbt.model.SourceConfig( + root_project_config, package_project_config, fqn) + + context = { + 'ref': __ref(parsed_model), + 'config': __config(parsed_model, config), + } + + env = jinja2.sandbox.SandboxedEnvironment( + undefined=SilentUndefined) + + env.from_string(model.get('raw_sql')).render(context) + + parsed_model['config'] = config.config + parsed_model['empty'] = (len(model.get('raw_sql').strip()) == 0) + + return parsed_model + + +def parse_models(models, projects): + to_return = {} + + for model in models: + package_name = model.get('package_name', 'root') + + model_path = get_model_path(package_name, model.get('name')) + + to_return[model_path] = parse_model(model, + projects.get('root'), + projects.get(package_name)) + + return to_return + + +def load_and_parse_files(package_name, root_dir, relative_dirs, extension, + resource_type): + file_matches = dbt.clients.system.find_matching( + root_dir, + relative_dirs, + extension) + + models = [] + + for file_match in file_matches: + file_contents = dbt.clients.system.load_file_contents( + file_match.get('absolute_path')) + + # TODO: support more than just models + models.append({ + 'name': os.path.basename(file_match.get('absolute_path')), + 'path': file_match.get('relative_path'), + 'package_name': package_name, + 'raw_sql': file_contents + }) + + return parse_models(models) + + +def load_and_parse_models(package_name, root_dir, relative_dirs): + return load_and_parse_files(package_name, root_dir, relative_dirs, + extension="[!.#~]*.sql", + resource_type='models') diff --git a/dbt/utils.py b/dbt/utils.py index 3be15b85e0a..dee66157d8d 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -34,6 +34,8 @@ def __repr__(self): def compiler_error(model, msg): if model is None: name = '' + elif model is str: + name = model else: name = model.nice_name diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py new file mode 100644 index 00000000000..8c40d48bc6d --- /dev/null +++ b/test/unit/test_parser.py @@ -0,0 +1,555 @@ +from mock import MagicMock +import unittest + +import os + +import dbt.parser + + +class ParserTest(unittest.TestCase): + + def find_input_by_name(self, models, name): + return next( + (model for model in models if model.get('name') == name), + {}) + + def setUp(self): + self.maxDiff = None + + self.root_project_config = { + 'name': 'root_project', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + } + + self.snowplow_project_config = { + 'name': 'snowplow', + 'version': '0.1', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + } + + self.model_config = { + 'enabled': True, + 'materialized': 'view', + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + } + + def test__single_model(self): + models = [{ + 'name': 'model_one', + 'package_name': 'root', + 'raw_sql': ("select * from events"), + }] + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.root.model_one': { + 'name': 'model_one', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__empty_model(self): + models = [{ + 'name': 'model_one', + 'package_name': 'root', + 'raw_sql': (" "), + }] + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config}), + { + 'models.root.model_one': { + 'name': 'model_one', + 'empty': True, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__simple_dependency(self): + models = [{ + 'name': 'base', + 'package_name': 'root', + 'raw_sql': 'select * from events' + }, { + 'name': 'events_tx', + 'package_name': 'root', + 'raw_sql': "select * from {{ref('base')}}" + }] + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.root.base': { + 'name': 'base', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'base').get('raw_sql') + }, + 'models.root.events_tx': { + 'name': 'events_tx', + 'empty': False, + 'package_name': 'root', + 'depends_on': ['models.root.base'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'events_tx').get('raw_sql') + } + } + ) + + def test__multiple_dependencies(self): + models = [{ + 'name': 'events', + 'package_name': 'root', + 'raw_sql': 'select * from base.events', + }, { + 'name': 'sessions', + 'package_name': 'root', + 'raw_sql': 'select * from base.sessions', + }, { + 'name': 'events_tx', + 'package_name': 'root', + 'raw_sql': ("with events as (select * from {{ref('events')}}) " + "select * from events"), + }, { + 'name': 'sessions_tx', + 'package_name': 'root', + 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " + "select * from sessions"), + }, { + 'name': 'multi', + 'package_name': 'root', + 'raw_sql': ("with s as (select * from {{ref('sessions_tx')}}), " + "e as (select * from {{ref('events_tx')}}) " + "select * from e left join s on s.id = e.sid"), + }] + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.root.events': { + 'name': 'events', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'events').get('raw_sql') + }, + 'models.root.sessions': { + 'name': 'sessions', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'sessions').get('raw_sql') + }, + 'models.root.events_tx': { + 'name': 'events_tx', + 'empty': False, + 'package_name': 'root', + 'depends_on': ['models.root.events'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'events_tx').get('raw_sql') + }, + 'models.root.sessions_tx': { + 'name': 'sessions_tx', + 'empty': False, + 'package_name': 'root', + 'depends_on': ['models.root.sessions'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'sessions_tx').get('raw_sql') + }, + 'models.root.multi': { + 'name': 'multi', + 'empty': False, + 'package_name': 'root', + 'depends_on': ['models.root.sessions_tx', + 'models.root.events_tx'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'multi').get('raw_sql') + } + } + ) + + def test__multiple_dependencies__packages(self): + models = [{ + 'name': 'events', + 'package_name': 'snowplow', + 'raw_sql': 'select * from base.events', + }, { + 'name': 'sessions', + 'package_name': 'snowplow', + 'raw_sql': 'select * from base.sessions', + }, { + 'name': 'events_tx', + 'package_name': 'snowplow', + 'raw_sql': ("with events as (select * from {{ref('events')}}) " + "select * from events"), + }, { + 'name': 'sessions_tx', + 'package_name': 'snowplow', + 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " + "select * from sessions"), + }, { + 'name': 'multi', + 'package_name': 'root', + 'raw_sql': ("with s as (select * from {{ref('snowplow', 'sessions_tx')}}), " + "e as (select * from {{ref('snowplow', 'events_tx')}}) " + "select * from e left join s on s.id = e.sid"), + }] + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.snowplow.events': { + 'name': 'events', + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'events').get('raw_sql') + }, + 'models.snowplow.sessions': { + 'name': 'sessions', + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'sessions').get('raw_sql') + }, + 'models.snowplow.events_tx': { + 'name': 'events_tx', + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': ['models.snowplow.events'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'events_tx').get('raw_sql') + }, + 'models.snowplow.sessions_tx': { + 'name': 'sessions_tx', + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': ['models.snowplow.sessions'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'sessions_tx').get('raw_sql') + }, + 'models.root.multi': { + 'name': 'multi', + 'empty': False, + 'package_name': 'root', + 'depends_on': ['models.snowplow.sessions_tx', + 'models.snowplow.events_tx'], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'multi').get('raw_sql') + } + } + ) + + def test__in_model_config(self): + models = [{ + 'name': 'model_one', + 'package_name': 'root', + 'raw_sql': ("{{config({'materialized':'table'})}}" + "select * from events"), + }] + + self.model_config.update({ + 'materialized': 'table' + }) + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.root.model_one': { + 'name': 'model_one', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + + def test__root_project_config(self): + self.root_project_config = { + 'name': 'root_project', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + 'models': { + 'materialized': 'ephemeral', + 'root_project': { + 'view': { + 'materialized': 'view' + } + } + } + } + + models = [{ + 'name': 'table', + 'package_name': 'root', + 'path': 'table.sql', + 'raw_sql': ("{{config({'materialized':'table'})}}" + "select * from events"), + }, { + 'name': 'ephemeral', + 'package_name': 'root', + 'path': 'ephemeral.sql', + 'raw_sql': ("select * from events"), + }, { + 'name': 'view', + 'package_name': 'root', + 'path': 'view.sql', + 'raw_sql': ("select * from events"), + }] + + self.model_config.update({ + 'materialized': 'table' + }) + + ephemeral_config = self.model_config.copy() + ephemeral_config.update({ + 'materialized': 'ephemeral' + }) + + view_config = self.model_config.copy() + view_config.update({ + 'materialized': 'view' + }) + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.root.table': { + 'name': 'table', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'table.sql', + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'table').get('raw_sql') + }, + 'models.root.ephemeral': { + 'name': 'ephemeral', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'ephemeral.sql', + 'config': ephemeral_config, + 'raw_sql': self.find_input_by_name( + models, 'ephemeral').get('raw_sql') + }, + 'models.root.view': { + 'name': 'view', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'view.sql', + 'config': view_config, + 'raw_sql': self.find_input_by_name( + models, 'ephemeral').get('raw_sql') + } + } + + ) + + def test__other_project_config(self): + self.root_project_config = { + 'name': 'root_project', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + 'models': { + 'materialized': 'ephemeral', + 'root_project': { + 'view': { + 'materialized': 'view' + } + }, + 'snowplow': { + 'enabled': False, + 'views': { + 'materialized': 'view', + } + } + } + } + + self.snowplow_project_config = { + 'name': 'snowplow', + 'version': '0.1', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + 'models': { + 'enabled': False, + 'views': { + 'materialized': 'table', + 'sort': 'timestamp' + } + } + } + + models = [{ + 'name': 'table', + 'package_name': 'root', + 'path': 'table.sql', + 'raw_sql': ("{{config({'materialized':'table'})}}" + "select * from events"), + }, { + 'name': 'ephemeral', + 'package_name': 'root', + 'path': 'ephemeral.sql', + 'raw_sql': ("select * from events"), + }, { + 'name': 'view', + 'package_name': 'root', + 'path': 'view.sql', + 'raw_sql': ("select * from events"), + }, { + 'name': 'disabled', + 'package_name': 'snowplow', + 'path': 'disabled.sql', + 'raw_sql': ("select * from events"), + }, { + 'name': 'package', + 'package_name': 'snowplow', + 'path': 'models/views/package.sql', + 'raw_sql': ("select * from events"), + }] + + self.model_config.update({ + 'materialized': 'table' + }) + + ephemeral_config = self.model_config.copy() + ephemeral_config.update({ + 'materialized': 'ephemeral' + }) + + view_config = self.model_config.copy() + view_config.update({ + 'materialized': 'view' + }) + + disabled_config = self.model_config.copy() + disabled_config.update({ + 'enabled': False, + 'materialized': 'ephemeral' + }) + + sort_config = self.model_config.copy() + sort_config.update({ + 'enabled': False, + 'materialized': 'view' + }) + + self.assertEquals( + dbt.parser.parse_models( + models, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'models.root.table': { + 'name': 'table', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'table.sql', + 'config': self.model_config, + 'raw_sql': self.find_input_by_name( + models, 'table').get('raw_sql') + }, + 'models.root.ephemeral': { + 'name': 'ephemeral', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'ephemeral.sql', + 'config': ephemeral_config, + 'raw_sql': self.find_input_by_name( + models, 'ephemeral').get('raw_sql') + }, + 'models.root.view': { + 'name': 'view', + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'path': 'view.sql', + 'config': view_config, + 'raw_sql': self.find_input_by_name( + models, 'view').get('raw_sql') + }, + 'models.snowplow.disabled': { + 'name': 'disabled', + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'path': 'disabled.sql', + 'config': disabled_config, + 'raw_sql': self.find_input_by_name( + models, 'disabled').get('raw_sql') + }, + 'models.snowplow.package': { + 'name': 'package', + 'empty': False, + 'package_name': 'snowplow', + 'depends_on': [], + 'path': 'models/views/package.sql', + 'config': sort_config, + 'raw_sql': self.find_input_by_name( + models, 'package').get('raw_sql') + } + } + ) From 39e2c1c5e83c107067ee2baa38f3fd0f538e3f8a Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 20 Feb 2017 09:40:50 -0500 Subject: [PATCH 02/25] add fqn to model representation, rewiring some of the compiler --- dbt/compilation.py | 91 +++++++++++++++++++--------------------- dbt/parser.py | 8 +++- dbt/utils.py | 17 ++++++++ test/unit/test_parser.py | 23 ++++++++++ 4 files changed, 89 insertions(+), 50 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 0ad22810574..0ac43eef9f3 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -149,21 +149,17 @@ def model_can_reference(self, src_model, other_model): def __ref(self, linker, ctx, model, all_models): schema = ctx['env']['schema'] - source_model = tuple(model.fqn) - linker.add_node(source_model) + model_id = model.get('unique_id') + linker.add_node(model_id) def do_ref(*args): + target_model_name = None + target_model_package = None + if len(args) == 1: - other_model_name = args[0] - other_model = find_model_by_name(all_models, other_model_name) + target_model_name = args[0] elif len(args) == 2: - other_model_package, other_model_name = args - - other_model = find_model_by_name( - all_models, - other_model_name, - package_namespace=other_model_package - ) + target_model_package, target_model_name = args else: compiler_error( model, @@ -172,43 +168,38 @@ def do_ref(*args): ) ) - other_model_fqn = tuple(other_model.fqn[:-1] + [other_model_name]) - src_fqn = ".".join(source_model) - ref_fqn = ".".join(other_model_fqn) + target_model = find_model_by_unique_id(all_models, + target_model_name, + target_model_package) + target_model_id = target_model.get('unique_id') - if not other_model.is_enabled: - raise RuntimeError( + if target_model.get('enabled') == False: + compiler_error( "Model '{}' depends on model '{}' which is disabled in " - "the project config".format(src_fqn, ref_fqn) - ) + "the project config".format(model.get('unique_id'), + target_model.get('unique_id'))) # this creates a trivial cycle -- should this be a compiler error? # we can still interpolate the name w/o making a self-cycle - if source_model == other_model_fqn: + if model_id == target_model_id: pass else: - linker.dependency(source_model, other_model_fqn) + linker.dependency(model_id, target_model_id) - if other_model.is_ephemeral: - linker.inject_cte(model, other_model) - return other_model.cte_name + if other_model.get('ephemeral') == True: + model['extra_ctes'].append(target_model_id) + return '__dbt__CTE__{}'.format(target_model.get('name')) else: - return '"{}"."{}"'.format(schema, other_model_name) + return '"{}"."{}"'.format(schema, target_model.get('name')) def wrapped_do_ref(*args): try: return do_ref(*args) except RuntimeError as e: - root = os.path.relpath( - model.root_dir, - model.project['project-root'] - ) - - filepath = os.path.join(root, model.rel_filepath) - logger.info("Compiler error in {}".format(filepath)) + logger.info("Compiler error in {}".format(model.get('path'))) logger.info("Enabled models:") for m in all_models: - logger.info(" - {}".format(".".join(m.fqn))) + logger.info(" - {}".format(".".join(m.get('fqn')))) raise e return wrapped_do_ref @@ -251,22 +242,26 @@ def get_context(self, linker, model, models): def compile_model(self, linker, model, models): try: - fs_loader = jinja2.FileSystemLoader(searchpath=model.root_dir) - jinja = jinja2.Environment(loader=fs_loader) - - template_contents = dbt.clients.system.load_file_contents( - model.absolute_path) + compiled_model = model.copy() + compiled_model.update({ + 'compiled': False, + 'extra_ctes': [], + 'compiled_sql': None + }) - template = jinja.from_string(template_contents) context = self.get_context(linker, model, models) - rendered = template.render(context) + env = jinja2.sandbox.SandboxedEnvironment() + + compiled_model['compiled_sql'] = env.from_string( + model.get('raw_sql')).render(context) + compiled_model['compiled'] = True except jinja2.exceptions.TemplateSyntaxError as e: compiler_error(model, str(e)) except jinja2.exceptions.UndefinedError as e: compiler_error(model, str(e)) - return rendered + return compiled_model def write_graph_file(self, linker): filename = graph_file_name @@ -380,16 +375,16 @@ def remove_node_from_graph(self, linker, model, models): ) def compile_models(self, linker, models): - compiled_models = {model: self.compile_model(linker, model, models) + compiled_models = {self.compile_model(linker, model, models) for model in models} - sorted_models = [find_model_by_fqn(models, fqn) - for fqn in linker.as_topological_ordering()] + sorted_models = [models[path] + for path in linker.as_topological_ordering()] written_models = [] for model in sorted_models: # in-model configs were just evaluated. Evict anything that is # newly-disabled - if not model.is_enabled: + if model.get('enabled') == False: self.remove_node_from_graph(linker, model, models) continue @@ -400,13 +395,11 @@ def compile_models(self, linker, models): context = self.get_context(linker, model, models) wrapped_stmt = model.compile(injected_stmt, self.project, context) - serialized = model.serialize() - linker.update_node_data(tuple(model.fqn), serialized) - - if model.is_ephemeral: + if model.get('ephemeral') == True: continue - self.__write(model.build_path(), wrapped_stmt) + build_path = os.path.join('build', model.get('path')) + self.__write(build_path, wrapped_stmt) written_models.append(model) return compiled_models, written_models diff --git a/dbt/parser.py b/dbt/parser.py index d69337659a9..5482755e0c8 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -70,7 +70,8 @@ def config(*args, **kwargs): return config -def parse_model(model, root_project_config, package_project_config): +def parse_model(model, model_path, root_project_config, + package_project_config): parsed_model = copy.deepcopy(model) parsed_model.update({ @@ -96,8 +97,10 @@ def parse_model(model, root_project_config, package_project_config): env.from_string(model.get('raw_sql')).render(context) + parsed_model['unique_id'] = model_path parsed_model['config'] = config.config parsed_model['empty'] = (len(model.get('raw_sql').strip()) == 0) + parsed_model['fqn'] = fqn return parsed_model @@ -110,7 +113,9 @@ def parse_models(models, projects): model_path = get_model_path(package_name, model.get('name')) + # TODO if this is set, raise a compiler error to_return[model_path] = parse_model(model, + model_path, projects.get('root'), projects.get(package_name)) @@ -133,6 +138,7 @@ def load_and_parse_files(package_name, root_dir, relative_dirs, extension, # TODO: support more than just models models.append({ 'name': os.path.basename(file_match.get('absolute_path')), + 'root_path': root_dir, 'path': file_match.get('relative_path'), 'package_name': package_name, 'raw_sql': file_contents diff --git a/dbt/utils.py b/dbt/utils.py index dee66157d8d..c4df2c42b07 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -90,6 +90,23 @@ def __call__(self, var_name, default=None): return default +def model_cte_name(model): + return '__dbt__CTE__{}'.format(model.get('name')) + + +def find_model_by_unique_id(all_models, target_model_name, + target_model_package): + + for name, model in all_models.items(): + resource_type, package_name, model_name = name.split('.') + + if ((target_model_name == model_name) and \ + (target_model_package is None or + target_model_package == package_name)): + return model + + return None + def find_model_by_name(models, name, package_namespace=None): found = [] for model in models: diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 8c40d48bc6d..6c5f795aad4 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -52,6 +52,7 @@ def test__single_model(self): { 'models.root.model_one': { 'name': 'model_one', + 'fqn': ['root_project', 'model_one'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -76,6 +77,7 @@ def test__empty_model(self): { 'models.root.model_one': { 'name': 'model_one', + 'fqn': ['root_project', 'model_one'], 'empty': True, 'package_name': 'root', 'depends_on': [], @@ -105,6 +107,7 @@ def test__simple_dependency(self): { 'models.root.base': { 'name': 'base', + 'fqn': ['root_project', 'base'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -114,6 +117,7 @@ def test__simple_dependency(self): }, 'models.root.events_tx': { 'name': 'events_tx', + 'fqn': ['root_project', 'events_tx'], 'empty': False, 'package_name': 'root', 'depends_on': ['models.root.base'], @@ -159,6 +163,7 @@ def test__multiple_dependencies(self): { 'models.root.events': { 'name': 'events', + 'fqn': ['root_project', 'events'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -168,6 +173,7 @@ def test__multiple_dependencies(self): }, 'models.root.sessions': { 'name': 'sessions', + 'fqn': ['root_project', 'sessions'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -177,6 +183,7 @@ def test__multiple_dependencies(self): }, 'models.root.events_tx': { 'name': 'events_tx', + 'fqn': ['root_project', 'events_tx'], 'empty': False, 'package_name': 'root', 'depends_on': ['models.root.events'], @@ -186,6 +193,7 @@ def test__multiple_dependencies(self): }, 'models.root.sessions_tx': { 'name': 'sessions_tx', + 'fqn': ['root_project', 'sessions_tx'], 'empty': False, 'package_name': 'root', 'depends_on': ['models.root.sessions'], @@ -195,6 +203,7 @@ def test__multiple_dependencies(self): }, 'models.root.multi': { 'name': 'multi', + 'fqn': ['root_project', 'multi'], 'empty': False, 'package_name': 'root', 'depends_on': ['models.root.sessions_tx', @@ -241,6 +250,7 @@ def test__multiple_dependencies__packages(self): { 'models.snowplow.events': { 'name': 'events', + 'fqn': ['snowplow', 'events'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], @@ -250,6 +260,7 @@ def test__multiple_dependencies__packages(self): }, 'models.snowplow.sessions': { 'name': 'sessions', + 'fqn': ['snowplow', 'sessions'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], @@ -259,6 +270,7 @@ def test__multiple_dependencies__packages(self): }, 'models.snowplow.events_tx': { 'name': 'events_tx', + 'fqn': ['snowplow', 'events_tx'], 'empty': False, 'package_name': 'snowplow', 'depends_on': ['models.snowplow.events'], @@ -268,6 +280,7 @@ def test__multiple_dependencies__packages(self): }, 'models.snowplow.sessions_tx': { 'name': 'sessions_tx', + 'fqn': ['snowplow', 'sessions_tx'], 'empty': False, 'package_name': 'snowplow', 'depends_on': ['models.snowplow.sessions'], @@ -277,6 +290,7 @@ def test__multiple_dependencies__packages(self): }, 'models.root.multi': { 'name': 'multi', + 'fqn': ['root_project', 'multi'], 'empty': False, 'package_name': 'root', 'depends_on': ['models.snowplow.sessions_tx', @@ -308,6 +322,7 @@ def test__in_model_config(self): { 'models.root.model_one': { 'name': 'model_one', + 'fqn': ['root_project', 'model_one'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -374,6 +389,7 @@ def test__root_project_config(self): { 'models.root.table': { 'name': 'table', + 'fqn': ['root_project', 'table'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -384,6 +400,7 @@ def test__root_project_config(self): }, 'models.root.ephemeral': { 'name': 'ephemeral', + 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -394,6 +411,7 @@ def test__root_project_config(self): }, 'models.root.view': { 'name': 'view', + 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -503,6 +521,7 @@ def test__other_project_config(self): { 'models.root.table': { 'name': 'table', + 'fqn': ['root_project', 'table'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -513,6 +532,7 @@ def test__other_project_config(self): }, 'models.root.ephemeral': { 'name': 'ephemeral', + 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -523,6 +543,7 @@ def test__other_project_config(self): }, 'models.root.view': { 'name': 'view', + 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -533,6 +554,7 @@ def test__other_project_config(self): }, 'models.snowplow.disabled': { 'name': 'disabled', + 'fqn': ['snowplow', 'disabled'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], @@ -543,6 +565,7 @@ def test__other_project_config(self): }, 'models.snowplow.package': { 'name': 'package', + 'fqn': ['snowplow', 'views', 'package'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], From beedfd5b735bdae9d703ee767a8d57670f826434 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Wed, 22 Feb 2017 21:54:33 -0500 Subject: [PATCH 03/25] almost there --- dbt/compilation.py | 378 +++++++++++++++++++++++++------- dbt/compiled_model.py | 1 + dbt/contracts/graph/__init__.py | 0 dbt/contracts/graph/compiled.py | 44 ++++ dbt/contracts/graph/parsed.py | 66 ++++++ dbt/contracts/graph/unparsed.py | 27 +++ dbt/contracts/project.py | 28 +++ dbt/model.py | 8 +- dbt/parser.py | 51 +++-- dbt/runner.py | 1 + dbt/utils.py | 48 ++-- test/unit/test_graph.py | 6 +- test/unit/test_parser.py | 118 +++++++++- 13 files changed, 632 insertions(+), 144 deletions(-) create mode 100644 dbt/contracts/graph/__init__.py create mode 100644 dbt/contracts/graph/compiled.py create mode 100644 dbt/contracts/graph/parsed.py create mode 100644 dbt/contracts/graph/unparsed.py create mode 100644 dbt/contracts/project.py diff --git a/dbt/compilation.py b/dbt/compilation.py index 0ac43eef9f3..59f40f644e6 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -8,12 +8,19 @@ import dbt.project import dbt.utils +from dbt.model import Model from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ split_path, This, Var, compiler_error, to_string from dbt.linker import Linker from dbt.runtime import RuntimeContext + +import dbt.contracts.graph.compiled +import dbt.contracts.graph.parsed +import dbt.contracts.project +import dbt.flags +import dbt.parser import dbt.templates from dbt.adapters.factory import get_adapter @@ -37,6 +44,95 @@ def compile_string(string, ctx): compiler_error(None, str(e)) +def prepend_ctes(model, all_models): + model, _, all_models = recursively_prepend_ctes(model, all_models) + + return (model, all_models) + + +def recursively_prepend_ctes(model, all_models): + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.compiled.validate_one(model) + dbt.contracts.graph.compiled.validate(all_models) + + model = model.copy() + prepend_ctes = [] + + if model.get('all_ctes_injected') == True: + return (model, model.get('extra_cte_ids'), all_models) + + for cte_id in model.get('extra_cte_ids'): + cte_to_add = all_models.get(cte_id) + cte_to_add, new_prepend_ctes, all_models = recursively_prepend_ctes( + cte_to_add, all_models) + + prepend_ctes = new_prepend_ctes + prepend_ctes + new_cte_name = '__dbt__CTE__{}'.format(cte_to_add.get('name')) + prepend_ctes.append(' {} as (\n{}\n)'.format( + new_cte_name, + cte_to_add.get('compiled_sql'))) + + model['extra_ctes_injected'] = True + model['extra_cte_sql'] = prepend_ctes + model['injected_sql'] = inject_ctes_into_sql( + model.get('compiled_sql'), + model.get('extra_cte_sql')) + + all_models[model.get('unique_id')] = model + + return (model, prepend_ctes, all_models) + + +def inject_ctes_into_sql(sql, ctes): + """ + `ctes` is a list of CTEs in the form: + + [ "__dbt__CTE__ephemeral as (select * from table)", + "__dbt__CTE__events as (select id, type from events)" ] + + Given `sql` like: + + "with internal_cte as (select * from sessions) + select * from internal_cte" + + This will spit out: + + "with __dbt__CTE__ephemeral as (select * from table), + __dbt__CTE__events as (select id, type from events), + with internal_cte as (select * from sessions) + select * from internal_cte" + + (Whitespace enhanced for readability.) + """ + if len(ctes) == 0: + return sql + + parsed_stmts = sqlparse.parse(sql) + parsed = parsed_stmts[0] + + with_stmt = None + for token in parsed.tokens: + if token.is_keyword and token.normalized == 'WITH': + with_stmt = token + break + + if with_stmt is None: + # no with stmt, add one, and inject CTEs right at the beginning + first_token = parsed.token_first() + with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') + parsed.insert_before(first_token, with_stmt) + else: + # stmt exists, add a comma (which will come after injected CTEs) + trailing_comma = sqlparse.sql.Token(sqlparse.tokens.Punctuation, ',') + parsed.insert_after(with_stmt, trailing_comma) + + parsed.insert_after( + with_stmt, + sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(ctes))) + + return str(parsed) + + class Compiler(object): def __init__(self, project, args): self.project = project @@ -114,23 +210,8 @@ def __write(self, build_filepath, payload): def __model_config(self, model, linker): def do_config(*args, **kwargs): - if len(args) == 1 and len(kwargs) == 0: - opts = args[0] - elif len(args) == 0 and len(kwargs) > 0: - opts = kwargs - else: - raise RuntimeError( - "Invalid model config given inline in {}".format(model) - ) + return '' - if type(opts) != dict: - raise RuntimeError( - "Invalid model config given inline in {}".format(model) - ) - - model.update_in_model_config(opts) - model.add_to_prologue("Config specified in model: {}".format(opts)) - return "" return do_config def model_can_reference(self, src_model, other_model): @@ -146,11 +227,8 @@ def model_can_reference(self, src_model, other_model): src_model.own_project['name'] == src_model.project['name'] ) - def __ref(self, linker, ctx, model, all_models): - schema = ctx['env']['schema'] - - model_id = model.get('unique_id') - linker.add_node(model_id) + def __ref(self, ctx, model, all_models): + schema = ctx.get('env', {}).get('schema') def do_ref(*args): target_model_name = None @@ -168,26 +246,35 @@ def do_ref(*args): ) ) - target_model = find_model_by_unique_id(all_models, - target_model_name, - target_model_package) + target_model = dbt.utils.find_model_by_name( + all_models, + target_model_name, + target_model_package) + + if target_model is None: + compiler_error( + model, + "Model '{}' depends on model '{}' which was not found." + .format(model.get('unique_id'), target_model_name)) + target_model_id = target_model.get('unique_id') - if target_model.get('enabled') == False: + if target_model.get('config', {}) \ + .get('enabled') == False: compiler_error( + model, "Model '{}' depends on model '{}' which is disabled in " "the project config".format(model.get('unique_id'), target_model.get('unique_id'))) - # this creates a trivial cycle -- should this be a compiler error? - # we can still interpolate the name w/o making a self-cycle - if model_id == target_model_id: - pass - else: - linker.dependency(model_id, target_model_id) + print('dependency {} to {}'.format(model.get('unique_id'), + target_model.get('unique_id'))) + + model['depends_on'].append(target_model_id) + if target_model.get('config', {}) \ + .get('materialized') == 'ephemeral': - if other_model.get('ephemeral') == True: - model['extra_ctes'].append(target_model_id) + model['extra_cte_ids'].append(target_model_id) return '__dbt__CTE__{}'.format(target_model.get('name')) else: return '"{}"."{}"'.format(schema, target_model.get('name')) @@ -198,19 +285,55 @@ def wrapped_do_ref(*args): except RuntimeError as e: logger.info("Compiler error in {}".format(model.get('path'))) logger.info("Enabled models:") - for m in all_models: + for n,m in all_models.items(): logger.info(" - {}".format(".".join(m.get('fqn')))) raise e return wrapped_do_ref + def get_compiler_context(self, linker, model, models): + runtime = RuntimeContext(model=model) + + context = self.project.context() + + # built-ins + context['ref'] = self.__ref(context, model, models) + context['config'] = self.__model_config(model, linker) + #context['this'] = This( + # context['env']['schema'], model.immediate_name, model.name + #) + context['var'] = Var(model, context=context) + context['target'] = self.project.get_target() + + # these get re-interpolated at runtime! + context['run_started_at'] = '{{ run_started_at }}' + context['invocation_id'] = '{{ invocation_id }}' + + adapter = get_adapter(self.project.run_environment()) + context['sql_now'] = adapter.date_function + + runtime.update_global(context) + + # add in macros (can we cache these somehow?) + for macro_data in self.macro_generator(context): + macro = macro_data["macro"] + macro_name = macro_data["name"] + project = macro_data["project"] + + runtime.update_package(project['name'], {macro_name: macro}) + + if project['name'] == self.project['name']: + runtime.update_global({macro_name: macro}) + + return runtime + def get_context(self, linker, model, models): runtime = RuntimeContext(model=model) context = self.project.context() # built-ins - context['ref'] = self.__ref(linker, context, model, models) + context['ref'] = self.__ref(context, model, models) context['config'] = self.__model_config(model, linker) context['this'] = This( context['env']['schema'], model.immediate_name, model.name @@ -245,16 +368,20 @@ def compile_model(self, linker, model, models): compiled_model = model.copy() compiled_model.update({ 'compiled': False, - 'extra_ctes': [], - 'compiled_sql': None + 'compiled_sql': None, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': None, }) - context = self.get_context(linker, model, models) + context = self.get_compiler_context(linker, compiled_model, models) env = jinja2.sandbox.SandboxedEnvironment() compiled_model['compiled_sql'] = env.from_string( model.get('raw_sql')).render(context) + compiled_model['compiled'] = True except jinja2.exceptions.TemplateSyntaxError as e: compiler_error(model, str(e)) @@ -329,6 +456,35 @@ def __recursive_add_ctes(self, linker, model): return models_to_add + def new_add_cte_to_rendered_query(self, linker, primary_model, + compiled_models): + + fqn_to_model = {tuple(model.fqn): model for model in compiled_models} + sorted_nodes = linker.as_topological_ordering() + + models_to_add = self.__recursive_add_ctes(linker, primary_model) + + required_ctes = [] + for node in sorted_nodes: + + if node not in fqn_to_model: + continue + + model = fqn_to_model[node] + # add these in topological sort order -- significant for CTEs + if model.is_ephemeral and model in models_to_add: + required_ctes.append(model) + + query = compiled_models[primary_model] + if len(required_ctes) == 0: + return query + else: + compiled_query = self.combine_query_with_ctes( + primary_model, query, required_ctes, compiled_models + ) + return compiled_query + + def add_cte_to_rendered_query( self, linker, primary_model, compiled_models ): @@ -375,32 +531,80 @@ def remove_node_from_graph(self, linker, model, models): ) def compile_models(self, linker, models): - compiled_models = {self.compile_model(linker, model, models) - for model in models} - sorted_models = [models[path] - for path in linker.as_topological_ordering()] + all_projects = {'root': self.project} + dependency_projects = dbt.utils.dependency_projects(self.project) + for project in dependency_projects: + name = project.cfg.get('name', 'unknown') + all_projects[name] = project + + compiled_models = {} + injected_models = {} + wrapped_models = {} written_models = [] - for model in sorted_models: - # in-model configs were just evaluated. Evict anything that is - # newly-disabled - if model.get('enabled') == False: - self.remove_node_from_graph(linker, model, models) - continue - injected_stmt = self.add_cte_to_rendered_query( - linker, model, compiled_models - ) + for name, model in models.items(): + compiled_models[name] = self.compile_model(linker, model, models) - context = self.get_context(linker, model, models) - wrapped_stmt = model.compile(injected_stmt, self.project, context) + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.compiled.validate(compiled_models) - if model.get('ephemeral') == True: - continue + for name, model in compiled_models.items(): + model, compiled_models = prepend_ctes(model, compiled_models) + injected_models[name] = model - build_path = os.path.join('build', model.get('path')) - self.__write(build_path, wrapped_stmt) - written_models.append(model) + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.compiled.validate(injected_models) + + for name, injected_model in injected_models.items(): + # now turn a model back into the old-style model object + model = Model( + self.project, + injected_model.get('root_path'), + injected_model.get('path'), + all_projects[injected_model.get('package_name')]) + model._config = injected_model.get('config', {}) + + context = self.get_context(linker, model, injected_models) + + wrapped_stmt = model.compile( + injected_model.get('injected_sql'), self.project, context) + + injected_model['wrapped_sql'] = wrapped_stmt + wrapped_model = injected_model + wrapped_models[name] = wrapped_model + + build_path = os.path.join('build', injected_model.get('path')) + if injected_model.get('config', {}) \ + .get('materialized') != 'ephemeral': + self.__write(build_path, wrapped_stmt) + written_models.append(model) + + linker.add_node(tuple(wrapped_model.get('fqn'))) + project = all_projects[wrapped_model.get('package_name')] + + linker.update_node_data( + tuple(wrapped_model.get('fqn')), + { + 'materialized': (wrapped_model.get('config', {}) + .get('materialized')), + 'dbt_run_type': dbt.model.NodeType.Model, + 'enabled': (wrapped_model.get('config', {}) + .get('enabled')), + 'build_path': os.path.join(project['target-path'], + build_path), + 'name': wrapped_model.get('name'), + + # I think we always use tmp_name. - Connor + 'tmp_name': model.tmp_name(), + 'project_name': project.cfg.get('name') + }) + + for dependency in wrapped_model.get('depends_on'): + if wrapped_models.get(dependency): + linker.dependency( + tuple(wrapped_model.get('fqn')), + tuple(wrapped_models.get(dependency).get('fqn'))) return compiled_models, written_models @@ -465,14 +669,15 @@ def compile_schema_tests(self, linker, models): # show a warning if the model being tested doesn't exist try: source_model = find_model_by_name(models, - schema_test.model_name) + schema_test.model_name, + None) except RuntimeError as e: dbt.utils.compiler_warning(schema_test, str(e)) continue serialized = schema_test.serialize() - model_node = tuple(source_model.fqn) + model_node = tuple(source_model.get('fqn')) test_node = tuple(schema_test.fqn) linker.dependency(test_node, model_node) @@ -532,27 +737,42 @@ def get_models(self): return all_models - def get_parsed_models(self): - root_project = self.project - all_projects = [root_project] - all_projects.extend(dbt.utils.dependency_projects(self.project)) + def get_all_projects(self): + root_project = self.project.cfg + all_projects = {'root': root_project} + dependency_projects = dbt.utils.dependency_projects(self.project) - all_models = [] - for project in all_projects: - all_models.extend( + for project in dependency_projects: + name = project.cfg.get('name', 'unknown') + all_projects[name] = project.cfg + + if dbt.flags.STRICT_MODE: + dbt.contracts.project.validate_list(all_projects) + + return all_projects + + + def get_parsed_models(self, root_project, all_projects): + parsed_models = {} + + for name, project in all_projects.items(): + parsed_models.update( dbt.parser.load_and_parse_models( - package_name=project.get('name'), - root_dir=root_project.get('project-root'), - relative_dirs=project.get('source_paths', []))) + package_name=name, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('source-paths', []))) - return all_models + return parsed_models def compile(self): linker = Linker() - parsed_models = self.get_parsed_models() + root_project = self.project.cfg + all_projects = self.get_all_projects() + + parsed_models = self.get_parsed_models(root_project, all_projects) - all_models = self.get_models() all_macros = self.get_macros(this_project=self.project) for project in dbt.utils.dependency_projects(self.project): @@ -562,10 +782,14 @@ def compile(self): self.macro_generator = self.generate_macros(all_macros) - enabled_models = [ - model for model in parsed_models - if model.get('enabled') == True and model.get('empty') == True - ] + enabled_models = {} + + # TODO: don't worry about testing enabled here. do it after + # linking + for name, model in parsed_models.items(): + if model.get('config', {}).get('enabled') == True and \ + model.get('empty') == False: + enabled_models[model.get('unique_id')] = model compiled_models, written_models = self.compile_models( linker, enabled_models diff --git a/dbt/compiled_model.py b/dbt/compiled_model.py index 2d95a356ef6..bea4d29d072 100644 --- a/dbt/compiled_model.py +++ b/dbt/compiled_model.py @@ -53,6 +53,7 @@ def contents(self): if self._contents is None: with open(self.data['build_path']) as fh: self._contents = to_unicode(fh.read(), 'utf-8') + return self._contents def compile(self, context, profile, existing): diff --git a/dbt/contracts/graph/__init__.py b/dbt/contracts/graph/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/dbt/contracts/graph/compiled.py b/dbt/contracts/graph/compiled.py new file mode 100644 index 00000000000..3ffc646256a --- /dev/null +++ b/dbt/contracts/graph/compiled.py @@ -0,0 +1,44 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length +from voluptuous.error import Invalid, MultipleInvalid + +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +from dbt.contracts.graph.parsed import parsed_graph_item_contract + +compiled_graph_item_contract = parsed_graph_item_contract.extend({ + # compiled fields + Required('compiled'): bool, + Required('compiled_sql'): Any(str, None), + + # injected fields + Required('extra_ctes_injected'): bool, + Required('extra_cte_ids'): All(list, [str]), + Required('extra_cte_sql'): All(list, [str]), + Required('injected_sql'): Any(str, None), +}) + + +def validate_one(compiled_graph_item): + try: + compiled_graph_item_contract(compiled_graph_item) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) + + +def validate(compiled_graph): + try: + for k, v in compiled_graph.items(): + compiled_graph_item_contract(v) + + if v.get('unique_id') != k: + error_msg = 'unique_id must match key name in compiled graph!' + logger.info(error_msg) + raise ValidationException(error_msg) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py new file mode 100644 index 00000000000..5cebd06d171 --- /dev/null +++ b/dbt/contracts/graph/parsed.py @@ -0,0 +1,66 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length +from voluptuous.error import Invalid, MultipleInvalid + +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +from dbt.contracts.graph.unparsed import unparsed_graph_item_contract + +config_contract = { + Required('enabled'): bool, + Required('materialized'): Any('table', 'view', 'ephemeral', 'incremental'), + Required('post-hook'): list, + Required('pre-hook'): list, + Required('vars'): dict, + Optional('sql_where'): str, + Optional('unique_key'): str, +} + +parsed_graph_item_contract = unparsed_graph_item_contract.extend({ + # identifiers + Required('unique_id'): All(str, Length(min=1, max=255)), + Required('fqn'): All(list, [All(str)]), + + # parsed fields + Required('depends_on'): All(list, [All(str, Length(min=1, max=255))]), + Required('empty'): bool, + Required('config'): config_contract, +}) + +def validate_one(parsed_graph_item): + try: + parsed_graph_item_contract(parsed_graph_item) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) + + materialization = parsed_graph_item.get('config', {}) \ + .get('materialized') + + if materialization == 'incremental' and \ + parsed_graph_item.get('config', {}).get('sql_where') is None: + raise ValidationException( + 'missing `sql_where` for an incremental model') + elif materialization != 'incremental' and \ + parsed_graph_item.get('config', {}).get('sql_where') is not None: + raise ValidationException( + 'invalid field `sql_where` for a non-incremental model') + + +def validate(parsed_graph): + try: + for k, v in parsed_graph.items(): + parsed_graph_item_contract(v) + + if v.get('unique_id') != k: + error_msg = ('unique_id must match key name in parsed graph!' + 'key: {}, model: {}' + .format(k, v)) + logger.info(error_msg) + raise ValidationException(error_msg) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) diff --git a/dbt/contracts/graph/unparsed.py b/dbt/contracts/graph/unparsed.py new file mode 100644 index 00000000000..74ab7ac6ddb --- /dev/null +++ b/dbt/contracts/graph/unparsed.py @@ -0,0 +1,27 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length +from voluptuous.error import Invalid, MultipleInvalid + +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +unparsed_graph_item_contract = Schema({ + # identifiers + Required('name'): All(str, Length(min=1, max=63)), + Required('package_name'): str, + + # filesystem + Required('root_path'): str, + Required('path'): str, + Required('raw_sql'): str, +}) + + +def validate(unparsed_graph): + try: + for item in unparsed_graph: + unparsed_graph_item_contract(item) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py new file mode 100644 index 00000000000..4e88f41bfcb --- /dev/null +++ b/dbt/contracts/project.py @@ -0,0 +1,28 @@ +from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ + Length, ALLOW_EXTRA +from voluptuous.error import Invalid, MultipleInvalid + +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +project_contract = Schema({ + Required('name'): str +}, extra=ALLOW_EXTRA) + +projects_list_contract = Schema({str: project_contract}) + +def validate(project): + try: + project_contract(project) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) + +def validate_list(projects): + try: + projects_list_contract(projects) + + except Invalid as e: + logger.info(e) + raise ValidationException(str(e)) diff --git a/dbt/model.py b/dbt/model.py index 2ed20e244cc..7dd750657b2 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -42,6 +42,7 @@ class SourceConfig(object): ] def __init__(self, active_project, own_project, fqn): + self._config = None self.active_project = active_project self.own_project = own_project self.fqn = fqn @@ -101,7 +102,8 @@ def config(self): return cfg def is_full_refresh(self): - if hasattr(self.active_project.args, 'full_refresh'): + if hasattr(self.active_project, 'args') and \ + hasattr(self.active_project.args, 'full_refresh'): return self.active_project.args.full_refresh else: return False @@ -208,6 +210,7 @@ class DBTSource(object): dbt_run_type = NodeType.Base def __init__(self, project, top_dir, rel_filepath, own_project): + self._config = None self.project = project self.own_project = own_project @@ -256,6 +259,9 @@ def contents(self): @property def config(self): + if self._config is not None: + return self._config + return self.source_config.config def update_in_model_config(self, config): diff --git a/dbt/parser.py b/dbt/parser.py index 5482755e0c8..26f18cdf1cf 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -3,15 +3,19 @@ import jinja2.sandbox import os +import dbt.flags import dbt.model import dbt.utils +import dbt.contracts.graph.parsed +import dbt.contracts.graph.unparsed +import dbt.contracts.project class SilentUndefined(jinja2.Undefined): """ - Don't fail to parse because of undefined things. This allows us to parse - models before macros, since we aren't guaranteed to know about macros - before models. + This class sets up the parser to just ignore undefined jinja2 calls. So, + for example, `env` is not defined here, but will not make the parser fail + with a fatal error. """ def _fail_with_undefined_error(self, *args, **kwargs): return None @@ -36,19 +40,7 @@ def get_macro_path(package_name, resource_name): def __ref(model): def ref(*args): - model_path = None - - if len(args) == 1: - model_path = get_model_path(model.get('package_name'), args[0]) - elif len(args) == 2: - model_path = get_model_path(args[0], args[1]) - else: - dbt.utils.compiler_error( - model.get('name'), - "ref() takes at most two arguments ({} given)".format( - len(args))) - - model['depends_on'].append(model_path) + pass return ref @@ -108,6 +100,9 @@ def parse_model(model, model_path, root_project_config, def parse_models(models, projects): to_return = {} + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.unparsed.validate(models) + for model in models: package_name = model.get('package_name', 'root') @@ -119,11 +114,14 @@ def parse_models(models, projects): projects.get('root'), projects.get(package_name)) + if dbt.flags.STRICT_MODE: + dbt.contracts.graph.parsed.validate(to_return) + return to_return -def load_and_parse_files(package_name, root_dir, relative_dirs, extension, - resource_type): +def load_and_parse_files(package_name, all_projects, root_dir, relative_dirs, + extension, resource_type): file_matches = dbt.clients.system.find_matching( root_dir, relative_dirs, @@ -135,19 +133,28 @@ def load_and_parse_files(package_name, root_dir, relative_dirs, extension, file_contents = dbt.clients.system.load_file_contents( file_match.get('absolute_path')) + parts = dbt.utils.split_path(file_match.get('relative_path', '')) + name, _ = os.path.splitext(parts[-1]) + # TODO: support more than just models models.append({ - 'name': os.path.basename(file_match.get('absolute_path')), + 'name': name, 'root_path': root_dir, 'path': file_match.get('relative_path'), 'package_name': package_name, 'raw_sql': file_contents }) - return parse_models(models) + return parse_models(models, all_projects) + +def load_and_parse_models(package_name, all_projects, root_dir, relative_dirs): + if dbt.flags.STRICT_MODE: + dbt.contracts.project.validate_list(all_projects) -def load_and_parse_models(package_name, root_dir, relative_dirs): - return load_and_parse_files(package_name, root_dir, relative_dirs, + return load_and_parse_files(package_name, + all_projects, + root_dir, + relative_dirs, extension="[!.#~]*.sql", resource_type='models') diff --git a/dbt/runner.py b/dbt/runner.py index 64044fdf6eb..511fc1bb25c 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -614,6 +614,7 @@ def get_nodes_to_run(self, graph, include_spec, exclude_spec, model_type): def get_compiled_models(self, linker, nodes, node_type): compiled_models = [] + for fqn in nodes: compiled_model = make_compiled_model(fqn, linker.get_node(fqn)) diff --git a/dbt/utils.py b/dbt/utils.py index c4df2c42b07..8093a6a6619 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -34,8 +34,10 @@ def __repr__(self): def compiler_error(model, msg): if model is None: name = '' - elif model is str: + elif isinstance(model, str): name = model + elif isinstance(model, dict): + name = model.get('name') else: name = model.nice_name @@ -61,7 +63,11 @@ class Var(object): def __init__(self, model, context): self.model = model self.context = context - self.local_vars = model.config.get('vars', {}) + + if isinstance(model, dict) and model.get('unique_id'): + self.local_vars = model.get('config', {}).get('vars') + else: + self.local_vars = model.config.get('vars', {}) def pretty_dict(self, data): return json.dumps(data, sort_keys=True, indent=4) @@ -94,46 +100,20 @@ def model_cte_name(model): return '__dbt__CTE__{}'.format(model.get('name')) -def find_model_by_unique_id(all_models, target_model_name, - target_model_package): +def find_model_by_name(all_models, target_model_name, + target_model_package): for name, model in all_models.items(): resource_type, package_name, model_name = name.split('.') - if ((target_model_name == model_name) and \ - (target_model_package is None or - target_model_package == package_name)): + if (resource_type == 'models' and \ + ((target_model_name == model_name) and \ + (target_model_package is None or + target_model_package == package_name))): return model return None -def find_model_by_name(models, name, package_namespace=None): - found = [] - for model in models: - if model.name == name: - if package_namespace is None: - found.append(model) - elif (package_namespace is not None and - package_namespace == model.project['name']): - found.append(model) - - nice_package_name = 'ANY' if package_namespace is None \ - else package_namespace - if len(found) == 0: - raise RuntimeError( - "Can't find a model named '{}' in package '{}' -- does it exist?" - .format(name, nice_package_name) - ) - elif len(found) == 1: - return found[0] - else: - raise RuntimeError( - "Model specification is ambiguous: model='{}' package='{}' -- " - "{} models match criteria: {}" - .format(name, nice_package_name, len(found), found) - ) - - def find_model_by_fqn(models, fqn): for model in models: if tuple(model.fqn) == tuple(fqn): diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index a29526dbf91..ba8b06cfc55 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -4,11 +4,13 @@ import unittest import dbt.compilation +import dbt.exceptions +import dbt.flags +import dbt.linker import dbt.model import dbt.project import dbt.templates import dbt.utils -import dbt.linker import networkx as nx from test.integration.base import FakeArgs @@ -25,6 +27,8 @@ def tearDown(self): dbt.clients.system.load_file_contents = self.real_load_file_contents def setUp(self): + dbt.flags.STRICT_MODE = True + def mock_write_yaml(graph, outfile): self.graph_result = graph diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 6c5f795aad4..b65293f56dd 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -3,6 +3,7 @@ import os +import dbt.flags import dbt.parser @@ -14,6 +15,8 @@ def find_input_by_name(self, models, name): {}) def setUp(self): + dbt.flags.STRICT_MODE = True + self.maxDiff = None self.root_project_config = { @@ -41,6 +44,8 @@ def test__single_model(self): models = [{ 'name': 'model_one', 'package_name': 'root', + 'root_path': '/usr/src/app', + 'path': 'model_one.sql', 'raw_sql': ("select * from events"), }] @@ -52,11 +57,14 @@ def test__single_model(self): { 'models.root.model_one': { 'name': 'model_one', + 'unique_id': 'models.root.model_one', 'fqn': ['root_project', 'model_one'], 'empty': False, 'package_name': 'root', + 'root_path': '/usr/src/app', 'depends_on': [], 'config': self.model_config, + 'path': 'model_one.sql', 'raw_sql': self.find_input_by_name( models, 'model_one').get('raw_sql') } @@ -67,6 +75,8 @@ def test__empty_model(self): models = [{ 'name': 'model_one', 'package_name': 'root', + 'path': 'model_one.sql', + 'root_path': '/usr/src/app', 'raw_sql': (" "), }] @@ -77,11 +87,14 @@ def test__empty_model(self): { 'models.root.model_one': { 'name': 'model_one', + 'unique_id': 'models.root.model_one', 'fqn': ['root_project', 'model_one'], 'empty': True, 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'path': 'model_one.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'model_one').get('raw_sql') } @@ -92,10 +105,14 @@ def test__simple_dependency(self): models = [{ 'name': 'base', 'package_name': 'root', + 'path': 'base.sql', + 'root_path': '/usr/src/app', 'raw_sql': 'select * from events' }, { 'name': 'events_tx', 'package_name': 'root', + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': "select * from {{ref('base')}}" }] @@ -107,21 +124,27 @@ def test__simple_dependency(self): { 'models.root.base': { 'name': 'base', + 'unique_id': 'models.root.base', 'fqn': ['root_project', 'base'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'path': 'base.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'base').get('raw_sql') }, 'models.root.events_tx': { 'name': 'events_tx', + 'unique_id': 'models.root.events_tx', 'fqn': ['root_project', 'events_tx'], 'empty': False, 'package_name': 'root', - 'depends_on': ['models.root.base'], + 'depends_on': [], 'config': self.model_config, + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'events_tx').get('raw_sql') } @@ -132,24 +155,34 @@ def test__multiple_dependencies(self): models = [{ 'name': 'events', 'package_name': 'root', + 'path': 'events.sql', + 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.events', }, { 'name': 'sessions', 'package_name': 'root', + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.sessions', }, { 'name': 'events_tx', 'package_name': 'root', + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("with events as (select * from {{ref('events')}}) " "select * from events"), }, { 'name': 'sessions_tx', 'package_name': 'root', + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " "select * from sessions"), }, { 'name': 'multi', 'package_name': 'root', + 'path': 'multi.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("with s as (select * from {{ref('sessions_tx')}}), " "e as (select * from {{ref('events_tx')}}) " "select * from e left join s on s.id = e.sid"), @@ -163,52 +196,66 @@ def test__multiple_dependencies(self): { 'models.root.events': { 'name': 'events', + 'unique_id': 'models.root.events', 'fqn': ['root_project', 'events'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'path': 'events.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'events').get('raw_sql') }, 'models.root.sessions': { 'name': 'sessions', + 'unique_id': 'models.root.sessions', 'fqn': ['root_project', 'sessions'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'sessions').get('raw_sql') }, 'models.root.events_tx': { 'name': 'events_tx', + 'unique_id': 'models.root.events_tx', 'fqn': ['root_project', 'events_tx'], 'empty': False, 'package_name': 'root', - 'depends_on': ['models.root.events'], + 'depends_on': [], 'config': self.model_config, + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'events_tx').get('raw_sql') }, 'models.root.sessions_tx': { 'name': 'sessions_tx', + 'unique_id': 'models.root.sessions_tx', 'fqn': ['root_project', 'sessions_tx'], 'empty': False, 'package_name': 'root', - 'depends_on': ['models.root.sessions'], + 'depends_on': [], 'config': self.model_config, + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'sessions_tx').get('raw_sql') }, 'models.root.multi': { 'name': 'multi', + 'unique_id': 'models.root.multi', 'fqn': ['root_project', 'multi'], 'empty': False, 'package_name': 'root', - 'depends_on': ['models.root.sessions_tx', - 'models.root.events_tx'], + 'depends_on': [], 'config': self.model_config, + 'path': 'multi.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'multi').get('raw_sql') } @@ -219,24 +266,34 @@ def test__multiple_dependencies__packages(self): models = [{ 'name': 'events', 'package_name': 'snowplow', + 'path': 'events.sql', + 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.events', }, { 'name': 'sessions', 'package_name': 'snowplow', + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.sessions', }, { 'name': 'events_tx', 'package_name': 'snowplow', + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("with events as (select * from {{ref('events')}}) " "select * from events"), }, { 'name': 'sessions_tx', 'package_name': 'snowplow', + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("with sessions as (select * from {{ref('sessions')}}) " "select * from sessions"), }, { 'name': 'multi', 'package_name': 'root', + 'path': 'multi.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("with s as (select * from {{ref('snowplow', 'sessions_tx')}}), " "e as (select * from {{ref('snowplow', 'events_tx')}}) " "select * from e left join s on s.id = e.sid"), @@ -250,52 +307,66 @@ def test__multiple_dependencies__packages(self): { 'models.snowplow.events': { 'name': 'events', + 'unique_id': 'models.snowplow.events', 'fqn': ['snowplow', 'events'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], 'config': self.model_config, + 'path': 'events.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'events').get('raw_sql') }, 'models.snowplow.sessions': { 'name': 'sessions', + 'unique_id': 'models.snowplow.sessions', 'fqn': ['snowplow', 'sessions'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], 'config': self.model_config, + 'path': 'sessions.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'sessions').get('raw_sql') }, 'models.snowplow.events_tx': { 'name': 'events_tx', + 'unique_id': 'models.snowplow.events_tx', 'fqn': ['snowplow', 'events_tx'], 'empty': False, 'package_name': 'snowplow', - 'depends_on': ['models.snowplow.events'], + 'depends_on': [], 'config': self.model_config, + 'path': 'events_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'events_tx').get('raw_sql') }, 'models.snowplow.sessions_tx': { 'name': 'sessions_tx', + 'unique_id': 'models.snowplow.sessions_tx', 'fqn': ['snowplow', 'sessions_tx'], 'empty': False, 'package_name': 'snowplow', - 'depends_on': ['models.snowplow.sessions'], + 'depends_on': [], 'config': self.model_config, + 'path': 'sessions_tx.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'sessions_tx').get('raw_sql') }, 'models.root.multi': { 'name': 'multi', + 'unique_id': 'models.root.multi', 'fqn': ['root_project', 'multi'], 'empty': False, 'package_name': 'root', - 'depends_on': ['models.snowplow.sessions_tx', - 'models.snowplow.events_tx'], + 'depends_on': [], 'config': self.model_config, + 'path': 'multi.sql', + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'multi').get('raw_sql') } @@ -306,6 +377,8 @@ def test__in_model_config(self): models = [{ 'name': 'model_one', 'package_name': 'root', + 'path': 'model_one.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("{{config({'materialized':'table'})}}" "select * from events"), }] @@ -322,11 +395,14 @@ def test__in_model_config(self): { 'models.root.model_one': { 'name': 'model_one', + 'unique_id': 'models.root.model_one', 'fqn': ['root_project', 'model_one'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'root_path': '/usr/src/app', + 'path': 'model_one.sql', 'raw_sql': self.find_input_by_name( models, 'model_one').get('raw_sql') } @@ -353,17 +429,20 @@ def test__root_project_config(self): 'name': 'table', 'package_name': 'root', 'path': 'table.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("{{config({'materialized':'table'})}}" "select * from events"), }, { 'name': 'ephemeral', 'package_name': 'root', 'path': 'ephemeral.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'view', 'package_name': 'root', 'path': 'view.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }] @@ -389,33 +468,39 @@ def test__root_project_config(self): { 'models.root.table': { 'name': 'table', + 'unique_id': 'models.root.table', 'fqn': ['root_project', 'table'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'path': 'table.sql', 'config': self.model_config, + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'table').get('raw_sql') }, 'models.root.ephemeral': { 'name': 'ephemeral', + 'unique_id': 'models.root.ephemeral', 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'path': 'ephemeral.sql', 'config': ephemeral_config, + 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') }, 'models.root.view': { 'name': 'view', + 'unique_id': 'models.root.view', 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'path': 'view.sql', + 'root_path': '/usr/src/app', 'config': view_config, 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') @@ -463,27 +548,32 @@ def test__other_project_config(self): 'name': 'table', 'package_name': 'root', 'path': 'table.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("{{config({'materialized':'table'})}}" "select * from events"), }, { 'name': 'ephemeral', 'package_name': 'root', 'path': 'ephemeral.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'view', 'package_name': 'root', 'path': 'view.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'disabled', 'package_name': 'snowplow', 'path': 'disabled.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'package', 'package_name': 'snowplow', 'path': 'models/views/package.sql', + 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }] @@ -521,55 +611,65 @@ def test__other_project_config(self): { 'models.root.table': { 'name': 'table', + 'unique_id': 'models.root.table', 'fqn': ['root_project', 'table'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'path': 'table.sql', + 'root_path': '/usr/src/app', 'config': self.model_config, 'raw_sql': self.find_input_by_name( models, 'table').get('raw_sql') }, 'models.root.ephemeral': { 'name': 'ephemeral', + 'unique_id': 'models.root.ephemeral', 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'path': 'ephemeral.sql', + 'root_path': '/usr/src/app', 'config': ephemeral_config, 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') }, 'models.root.view': { 'name': 'view', + 'unique_id': 'models.root.view', 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'depends_on': [], 'path': 'view.sql', + 'root_path': '/usr/src/app', 'config': view_config, 'raw_sql': self.find_input_by_name( models, 'view').get('raw_sql') }, 'models.snowplow.disabled': { 'name': 'disabled', + 'unique_id': 'models.snowplow.disabled', 'fqn': ['snowplow', 'disabled'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], 'path': 'disabled.sql', + 'root_path': '/usr/src/app', 'config': disabled_config, 'raw_sql': self.find_input_by_name( models, 'disabled').get('raw_sql') }, 'models.snowplow.package': { 'name': 'package', + 'unique_id': 'models.snowplow.package', 'fqn': ['snowplow', 'views', 'package'], 'empty': False, 'package_name': 'snowplow', 'depends_on': [], 'path': 'models/views/package.sql', + 'root_path': '/usr/src/app', 'config': sort_config, 'raw_sql': self.find_input_by_name( models, 'package').get('raw_sql') From 2473b120373b9dc151666bef44dc37d42afeca81 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Thu, 23 Feb 2017 16:53:23 -0500 Subject: [PATCH 04/25] down to 10 integration test failures --- dbt/compilation.py | 26 +- dbt/contracts/common.py | 16 + dbt/contracts/connection.py | 15 +- dbt/contracts/graph/compiled.py | 29 +- dbt/contracts/graph/parsed.py | 32 +- dbt/contracts/graph/unparsed.py | 12 +- dbt/contracts/project.py | 17 +- dbt/exceptions.py | 6 +- dbt/parser.py | 6 +- .../test_invalid_models.py | 11 +- test/unit/test_compiler.py | 324 ++++++++++++++++++ 11 files changed, 396 insertions(+), 98 deletions(-) create mode 100644 dbt/contracts/common.py create mode 100644 test/unit/test_compiler.py diff --git a/dbt/compilation.py b/dbt/compilation.py index 59f40f644e6..d57439fadc3 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -267,9 +267,6 @@ def do_ref(*args): "the project config".format(model.get('unique_id'), target_model.get('unique_id'))) - print('dependency {} to {}'.format(model.get('unique_id'), - target_model.get('unique_id'))) - model['depends_on'].append(target_model_id) if target_model.get('config', {}) \ .get('materialized') == 'ephemeral': @@ -563,6 +560,7 @@ def compile_models(self, linker, models): injected_model.get('root_path'), injected_model.get('path'), all_projects[injected_model.get('package_name')]) + model._config = injected_model.get('config', {}) context = self.get_context(linker, model, injected_models) @@ -594,17 +592,20 @@ def compile_models(self, linker, models): 'build_path': os.path.join(project['target-path'], build_path), 'name': wrapped_model.get('name'), - - # I think we always use tmp_name. - Connor 'tmp_name': model.tmp_name(), 'project_name': project.cfg.get('name') }) for dependency in wrapped_model.get('depends_on'): - if wrapped_models.get(dependency): + if compiled_models.get(dependency): linker.dependency( tuple(wrapped_model.get('fqn')), - tuple(wrapped_models.get(dependency).get('fqn'))) + tuple(compiled_models.get(dependency).get('fqn'))) + else: + compiler_error( + model, + "dependency {} not found in graph!".format( + dependency)) return compiled_models, written_models @@ -782,17 +783,8 @@ def compile(self): self.macro_generator = self.generate_macros(all_macros) - enabled_models = {} - - # TODO: don't worry about testing enabled here. do it after - # linking - for name, model in parsed_models.items(): - if model.get('config', {}).get('enabled') == True and \ - model.get('empty') == False: - enabled_models[model.get('unique_id')] = model - compiled_models, written_models = self.compile_models( - linker, enabled_models + linker, parsed_models ) compilers = { diff --git a/dbt/contracts/common.py b/dbt/contracts/common.py new file mode 100644 index 00000000000..5d6c14b7ddf --- /dev/null +++ b/dbt/contracts/common.py @@ -0,0 +1,16 @@ +from voluptuous.error import Invalid, MultipleInvalid + +from dbt.exceptions import ValidationException +from dbt.logger import GLOBAL_LOGGER as logger + +def validate_with(schema, data): + try: + schema(data) + + except MultipleInvalid as e: + logger.error(str(e)) + raise ValidationException(str(e)) + + except Invalid as e: + logger.error(str(e)) + raise ValidationException(str(e)) diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 46aeb90527d..1fd691c0f4d 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -1,7 +1,6 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional -from voluptuous.error import MultipleInvalid -from dbt.exceptions import ValidationException +from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger @@ -39,11 +38,7 @@ def validate_connection(connection): - try: - connection_contract(connection) - - credentials_contract = credentials_mapping.get(connection.get('type')) - credentials_contract(connection.get('credentials')) - except MultipleInvalid as e: - logger.info(e) - raise ValidationException(str(e)) + validate_with(connection_contract, connection) + + credentials_contract = credentials_mapping.get(connection.get('type')) + validate_with(credentials_contract, connection.get('credentials')) diff --git a/dbt/contracts/graph/compiled.py b/dbt/contracts/graph/compiled.py index 3ffc646256a..341818f26aa 100644 --- a/dbt/contracts/graph/compiled.py +++ b/dbt/contracts/graph/compiled.py @@ -1,10 +1,10 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length -from voluptuous.error import Invalid, MultipleInvalid from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger +from dbt.contracts.common import validate_with from dbt.contracts.graph.parsed import parsed_graph_item_contract compiled_graph_item_contract = parsed_graph_item_contract.extend({ @@ -21,24 +21,13 @@ def validate_one(compiled_graph_item): - try: - compiled_graph_item_contract(compiled_graph_item) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) - + validate_with(compiled_graph_item_contract, compiled_graph_item) def validate(compiled_graph): - try: - for k, v in compiled_graph.items(): - compiled_graph_item_contract(v) - - if v.get('unique_id') != k: - error_msg = 'unique_id must match key name in compiled graph!' - logger.info(error_msg) - raise ValidationException(error_msg) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) + for k, v in compiled_graph.items(): + validate_with(compiled_graph_item_contract, v) + + if v.get('unique_id') != k: + error_msg = 'unique_id must match key name in compiled graph!' + logger.info(error_msg) + raise ValidationException(error_msg) diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index 5cebd06d171..89c7dfeb8b7 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -1,10 +1,10 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length -from voluptuous.error import Invalid, MultipleInvalid from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger +from dbt.contracts.common import validate_with from dbt.contracts.graph.unparsed import unparsed_graph_item_contract config_contract = { @@ -29,12 +29,7 @@ }) def validate_one(parsed_graph_item): - try: - parsed_graph_item_contract(parsed_graph_item) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) + validate_with(parsed_graph_item_contract, parsed_graph_item) materialization = parsed_graph_item.get('config', {}) \ .get('materialized') @@ -50,17 +45,12 @@ def validate_one(parsed_graph_item): def validate(parsed_graph): - try: - for k, v in parsed_graph.items(): - parsed_graph_item_contract(v) - - if v.get('unique_id') != k: - error_msg = ('unique_id must match key name in parsed graph!' - 'key: {}, model: {}' - .format(k, v)) - logger.info(error_msg) - raise ValidationException(error_msg) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) + for k, v in parsed_graph.items(): + validate_one(v) + + if v.get('unique_id') != k: + error_msg = ('unique_id must match key name in parsed graph!' + 'key: {}, model: {}' + .format(k, v)) + logger.info(error_msg) + raise ValidationException(error_msg) diff --git a/dbt/contracts/graph/unparsed.py b/dbt/contracts/graph/unparsed.py index 74ab7ac6ddb..aaea1268f25 100644 --- a/dbt/contracts/graph/unparsed.py +++ b/dbt/contracts/graph/unparsed.py @@ -1,8 +1,7 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length -from voluptuous.error import Invalid, MultipleInvalid -from dbt.exceptions import ValidationException +from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger unparsed_graph_item_contract = Schema({ @@ -18,10 +17,5 @@ def validate(unparsed_graph): - try: - for item in unparsed_graph: - unparsed_graph_item_contract(item) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) + for item in unparsed_graph: + validate_with(unparsed_graph_item_contract, item) diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py index 4e88f41bfcb..43691ee27d7 100644 --- a/dbt/contracts/project.py +++ b/dbt/contracts/project.py @@ -1,8 +1,7 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length, ALLOW_EXTRA -from voluptuous.error import Invalid, MultipleInvalid -from dbt.exceptions import ValidationException +from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger project_contract = Schema({ @@ -12,17 +11,7 @@ projects_list_contract = Schema({str: project_contract}) def validate(project): - try: - project_contract(project) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) + validate_with(project_contract, project) def validate_list(projects): - try: - projects_list_contract(projects) - - except Invalid as e: - logger.info(e) - raise ValidationException(str(e)) + validate_with(projects_list_contract, project) diff --git a/dbt/exceptions.py b/dbt/exceptions.py index 991bfced3cd..b336e610e06 100644 --- a/dbt/exceptions.py +++ b/dbt/exceptions.py @@ -2,7 +2,11 @@ class Exception(BaseException): pass -class ValidationException(Exception): +class RuntimeException(RuntimeError, Exception): + pass + + +class ValidationException(RuntimeException): pass diff --git a/dbt/parser.py b/dbt/parser.py index 26f18cdf1cf..b41f028d81f 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -100,8 +100,7 @@ def parse_model(model, model_path, root_project_config, def parse_models(models, projects): to_return = {} - if dbt.flags.STRICT_MODE: - dbt.contracts.graph.unparsed.validate(models) + dbt.contracts.graph.unparsed.validate(models) for model in models: package_name = model.get('package_name', 'root') @@ -114,8 +113,7 @@ def parse_models(models, projects): projects.get('root'), projects.get(package_name)) - if dbt.flags.STRICT_MODE: - dbt.contracts.graph.parsed.validate(to_return) + dbt.contracts.graph.parsed.validate(to_return) return to_return diff --git a/test/integration/011_invalid_model_tests/test_invalid_models.py b/test/integration/011_invalid_model_tests/test_invalid_models.py index 0f8338a4ad9..4f526a7e289 100644 --- a/test/integration/011_invalid_model_tests/test_invalid_models.py +++ b/test/integration/011_invalid_model_tests/test_invalid_models.py @@ -1,6 +1,8 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest +from dbt.exceptions import ValidationException + class TestInvalidViewModels(DBTIntegrationTest): def setUp(self): @@ -18,7 +20,12 @@ def models(self): @attr(type='postgres') def test_view_with_incremental_attributes(self): - self.run_dbt() + try: + self.run_dbt() + # should throw + self.assertTrue(False) + except RuntimeError as e: + pass class TestInvalidDisabledModels(DBTIntegrationTest): @@ -43,7 +50,7 @@ def test_view_with_incremental_attributes(self): # should throw self.assertTrue(False) except RuntimeError as e: - self.assertTrue("config must be either True or False" in str(e)) + self.assertTrue("enabled" in str(e)) class TestInvalidModelReference(DBTIntegrationTest): diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py new file mode 100644 index 00000000000..22a148e79e9 --- /dev/null +++ b/test/unit/test_compiler.py @@ -0,0 +1,324 @@ +from mock import MagicMock +import unittest + +import os + +import dbt.flags +import dbt.compilation + + +class CompilerTest(unittest.TestCase): + + def assertEqualIgnoreWhitespace(self, a, b): + self.assertEqual( + "".join(a.split()), + "".join(b.split())) + + def setUp(self): + dbt.flags.STRICT_MODE = True + + self.maxDiff = None + + self.root_project_config = { + 'name': 'root_project', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + } + + self.snowplow_project_config = { + 'name': 'snowplow', + 'version': '0.1', + 'project-root': os.path.abspath('./dbt_modules/snowplow'), + } + + self.model_config = { + 'enabled': True, + 'materialized': 'view', + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + } + + def test__prepend_ctes__already_has_cte(self): + ephemeral_config = self.model_config.copy() + ephemeral_config['materialized'] = 'ephemeral' + + compiled_models = { + 'models.root.view': { + 'name': 'view', + 'unique_id': 'models.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'models.root.ephemeral' + ], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'models.root.ephemeral' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': ('with cte as (select * from something_else) ' + 'select * from __dbt__CTE__ephemeral') + }, + 'models.root.ephemeral': { + 'name': 'ephemeral', + 'unique_id': 'models.root.ephemeral', + 'fqn': ['root_project', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'path': 'ephemeral.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'compiled_sql': 'select * from source_table', + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '' + } + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models['models.root.view'], + compiled_models) + + self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + ('with __dbt__CTE__ephemeral as (' + 'select * from source_table' + '), cte as (select * from something_else) ' + 'select * from __dbt__CTE__ephemeral')) + + self.assertEqual( + all_models.get('models.root.ephemeral').get('extra_ctes_injected'), + True) + + def test__prepend_ctes__no_ctes(self): + compiled_models = { + 'models.root.view': { + 'name': 'view', + 'unique_id': 'models.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': ('with cte as (select * from something_else) ' + 'select * from source_table'), + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': ('with cte as (select * from something_else) ' + 'select * from source_table') + }, + 'models.root.view_no_cte': { + 'name': 'view_no_cte', + 'unique_id': 'models.root.view_no_cte', + 'fqn': ['root_project', 'view_no_cte'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': ('select * from source_table') + } + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models.get('models.root.view'), + compiled_models) + + self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + compiled_models.get('models.root.view').get('compiled_sql')) + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models.get('models.root.view_no_cte'), + compiled_models) + + self.assertEqual(result, all_models.get('models.root.view_no_cte')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + compiled_models.get('models.root.view_no_cte').get('compiled_sql')) + + + def test__prepend_ctes(self): + ephemeral_config = self.model_config.copy() + ephemeral_config['materialized'] = 'ephemeral' + + compiled_models = { + 'models.root.view': { + 'name': 'view', + 'unique_id': 'models.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'models.root.ephemeral' + ], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'models.root.ephemeral' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from __dbt__CTE__ephemeral' + }, + 'models.root.ephemeral': { + 'name': 'ephemeral', + 'unique_id': 'models.root.ephemeral', + 'fqn': ['root_project', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'path': 'ephemeral.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from source_table' + } + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models['models.root.view'], + compiled_models) + + self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + ('with __dbt__CTE__ephemeral as (' + 'select * from source_table' + ') ' + 'select * from __dbt__CTE__ephemeral')) + + self.assertEqual( + all_models.get('models.root.ephemeral').get('extra_ctes_injected'), + True) + + + def test__prepend_ctes__multiple_levels(self): + ephemeral_config = self.model_config.copy() + ephemeral_config['materialized'] = 'ephemeral' + + compiled_models = { + 'models.root.view': { + 'name': 'view', + 'unique_id': 'models.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'models.root.ephemeral' + ], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'models.root.ephemeral' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from __dbt__CTE__ephemeral' + }, + 'models.root.ephemeral': { + 'name': 'ephemeral', + 'unique_id': 'models.root.ephemeral', + 'fqn': ['root_project', 'ephemeral'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'path': 'ephemeral.sql', + 'raw_sql': 'select * from {{ref("ephemeral_level_two")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'models.root.ephemeral_level_two' + ], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from __dbt__CTE__ephemeral_level_two' + }, + 'models.root.ephemeral_level_two': { + 'name': 'ephemeral_level_two', + 'unique_id': 'models.root.ephemeral_level_two', + 'fqn': ['root_project', 'ephemeral_level_two'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'path': 'ephemeral_level_two.sql', + 'raw_sql': 'select * from source_table', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [], + 'extra_cte_sql': [], + 'injected_sql': '', + 'compiled_sql': 'select * from source_table' + } + + } + + result, all_models = dbt.compilation.prepend_ctes( + compiled_models['models.root.view'], + compiled_models) + + self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result.get('extra_ctes_injected'), True) + self.assertEqualIgnoreWhitespace( + result.get('injected_sql'), + ('with __dbt__CTE__ephemeral_level_two as (' + 'select * from source_table' + '), __dbt__CTE__ephemeral as (' + 'select * from __dbt__CTE__ephemeral_level_two' + ') ' + 'select * from __dbt__CTE__ephemeral')) + + self.assertEqual( + all_models.get('models.root.ephemeral').get('extra_ctes_injected'), + True) + self.assertEqual( + all_models.get('models.root.ephemeral_level_two').get('extra_ctes_injected'), + True) From 50d9896e6ab67a7a881d2add70de6a5d18cd3e29 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sat, 25 Feb 2017 22:05:43 -0500 Subject: [PATCH 05/25] schema and data tests running in parser --- dbt/adapters/postgres.py | 4 +- dbt/compilation.py | 458 +++++----------- dbt/contracts/graph/unparsed.py | 3 + dbt/contracts/project.py | 2 +- dbt/flags.py | 1 + dbt/graph/selector.py | 20 +- dbt/linker.py | 10 + dbt/model.py | 6 +- dbt/parser.py | 248 +++++++-- dbt/runner.py | 518 ++++++++++++------ dbt/task/run.py | 6 +- dbt/task/test.py | 9 +- dbt/utils.py | 2 +- .../test_schema_test_graph_selection.py | 4 +- .../test_schema_tests.py | 1 + test/unit/test_compiler.py | 91 +-- test/unit/test_graph.py | 88 ++- test/unit/test_graph_selection.py | 47 +- test/unit/test_parser.py | 343 +++++++++--- test/unit/test_runner.py | 269 +++++++++ tox.ini | 2 +- 21 files changed, 1365 insertions(+), 767 deletions(-) create mode 100644 test/unit/test_runner.py diff --git a/dbt/adapters/postgres.py b/dbt/adapters/postgres.py index 06a74bef52a..6ea6a2d6462 100644 --- a/dbt/adapters/postgres.py +++ b/dbt/adapters/postgres.py @@ -291,7 +291,7 @@ def rename(cls, profile, from_name, to_name, model_name=None): @classmethod def execute_model(cls, profile, model): - parts = re.split(r'-- (DBT_OPERATION .*)', model.compiled_contents) + parts = re.split(r'-- (DBT_OPERATION .*)', model.get('wrapped_sql')) connection = cls.get_connection(profile) if flags.STRICT_MODE: @@ -317,7 +317,7 @@ def call_expand_target_column_types(kwargs): func_map[function](kwargs) else: handle, cursor = cls.add_query_to_transaction( - part, connection, model.name) + part, connection, model.get('name')) handle.commit() diff --git a/dbt/compilation.py b/dbt/compilation.py index d57439fadc3..535ca480a60 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -8,7 +8,7 @@ import dbt.project import dbt.utils -from dbt.model import Model +from dbt.model import Model, NodeType from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ split_path, This, Var, compiler_error, to_string @@ -148,15 +148,6 @@ def initialize(self): if not os.path.exists(self.project['modules-path']): os.makedirs(self.project['modules-path']) - def model_sources(self, this_project, own_project=None): - if own_project is None: - own_project = this_project - - paths = own_project.get('source-paths', []) - return Source( - this_project, - own_project=own_project - ).get_models(paths) def get_macros(self, this_project, own_project=None): if own_project is None: @@ -164,40 +155,18 @@ def get_macros(self, this_project, own_project=None): paths = own_project.get('macro-paths', []) return Source(this_project, own_project=own_project).get_macros(paths) + def get_archives(self, project): return Source( project, own_project=project ).get_archives() - def project_schemas(self, project): - source_paths = project.get('source-paths', []) - return Source(project).get_schemas(source_paths) - - def project_tests(self, project): - source_paths = project.get('test-paths', []) - return Source(project).get_tests(source_paths) def analysis_sources(self, project): paths = project.get('analysis-paths', []) return Source(project).get_analyses(paths) - def validate_models_unique(self, models, error_type): - found_models = defaultdict(list) - for model in models: - found_models[model.name].append(model) - for model_name, model_list in found_models.items(): - if len(model_list) > 1: - models_str = "\n - ".join( - [str(model) for model in model_list]) - - error_msg = "Found {} models with the same name.\n" \ - " Name='{}'\n" \ - " - {}".format( - len(model_list), model_name, models_str - ) - - error_type(model_list[0], error_msg) def __write(self, build_filepath, payload): target_path = os.path.join(self.project['target-path'], build_filepath) @@ -214,18 +183,6 @@ def do_config(*args, **kwargs): return do_config - def model_can_reference(self, src_model, other_model): - """ - returns True if the src_model can reference the other_model. Models - can access other models in their package and dependency models, but - a dependency model cannot access models "up" the dependency chain. - """ - - # hack for now b/c we don't support recursive dependencies - return ( - other_model.own_project['name'] == src_model.own_project['name'] or - src_model.own_project['name'] == src_model.project['name'] - ) def __ref(self, ctx, model, all_models): schema = ctx.get('env', {}).get('schema') @@ -268,6 +225,7 @@ def do_ref(*args): target_model.get('unique_id'))) model['depends_on'].append(target_model_id) + if target_model.get('config', {}) \ .get('materialized') == 'ephemeral': @@ -360,10 +318,10 @@ def get_context(self, linker, model, models): return runtime - def compile_model(self, linker, model, models): + def compile_node(self, linker, node, nodes): try: - compiled_model = model.copy() - compiled_model.update({ + compiled_node = node.copy() + compiled_node.update({ 'compiled': False, 'compiled_sql': None, 'extra_ctes_injected': False, @@ -372,86 +330,26 @@ def compile_model(self, linker, model, models): 'injected_sql': None, }) - context = self.get_compiler_context(linker, compiled_model, models) + context = self.get_compiler_context(linker, compiled_node, nodes) env = jinja2.sandbox.SandboxedEnvironment() - compiled_model['compiled_sql'] = env.from_string( - model.get('raw_sql')).render(context) + compiled_node['compiled_sql'] = env.from_string( + node.get('raw_sql')).render(context) - compiled_model['compiled'] = True + compiled_node['compiled'] = True except jinja2.exceptions.TemplateSyntaxError as e: - compiler_error(model, str(e)) + compiler_error(node, str(e)) except jinja2.exceptions.UndefinedError as e: - compiler_error(model, str(e)) + compiler_error(node, str(e)) - return compiled_model + return compiled_node def write_graph_file(self, linker): filename = graph_file_name graph_path = os.path.join(self.project['target-path'], filename) linker.write_graph(graph_path) - def combine_query_with_ctes(self, model, query, ctes, compiled_models): - parsed_stmts = sqlparse.parse(query) - if len(parsed_stmts) != 1: - raise RuntimeError( - "unexpectedly parsed {} queries from model " - "{}".format(len(parsed_stmts), model) - ) - - parsed = parsed_stmts[0] - - with_stmt = None - for token in parsed.tokens: - if token.is_keyword and token.normalized == 'WITH': - with_stmt = token - break - - if with_stmt is None: - # no with stmt, add one! - first_token = parsed.token_first() - with_stmt = sqlparse.sql.Token(sqlparse.tokens.Keyword, 'with') - parsed.insert_before(first_token, with_stmt) - else: - # stmt exists, add a comma (which will come after our injected - # CTE(s) ) - trailing_comma = sqlparse.sql.Token( - sqlparse.tokens.Punctuation, ',' - ) - parsed.insert_after(with_stmt, trailing_comma) - - cte_mapping = [ - (model.cte_name, compiled_models[model]) for model in ctes - ] - - # these newlines are important -- comments could otherwise interfere - # w/ query - cte_stmts = [ - " {} as (\n{}\n)".format(name, contents) - for (name, contents) in cte_mapping - ] - - cte_text = sqlparse.sql.Token( - sqlparse.tokens.Keyword, ", ".join(cte_stmts) - ) - parsed.insert_after(with_stmt, cte_text) - - return str(parsed) - - def __recursive_add_ctes(self, linker, model): - if model not in linker.cte_map: - return set() - - models_to_add = linker.cte_map[model] - recursive_models = [ - self.__recursive_add_ctes(linker, m) for m in models_to_add - ] - - for recursive_model_set in recursive_models: - models_to_add = models_to_add | recursive_model_set - - return models_to_add def new_add_cte_to_rendered_query(self, linker, primary_model, compiled_models): @@ -482,132 +380,79 @@ def new_add_cte_to_rendered_query(self, linker, primary_model, return compiled_query - def add_cte_to_rendered_query( - self, linker, primary_model, compiled_models - ): - fqn_to_model = {tuple(model.fqn): model for model in compiled_models} - sorted_nodes = linker.as_topological_ordering() - - models_to_add = self.__recursive_add_ctes(linker, primary_model) - - required_ctes = [] - for node in sorted_nodes: - - if node not in fqn_to_model: - continue - - model = fqn_to_model[node] - # add these in topological sort order -- significant for CTEs - if model.is_ephemeral and model in models_to_add: - required_ctes.append(model) - - query = compiled_models[primary_model] - if len(required_ctes) == 0: - return query - else: - compiled_query = self.combine_query_with_ctes( - primary_model, query, required_ctes, compiled_models - ) - return compiled_query - - def remove_node_from_graph(self, linker, model, models): - # remove the node - children = linker.remove_node(tuple(model.fqn)) - - # check if we bricked the graph. if so: throw compilation error - for child in children: - other_model = find_model_by_fqn(models, child) - - if other_model.is_enabled: - this_fqn = ".".join(model.fqn) - that_fqn = ".".join(other_model.fqn) - compiler_error( - model, - "Model '{}' depends on model '{}' which is " - "disabled".format(that_fqn, this_fqn) - ) - - def compile_models(self, linker, models): - all_projects = {'root': self.project} - dependency_projects = dbt.utils.dependency_projects(self.project) - - for project in dependency_projects: - name = project.cfg.get('name', 'unknown') - all_projects[name] = project + def compile_nodes(self, linker, nodes): + all_projects = self.get_all_projects() - compiled_models = {} - injected_models = {} - wrapped_models = {} - written_models = [] + compiled_nodes = {} + injected_nodes = {} + wrapped_nodes = {} + written_nodes = [] - for name, model in models.items(): - compiled_models[name] = self.compile_model(linker, model, models) + for name, node in nodes.items(): + compiled_nodes[name] = self.compile_node(linker, node, nodes) if dbt.flags.STRICT_MODE: - dbt.contracts.graph.compiled.validate(compiled_models) + dbt.contracts.graph.compiled.validate(compiled_nodes) - for name, model in compiled_models.items(): - model, compiled_models = prepend_ctes(model, compiled_models) - injected_models[name] = model + for name, node in compiled_nodes.items(): + node, compiled_nodes = prepend_ctes(node, compiled_nodes) + injected_nodes[name] = node if dbt.flags.STRICT_MODE: - dbt.contracts.graph.compiled.validate(injected_models) + dbt.contracts.graph.compiled.validate(injected_nodes) + + for name, injected_node in injected_nodes.items(): + # now turn model nodes back into the old-style model object for + # wrapping + if injected_node.get('resource_type') != NodeType.Model: + # don't wrap thing that aren't models, i.e. tests. + injected_node['wrapped_sql'] = injected_node['injected_sql'] + wrapped_nodes[name] = injected_node + else: + model = Model( + self.project, + injected_node.get('root_path'), + injected_node.get('path'), + all_projects.get(injected_node.get('package_name'))) - for name, injected_model in injected_models.items(): - # now turn a model back into the old-style model object - model = Model( - self.project, - injected_model.get('root_path'), - injected_model.get('path'), - all_projects[injected_model.get('package_name')]) + model._config = injected_node.get('config', {}) - model._config = injected_model.get('config', {}) + context = self.get_context(linker, model, injected_nodes) - context = self.get_context(linker, model, injected_models) + wrapped_stmt = model.compile( + injected_node.get('injected_sql'), self.project, context) - wrapped_stmt = model.compile( - injected_model.get('injected_sql'), self.project, context) + injected_node['wrapped_sql'] = wrapped_stmt + wrapped_nodes[name] = injected_node - injected_model['wrapped_sql'] = wrapped_stmt - wrapped_model = injected_model - wrapped_models[name] = wrapped_model + build_path = os.path.join('build', injected_node.get('path')) - build_path = os.path.join('build', injected_model.get('path')) - if injected_model.get('config', {}) \ - .get('materialized') != 'ephemeral': + if injected_node.get('config', {}) \ + .get('materialized') != 'ephemeral': self.__write(build_path, wrapped_stmt) - written_models.append(model) + written_nodes.append(injected_node) + injected_node['build_path'] = build_path - linker.add_node(tuple(wrapped_model.get('fqn'))) - project = all_projects[wrapped_model.get('package_name')] + linker.add_node(injected_node.get('unique_id')) + project = all_projects[injected_node.get('package_name')] linker.update_node_data( - tuple(wrapped_model.get('fqn')), - { - 'materialized': (wrapped_model.get('config', {}) - .get('materialized')), - 'dbt_run_type': dbt.model.NodeType.Model, - 'enabled': (wrapped_model.get('config', {}) - .get('enabled')), - 'build_path': os.path.join(project['target-path'], - build_path), - 'name': wrapped_model.get('name'), - 'tmp_name': model.tmp_name(), - 'project_name': project.cfg.get('name') - }) - - for dependency in wrapped_model.get('depends_on'): - if compiled_models.get(dependency): + injected_node.get('unique_id'), + injected_node) + + for dependency in injected_node.get('depends_on'): + if compiled_nodes.get(dependency): linker.dependency( - tuple(wrapped_model.get('fqn')), - tuple(compiled_models.get(dependency).get('fqn'))) + injected_node.get('unique_id'), + compiled_nodes.get(dependency).get('unique_id')) else: compiler_error( model, "dependency {} not found in graph!".format( dependency)) - return compiled_models, written_models + return wrapped_nodes, written_nodes + def compile_analyses(self, linker, compiled_models): analyses = self.analysis_sources(self.project) @@ -637,75 +482,6 @@ def compile_analyses(self, linker, compiled_models): return written_analyses - def get_local_and_package_sources(self, project, source_getter): - all_sources = [] - - all_sources.extend(source_getter(project)) - - for package in dbt.utils.dependency_projects(project): - all_sources.extend(source_getter(package)) - - return all_sources - - def compile_schema_tests(self, linker, models): - all_schema_specs = self.get_local_and_package_sources( - self.project, - self.project_schemas - ) - - schema_tests = [] - - for schema in all_schema_specs: - # compiling a SchemaFile returns >= 0 SchemaTest models - try: - schema_tests.extend(schema.compile()) - except RuntimeError as e: - logger.info("\n" + str(e)) - schema_test_path = schema.filepath - logger.info("Skipping compilation for {}...\n" - .format(schema_test_path)) - - written_tests = [] - for schema_test in schema_tests: - # show a warning if the model being tested doesn't exist - try: - source_model = find_model_by_name(models, - schema_test.model_name, - None) - except RuntimeError as e: - dbt.utils.compiler_warning(schema_test, str(e)) - continue - - serialized = schema_test.serialize() - - model_node = tuple(source_model.get('fqn')) - test_node = tuple(schema_test.fqn) - - linker.dependency(test_node, model_node) - linker.update_node_data(test_node, serialized) - - query = schema_test.render() - self.__write(schema_test.build_path(), query) - written_tests.append(schema_test) - - return written_tests - - def compile_data_tests(self, linker, models): - tests = self.get_local_and_package_sources( - self.project, - self.project_tests - ) - - written_tests = [] - for data_test in tests: - serialized = data_test.serialize() - linker.update_node_data(tuple(data_test.fqn), serialized) - query = self.compile_model(linker, data_test, models) - wrapped = data_test.render(query) - self.__write(data_test.build_path(), wrapped) - written_tests.append(data_test) - - return written_tests def generate_macros(self, all_macros): def do_gen(ctx): @@ -716,6 +492,7 @@ def do_gen(ctx): return macros return do_gen + def compile_archives(self, linker, compiled_models): all_archives = self.get_archives(self.project) @@ -727,20 +504,10 @@ def compile_archives(self, linker, compiled_models): return all_archives - def get_models(self): - all_models = self.model_sources(this_project=self.project) - for project in dbt.utils.dependency_projects(self.project): - all_models.extend( - self.model_sources( - this_project=self.project, own_project=project - ) - ) - - return all_models def get_all_projects(self): root_project = self.project.cfg - all_projects = {'root': root_project} + all_projects = {root_project.get('name'): root_project} dependency_projects = dbt.utils.dependency_projects(self.project) for project in dependency_projects: @@ -758,21 +525,73 @@ def get_parsed_models(self, root_project, all_projects): for name, project in all_projects.items(): parsed_models.update( - dbt.parser.load_and_parse_models( + dbt.parser.load_and_parse_sql( package_name=name, + root_project=root_project, all_projects=all_projects, root_dir=project.get('project-root'), - relative_dirs=project.get('source-paths', []))) + relative_dirs=project.get('source-paths', []), + resource_type=NodeType.Model)) return parsed_models + + def get_parsed_data_tests(self, root_project, all_projects): + parsed_tests = {} + + for name, project in all_projects.items(): + parsed_tests.update( + dbt.parser.load_and_parse_sql( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('test-paths', []), + resource_type=NodeType.Test)) + + return parsed_tests + + + def get_parsed_schema_tests(self, root_project, all_projects): + parsed_tests = {} + + for name, project in all_projects.items(): + print('project') + print(project) + + parsed_tests.update( + dbt.parser.load_and_parse_yml( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('source-paths', []))) + + return parsed_tests + + + def load_all_nodes(self, root_project, all_projects): + all_nodes = {} + + all_nodes.update(self.get_parsed_models(root_project, all_projects)) + all_nodes.update( + self.get_parsed_data_tests(root_project, all_projects)) + all_nodes.update( + self.get_parsed_schema_tests(root_project, all_projects)) + + return all_nodes + + def compile(self): linker = Linker() root_project = self.project.cfg all_projects = self.get_all_projects() - parsed_models = self.get_parsed_models(root_project, all_projects) + all_nodes = self.load_all_nodes(root_project, all_projects) + + print('all_nodes') + print(all_nodes.keys()) all_macros = self.get_macros(this_project=self.project) @@ -783,41 +602,16 @@ def compile(self): self.macro_generator = self.generate_macros(all_macros) - compiled_models, written_models = self.compile_models( - linker, parsed_models - ) - - compilers = { - 'schema tests': self.compile_schema_tests, - 'data tests': self.compile_data_tests, - 'archives': self.compile_archives, - 'analyses': self.compile_analyses - } + compiled_nodes, written_nodes = self.compile_nodes(linker, all_nodes) - compiled = { - 'models': written_models - } + # TODO re-add archives - for (compile_type, compiler_f) in compilers.items(): - newly_compiled = compiler_f(linker, compiled_models) - compiled[compile_type] = newly_compiled - - self.validate_models_unique( - compiled['models'], - dbt.utils.compiler_error - ) - - self.validate_models_unique( - compiled['data tests'], - dbt.utils.compiler_warning - ) + self.write_graph_file(linker) - self.validate_models_unique( - compiled['schema tests'], - dbt.utils.compiler_warning - ) + stats = {} - self.write_graph_file(linker) + for node_name, node in compiled_nodes.items(): + stats[node.get('resource_type')] = stats.get( + node.get('resource_type'), 0) + 1 - stats = {ttype: len(m) for (ttype, m) in compiled.items()} return stats diff --git a/dbt/contracts/graph/unparsed.py b/dbt/contracts/graph/unparsed.py index aaea1268f25..4b89d8aedd9 100644 --- a/dbt/contracts/graph/unparsed.py +++ b/dbt/contracts/graph/unparsed.py @@ -4,10 +4,13 @@ from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger +from dbt.model import NodeType + unparsed_graph_item_contract = Schema({ # identifiers Required('name'): All(str, Length(min=1, max=63)), Required('package_name'): str, + Required('resource_type'): Any(NodeType.Model, NodeType.Test), # filesystem Required('root_path'): str, diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py index 43691ee27d7..be4f54fc85f 100644 --- a/dbt/contracts/project.py +++ b/dbt/contracts/project.py @@ -14,4 +14,4 @@ def validate(project): validate_with(project_contract, project) def validate_list(projects): - validate_with(projects_list_contract, project) + validate_with(projects_list_contract, projects) diff --git a/dbt/flags.py b/dbt/flags.py index 928c20aaeb6..a3ae28bd830 100644 --- a/dbt/flags.py +++ b/dbt/flags.py @@ -1 +1,2 @@ STRICT_MODE = False +NON_DESTRUCTIVE = True diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index dc6810a2ace..4f25be66d79 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -43,7 +43,9 @@ def parse_spec(node_spec): def get_package_names(graph): - return set([node[0] for node in graph.nodes()]) + print('get_package_names') + print([node.split(".")[1] for node in graph.nodes()]) + return set([node.split(".")[1] for node in graph.nodes()]) def is_selected_node(real_node, node_selector): @@ -79,17 +81,22 @@ def get_nodes_by_qualified_name(project, graph, qualified_name): package_names = get_package_names(graph) for node in graph.nodes(): - if len(qualified_name) == 1 and node[-1] == qualified_name[0]: + # node naming has changed to dot notation. split to tuple for + # compatibility with this code. + fqn_ish = node.split('.')[1:] + print(fqn_ish) + + if len(qualified_name) == 1 and fqn_ish == qualified_name[0]: yield node elif qualified_name[0] in package_names: - if is_selected_node(node, qualified_name): + if is_selected_node(fqn_ish, qualified_name): yield node else: for package_name in package_names: local_qualified_node_name = (package_name,) + qualified_name - if is_selected_node(node, local_qualified_node_name): + if is_selected_node(fqn_ish, local_qualified_node_name): yield node break @@ -130,6 +137,11 @@ def warn_if_useless_spec(spec, nodes): def select_nodes(project, graph, raw_include_specs, raw_exclude_specs): selected_nodes = set() + print('select_nodes') + print(graph) + print(raw_include_specs) + print(raw_exclude_specs) + split_include_specs = split_specs(raw_include_specs) split_exclude_specs = split_specs(raw_exclude_specs) diff --git a/dbt/linker.py b/dbt/linker.py index b144892ae1d..a279e09b559 100644 --- a/dbt/linker.py +++ b/dbt/linker.py @@ -1,8 +1,18 @@ import networkx as nx from collections import defaultdict + +import dbt.compilation import dbt.model +def from_file(graph_file): + linker = Linker() + linker.read_graph(graph_file) + + return linker + + + class Linker(object): def __init__(self, data=None): if data is None: diff --git a/dbt/model.py b/dbt/model.py index 7dd750657b2..7a16ca9ea71 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -10,6 +10,7 @@ from dbt.adapters.factory import get_adapter from dbt.utils import deep_merge, DBTConfigKeys, compiler_error, \ compiler_warning +import dbt.flags class NodeType(object): @@ -319,10 +320,7 @@ def tmp_name(self): return "{}__dbt_tmp".format(self.name) def is_non_destructive(self): - if hasattr(self.project.args, 'non_destructive'): - return self.project.args.non_destructive - else: - return False + return dbt.flags.NON_DESTRUCTIVE def rename_query(self, schema): opts = { diff --git a/dbt/parser.py b/dbt/parser.py index b41f028d81f..6d36dcafec0 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -2,6 +2,7 @@ import jinja2 import jinja2.sandbox import os +import yaml import dbt.flags import dbt.model @@ -11,6 +12,55 @@ import dbt.contracts.graph.unparsed import dbt.contracts.project +from dbt.model import NodeType + +QUERY_VALIDATE_NOT_NULL = """ +with validation as ( + select {field} as f + from {ref} +) +select count(*) from validation where f is null +""" + + +QUERY_VALIDATE_UNIQUE = """ +with validation as ( + select {field} as f + from {ref} + where {field} is not null +), +validation_errors as ( + select f from validation group by f having count(*) > 1 +) +select count(*) from validation_errors +""" + + +QUERY_VALIDATE_ACCEPTED_VALUES = """ +with all_values as ( + select distinct {field} as f + from {ref} +), +validation_errors as ( + select f from all_values where f not in ({values_csv}) +) +select count(*) from validation_errors +""" + + +QUERY_VALIDATE_REFERENTIAL_INTEGRITY = """ +with parent as ( + select {parent_field} as id + from {parent_ref} +), child as ( + select {child_field} as id + from {child_ref} +) +select count(*) from child +where id not in (select id from parent) and id is not null +""" + + class SilentUndefined(jinja2.Undefined): """ This class sets up the parser to just ignore undefined jinja2 calls. So, @@ -32,7 +82,10 @@ def get_path(resource_type, package_name, resource_name): return "{}.{}.{}".format(resource_type, package_name, resource_name) def get_model_path(package_name, resource_name): - return get_path('models', package_name, resource_name) + return get_path(NodeType.Model, package_name, resource_name) + +def get_test_path(package_name, resource_name): + return get_path(NodeType.Test, package_name, resource_name) def get_macro_path(package_name, resource_name): return get_path('macros', package_name, resource_name) @@ -61,71 +114,87 @@ def config(*args, **kwargs): return config +def get_fqn(path, package_project_config, extra=[]): + parts = dbt.utils.split_path(path) + name, _ = os.path.splitext(parts[-1]) + fqn = ([package_project_config.get('name')] + + parts[1:-1] + + extra + + [name]) + + return fqn -def parse_model(model, model_path, root_project_config, - package_project_config): - parsed_model = copy.deepcopy(model) +def parse_node(node, node_path, root_project_config, package_project_config, + fqn_extra=[]): + parsed_node = copy.deepcopy(node) - parsed_model.update({ + parsed_node.update({ 'depends_on': [], }) - parts = dbt.utils.split_path(model.get('path', '')) - name, _ = os.path.splitext(parts[-1]) - fqn = ([package_project_config.get('name')] + - parts[1:-1] + - [model.get('name')]) + fqn = get_fqn(node.get('path'), package_project_config, fqn_extra) config = dbt.model.SourceConfig( root_project_config, package_project_config, fqn) context = { - 'ref': __ref(parsed_model), - 'config': __config(parsed_model, config), + 'ref': __ref(parsed_node), + 'config': __config(parsed_node, config), } env = jinja2.sandbox.SandboxedEnvironment( undefined=SilentUndefined) - env.from_string(model.get('raw_sql')).render(context) + env.from_string(node.get('raw_sql')).render(context) + + parsed_node['unique_id'] = node_path + parsed_node['config'] = config.config + parsed_node['empty'] = (len(node.get('raw_sql').strip()) == 0) + parsed_node['fqn'] = fqn + + return parsed_node - parsed_model['unique_id'] = model_path - parsed_model['config'] = config.config - parsed_model['empty'] = (len(model.get('raw_sql').strip()) == 0) - parsed_model['fqn'] = fqn - return parsed_model +def parse_models(nodes, projects): + return parse_sql_nodes(nodes, projects) -def parse_models(models, projects): +def parse_sql_nodes(nodes, root_project, projects): to_return = {} - dbt.contracts.graph.unparsed.validate(models) + dbt.contracts.graph.unparsed.validate(nodes) - for model in models: - package_name = model.get('package_name', 'root') + for node in nodes: + package_name = node.get('package_name') - model_path = get_model_path(package_name, model.get('name')) + node_path = get_path(node.get('resource_type'), + package_name, + node.get('name')) # TODO if this is set, raise a compiler error - to_return[model_path] = parse_model(model, - model_path, - projects.get('root'), - projects.get(package_name)) + to_return[node_path] = parse_node(node, + node_path, + root_project, + projects.get(package_name)) dbt.contracts.graph.parsed.validate(to_return) return to_return -def load_and_parse_files(package_name, all_projects, root_dir, relative_dirs, - extension, resource_type): +def load_and_parse_sql(package_name, root_project, all_projects, root_dir, + relative_dirs, resource_type): + extension = "[!.#~]*.sql" + + if dbt.flags.STRICT_MODE: + dbt.contracts.project.validate_list(all_projects) + file_matches = dbt.clients.system.find_matching( root_dir, relative_dirs, extension) - models = [] + result = [] for file_match in file_matches: file_contents = dbt.clients.system.load_file_contents( @@ -134,25 +203,124 @@ def load_and_parse_files(package_name, all_projects, root_dir, relative_dirs, parts = dbt.utils.split_path(file_match.get('relative_path', '')) name, _ = os.path.splitext(parts[-1]) - # TODO: support more than just models - models.append({ + result.append({ 'name': name, 'root_path': root_dir, + 'resource_type': resource_type, 'path': file_match.get('relative_path'), 'package_name': package_name, 'raw_sql': file_contents }) - return parse_models(models, all_projects) + return parse_sql_nodes(result, root_project, all_projects) + + +def parse_schema_tests(tests, root_project, projects): + to_return = {} + + for test in tests: + test_yml = yaml.safe_load(test.get('raw_yml')) + + # validate schema test yml structure + + for model_name, test_spec in test_yml.items(): + for test_type, configs in test_spec.get('constraints', {}).items(): + for config in configs: + to_add = parse_schema_test( + test, model_name, config, test_type, + root_project, + projects.get(test.get('package_name'))) + to_return[to_add.get('unique_id')] = to_add + + return to_return + + +def parse_schema_test(test_base, model_name, test_config, test_type, + root_project_config, package_project_config): + if test_type == 'not_null': + raw_sql = QUERY_VALIDATE_NOT_NULL.format( + ref="{{ref('"+model_name+"')}}", field=test_config) + name_key = test_config + + elif test_type == 'unique': + raw_sql = QUERY_VALIDATE_UNIQUE.format( + ref="{{ref('"+model_name+"')}}", field=test_config) + name_key = test_config + + elif test_type == 'relationships': + child_field = test_config.get('from') + parent_field = test_config.get('field') + parent_model = test_config.get('to') + + raw_sql = QUERY_VALIDATE_REFERENTIAL_INTEGRITY.format( + child_field=child_field, + child_ref="{{ref('"+model_name+"')}}", + parent_field=parent_field, + parent_ref=("{{ref('"+parent_model+"')}}")) + + name_key = '{}_to_{}_{}'.format(child_field, parent_model, + parent_field) + + elif test_type == 'accepted_values': + raw_sql = QUERY_VALIDATE_ACCEPTED_VALUES.format( + ref="{{ref('"+model_name+"')}}", + field=test_config.get('field', ''), + values_csv="'{}'".format( + "','".join(test_config.get('values', [])))) + + name_key = test_config.get('field') + + else: + raise dbt.exceptions.ValidationException( + 'Unknown schema test type {}'.format(test_type)) + + name = '{}_{}_{}'.format(test_type, model_name, name_key) + + to_return = { + 'name': name, + 'resource_type': test_base.get('resource_type'), + 'package_name': test_base.get('package_name'), + 'root_path': test_base.get('root_path'), + 'path': test_base.get('path'), + 'raw_sql': raw_sql + } + + return parse_node(to_return, + get_test_path(test_base.get('package_name'), + name), + root_project_config, + package_project_config, + fqn_extra=['schema']) + +def load_and_parse_yml(package_name, root_project, all_projects, root_dir, + relative_dirs): + extension = "[!.#~]*.yml" -def load_and_parse_models(package_name, all_projects, root_dir, relative_dirs): if dbt.flags.STRICT_MODE: dbt.contracts.project.validate_list(all_projects) - return load_and_parse_files(package_name, - all_projects, - root_dir, - relative_dirs, - extension="[!.#~]*.sql", - resource_type='models') + file_matches = dbt.clients.system.find_matching( + root_dir, + relative_dirs, + extension) + + result = [] + + for file_match in file_matches: + file_contents = dbt.clients.system.load_file_contents( + file_match.get('absolute_path')) + + parts = dbt.utils.split_path(file_match.get('relative_path', '')) + name, _ = os.path.splitext(parts[-1]) + + result.append({ + 'name': name, + 'root_path': root_dir, + 'resource_type': NodeType.Test, + 'path': file_match.get('relative_path'), + 'package_name': package_name, + 'raw_yml': file_contents + }) + + return parse_schema_tests(result, root_project, all_projects) diff --git a/dbt/runner.py b/dbt/runner.py index 511fc1bb25c..c429b202d3b 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -1,6 +1,7 @@ from __future__ import print_function +import hashlib import psycopg2 import os import sys @@ -13,14 +14,16 @@ from dbt.adapters.factory import get_adapter from dbt.logger import GLOBAL_LOGGER as logger -import dbt.compilation -from dbt.linker import Linker + from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ dependency_projects from dbt.compiled_model import make_compiled_model +from dbt.model import NodeType +import dbt.compilation import dbt.exceptions +import dbt.linker import dbt.tracking import dbt.schema import dbt.graph.selector @@ -33,13 +36,233 @@ def get_timestamp(): - return "{} |".format(time.strftime("%H:%M:%S")) + return time.strftime("%H:%M:%S") + + +def get_materialization(model): + return model.get('config', {}).get('materialized') + + +def get_hash(model): + return hashlib.md5(model.get('unique_id').encode('utf-8')).hexdigest() + + +def get_hashed_contents(model): + return hashlib.md5(model.get('raw_sql').encode('utf-8')).hexdigest() + + +def is_enabled(model): + return model.get('config', {}).get('enabled') == True + + +def print_timestamped_line(msg): + logger.info("{} | {}".format(get_timestamp(), msg)) + + +def print_fancy_output_line(msg, status, index, total, execution_time=None): + prefix = "{timestamp} {index} of {total} {message}".format( + timestamp=get_timestamp(), + index=index, + total=total, + message=msg) + justified = prefix.ljust(80, ".") + + if execution_time is None: + status_time = "" + else: + status_time = " in {execution_time:0.2f}s".format( + execution_time=execution_time) + + output = "{justified} [{status}{status_time}]".format( + justified=justified, status=status, status_time=status_time) + + logger.info(output) + + +def print_skip_line(model, schema, relation, index, num_models): + msg = 'SKIP relation {}.{}'.format(schema, relation) + print_fancy_output_line(msg, 'SKIP', index, num_models) + + +def print_counts(flat_nodes): + counts = {} + for node in flat_nodes: + t = node.get('resource_type') + counts[t] = counts.get(t, 0) + 1 + + for k, v in counts.items(): + print_timestamped_line("Running {} {}s".format(v,k)) + + +def print_start_line(node, schema_name, index, total): + if node.get('resource_type') == NodeType.Model: + print_model_start_line(node, schema_name, index, total) + if node.get('resource_type') == NodeType.Test: + print_model_start_line(node, schema_name, index, total) + + +def print_model_start_line(model, schema_name, index, total): + msg = "START {model_type} model {schema}.{relation}".format( + model_type=get_materialization(model), + schema=schema_name, + relation=model.get('name')) + + print_fancy_output_line(msg, 'RUN', index, total) + + +def print_result_line(result, schema_name, index, total): + node = result.node + + if node.get('resource_type') == NodeType.Model: + print_model_result_line(result, schema_name, index, total) + elif node.get('resource_type') == NodeType.Test: + print_test_result_line(result, schema_name, index, total) + + +def print_test_result_line(result, schema_name, index, total): + model = result.node + info = 'PASS' + + if result.errored: + info = "ERROR" + elif result.status > 0: + info = 'FAIL {}'.format(result.status) + elif result.status == 0: + info = 'PASS' + else: + raise RuntimeError("unexpected status: {}".format(result.status)) + + print_fancy_output_line( + "{info} {name}".format( + info=info, + name=model.get('name')), + result.status, + index, + total, + result.execution_time) + + +def execute_test(profile, test): + adapter = get_adapter(profile) + _, cursor = adapter.execute_one( + profile, + test.get('wrapped_sql'), + test.get('name')) + + rows = cursor.fetchall() + + cursor.close() + + if len(rows) > 1: + raise RuntimeError( + "Bad test {name}: Returned {num_rows} rows instead of 1" + .format(name=model.name, num_rows=len(rows))) + + row = rows[0] + if len(row) > 1: + raise RuntimeError( + "Bad test {name}: Returned {num_cols} cols instead of 1" + .format(name=model.name, num_cols=len(row))) + + return row[0] + + +def print_model_result_line(result, schema_name, index, total): + model = result.node + info = 'OK created' + + if result.errored: + info = 'ERROR creating' + + print_fancy_output_line( + "{info} {model_type} model {schema}.{relation}".format( + info=info, + model_type=get_materialization(model), + schema=schema_name, + relation=model.get('name')), + result.status, + index, + total, + result.execution_time) + + + +def execute_model(profile, model): + adapter = get_adapter(profile) + schema = adapter.get_default_schema(profile) + existing = adapter.query_for_existing(profile, schema) + + tmp_name = '{}__dbt_tmp'.format(model.get('name')) + + if dbt.flags.NON_DESTRUCTIVE: + # for non destructive mode, we only look at the already existing table. + tmp_name = model.get('name') + + result = None + + # TRUNCATE / DROP + if get_materialization(model) == 'table' and \ + dbt.flags.NON_DESTRUCTIVE and \ + existing.get(tmp_name) == 'table': + # tables get truncated instead of dropped in non-destructive mode. + adapter.truncate( + profile=profile, + table=tmp_name, + model_name=model.get('name')) + + elif dbt.flags.NON_DESTRUCTIVE: + # never drop existing relations in non destructive mode. + pass + + elif get_materialization(model) != 'incremental' and \ + existing.get(tmp_name) is not None: + # otherwise, for non-incremental things, drop them with IF EXISTS + adapter.drop( + profile=profile, + relation=tmp_name, + relation_type=existing.get(tmp_name), + model_name=model.get('name')) + + # and update the list of what exists + existing = adapter.query_for_existing(profile, schema) + + # EXECUTE + if get_materialization(model) == 'view' and dbt.flags.NON_DESTRUCTIVE and \ + model.get('name') in existing: + # views don't need to be recreated in non destructive mode since they + # will repopulate automatically. note that we won't run DDL for these + # views either. + pass + elif is_enabled(model) and get_materialization(model) != 'ephemeral': + result = adapter.execute_model(profile, model) + + # DROP OLD RELATION AND RENAME + if dbt.flags.NON_DESTRUCTIVE: + # in non-destructive mode, we truncate and repopulate tables, and + # don't modify views. + pass + elif get_materialization(model) in ['table', 'view']: + # otherwise, drop tables and views, and rename tmp tables/views to + # their new names + if existing.get(model.get('name')) is not None: + adapter.drop( + profile=profile, + relation=model.get('name'), + relation_type=existing.get(model.get('name')), + model_name=model.get('name')) + + adapter.rename(profile=profile, + from_name=tmp_name, + to_name=model.get('name'), + model_name=model.get('name')) + + return result class RunModelResult(object): - def __init__(self, model, error=None, skip=False, status=None, + def __init__(self, node, error=None, skip=False, status=None, execution_time=0): - self.model = model + self.node = node self.error = error self.skip = skip self.status = status @@ -103,7 +326,7 @@ def pre_run_msg(self, model): return output def post_run_msg(self, result): - model = result.model + model = result.node print_vars = { "schema": self.adapter.get_default_schema(self.profile), "model_name": model.name, @@ -385,106 +608,112 @@ def call_table_exists(schema, table): "already_exists": call_table_exists, } + + def inject_runtime_config(self, node): + sql = dbt.compilation.compile_string(node.get('wrapped_sql'), + self.context) + + node['wrapped_sql'] = sql + + return node + def deserialize_graph(self): logger.info("Loading dependency graph file") - linker = Linker() base_target_path = self.project['target-path'] graph_file = os.path.join( base_target_path, dbt.compilation.graph_file_name ) - linker.read_graph(graph_file) - return linker + return dbt.linker.from_file(graph_file) + + + def execute_node(self, node): + profile = self.project.run_environment() + + logger.debug("executing node %s", node.get('unique_id')) - def execute_model(self, runner, model): - logger.debug("executing model %s", model) + node = self.inject_runtime_config(node) + + if node.get('resource_type') == NodeType.Model: + result = execute_model(profile, node) + elif node.get('resource_type') == NodeType.Test: + result = execute_test(profile, node) - result = runner.execute(model) return result - def safe_execute_model(self, data): - runner, model = data['runner'], data['model'] + + def safe_execute_node(self, data): + node = data start_time = time.time() error = None + try: - status = self.execute_model(runner, model) + status = self.execute_node(node) except (RuntimeError, dbt.exceptions.ProgrammingException, psycopg2.ProgrammingError, psycopg2.InternalError) as e: error = "Error executing {filepath}\n{error}".format( - filepath=model['build_path'], error=str(e).strip()) + filepath=node.get('build_path'), error=str(e).strip()) status = "ERROR" logger.debug(error) if type(e) == psycopg2.InternalError and \ ABORTED_TRANSACTION_STRING == e.diag.message_primary: return RunModelResult( - model, error=ABORTED_TRANSACTION_STRING, status="SKIP") + node, error=ABORTED_TRANSACTION_STRING, status="SKIP") except Exception as e: error = ("Unhandled error while executing {filepath}\n{error}" .format( - filepath=model['build_path'], error=str(e).strip())) + filepath=node.get('build_path'), error=str(e).strip())) logger.debug(error) raise e execution_time = time.time() - start_time - return RunModelResult(model, + return RunModelResult(node, error=error, status=status, execution_time=execution_time) - def as_concurrent_dep_list(self, linker, models_to_run): - # linker.as_dependency_list operates on nodes, but this method operates - # on compiled models. Use a dict to translate between the two - node_model_map = {m.fqn: m for m in models_to_run} - dependency_list = linker.as_dependency_list(node_model_map.keys()) - model_dependency_list = [] - for node_level in dependency_list: - model_level = [node_model_map[n] for n in node_level] - model_dependency_list.append(model_level) + def as_concurrent_dep_list(self, linker, nodes_to_run): + dependency_list = linker.as_dependency_list(nodes_to_run) + + concurrent_dependency_list = [] + for level in dependency_list: + node_level = [linker.get_node(node) for node in level] + concurrent_dependency_list.append(node_level) + + return concurrent_dependency_list - return model_dependency_list - def on_model_failure(self, linker, models, selected_nodes): - def skip_dependent(model): - dependent_nodes = linker.get_dependent_nodes(model.fqn) + def on_model_failure(self, linker, selected_nodes): + def skip_dependent(node): + print(node) + dependent_nodes = linker.get_dependent_nodes(node.get('unique_id')) for node in dependent_nodes: if node in selected_nodes: - model_to_skip = find_model_by_fqn(models, node) - model_to_skip.do_skip() + # TODO fix skipping + pass + return skip_dependent - def print_fancy_output_line(self, message, status, index, total, - execution_time=None): - prefix = "{timestamp} {index} of {total} {message}".format( - timestamp=get_timestamp(), - index=index, - total=total, - message=message) - justified = prefix.ljust(80, ".") - - if execution_time is None: - status_time = "" - else: - status_time = " in {execution_time:0.2f}s".format( - execution_time=execution_time) - output = "{justified} [{status}{status_time}]".format( - justified=justified, status=status, status_time=status_time) - logger.info(output) + def execute_nodes(self, node_dependency_list, on_failure): + profile = self.project.run_environment() + adapter = get_adapter(profile) + schema_name = adapter.get_default_schema(profile) + + flat_nodes = list(itertools.chain.from_iterable( + node_dependency_list)) - def execute_models(self, runner, model_dependency_list, on_failure): - flat_models = list(itertools.chain.from_iterable( - model_dependency_list)) + num_nodes = len(flat_nodes) - num_models = len(flat_models) - if num_models == 0: + if num_nodes == 0: logger.info("WARNING: Nothing to do. Try checking your model " "configs and running `dbt compile`".format( self.target_path)) @@ -499,115 +728,116 @@ def execute_models(self, runner, model_dependency_list, on_failure): pool = ThreadPool(num_threads) logger.info("") - logger.info(runner.pre_run_all_msg(flat_models)) - runner.pre_run_all(flat_models, self.context) + print_counts(flat_nodes) + + # TODO: re-add hooks + # runner.pre_run_all(flat_models, self.context) - fqn_to_id_map = {model.fqn: i + 1 for (i, model) - in enumerate(flat_models)} + node_id_to_index_map = {node.get('unique_id'): i + 1 for (i, node) + in enumerate(flat_nodes)} - def get_idx(model): - return fqn_to_id_map[model.fqn] + def get_idx(node): + return node_id_to_index_map[node.get('unique_id')] - model_results = [] - for model_list in model_dependency_list: - for i, model in enumerate([model for model in model_list - if model.should_skip()]): - msg = runner.skip_msg(model) - self.print_fancy_output_line( - msg, 'SKIP', get_idx(model), num_models) - model_result = RunModelResult(model, skip=True) - model_results.append(model_result) + node_results = [] + for node_list in node_dependency_list: + for i, node in enumerate([node for node in node_list + if node.get('skip')]): + print_skip_line( + schema_name, node.get('name'), get_idx(node), num_nodes) - models_to_execute = [model for model in model_list - if not model.should_skip()] + node_result = RunModelResult(node, skip=True) + node_results.append(node_result) + + nodes_to_execute = [node for node in node_list + if not node.get('skip')] threads = self.threads - num_models_this_batch = len(models_to_execute) - model_index = 0 + num_nodes_this_batch = len(nodes_to_execute) + node_index = 0 def on_complete(run_model_results): for run_model_result in run_model_results: - model_results.append(run_model_result) - - msg = runner.post_run_msg(run_model_result) - status = runner.status(run_model_result) - index = get_idx(run_model_result.model) - self.print_fancy_output_line( - msg, - status, - index, - num_models, - run_model_result.execution_time - ) + node_results.append(run_model_result) + + index = get_idx(run_model_result.node) + + print_result_line(run_model_result, + schema_name, + index, + num_nodes) invocation_id = dbt.tracking.active_user.invocation_id dbt.tracking.track_model_run({ "invocation_id": invocation_id, "index": index, - "total": num_models, + "total": num_nodes, "execution_time": run_model_result.execution_time, "run_status": run_model_result.status, "run_skipped": run_model_result.skip, "run_error": run_model_result.error, - "model_materialization": run_model_result.model['materialized'], # noqa - "model_id": run_model_result.model.hashed_name(), - "hashed_contents": run_model_result.model.hashed_contents(), # noqa + "model_materialization": get_materialization(run_model_result.node), # noqa + "model_id": get_hash(run_model_result.node), + "hashed_contents": get_hashed_contents(run_model_result.node), # noqa }) if run_model_result.errored: - on_failure(run_model_result.model) + on_failure(run_model_result.node) logger.info(run_model_result.error) - while model_index < num_models_this_batch: - local_models = [] + while node_index < num_nodes_this_batch: + local_nodes = [] for i in range( - model_index, - min(model_index + threads, num_models_this_batch)): - model = models_to_execute[i] - local_models.append(model) - msg = runner.pre_run_msg(model) - self.print_fancy_output_line( - msg, 'RUN', get_idx(model), num_models - ) - - wrapped_models_to_execute = [ - {"runner": runner, "model": model} - for model in local_models - ] + node_index, + min(node_index + threads, num_nodes_this_batch)): + node = nodes_to_execute[i] + local_nodes.append(node) + + print_start_line(node, + schema_name, + get_idx(node), + num_nodes) + map_result = pool.map_async( - self.safe_execute_model, - wrapped_models_to_execute, + self.safe_execute_node, + local_nodes, callback=on_complete ) map_result.wait() run_model_results = map_result.get() - model_index += threads + node_index += threads pool.close() pool.join() logger.info("") - logger.info(runner.post_run_all_msg(model_results)) - runner.post_run_all(flat_models, model_results, self.context) + logger.info("FIXME") + # logger.info(runner.post_run_all_msg(model_results)) + # runner.post_run_all(flat_models, model_results, self.context) - return model_results + return node_results - def get_nodes_to_run(self, graph, include_spec, exclude_spec, model_type): + def get_nodes_to_run(self, graph, include_spec, exclude_spec, resource_type): if include_spec is None: include_spec = ['*'] if exclude_spec is None: exclude_spec = [] - model_nodes = [ + print('graphnodes') + print(graph.nodes()) + + to_run = [ n for n in graph.nodes() - if graph.node[n]['dbt_run_type'] == model_type + if ((graph.node.get(n, {}).get('resource_type') == resource_type) + and graph.node.get(n, {}).get('empty') == False + and is_enabled(graph.node.get(n, {}))) ] - model_only_graph = graph.subgraph(model_nodes) + type_filtered_graph = graph.subgraph(to_run) selected_nodes = dbt.graph.selector.select_nodes(self.project, - model_only_graph, + type_filtered_graph, include_spec, exclude_spec) return selected_nodes @@ -650,7 +880,6 @@ def try_create_schema(self): raise def run_models_from_graph(self, include_spec, exclude_spec): - runner = ModelRunner(self.project) linker = self.deserialize_graph() selected_nodes = self.get_nodes_to_run( @@ -659,23 +888,15 @@ def run_models_from_graph(self, include_spec, exclude_spec): exclude_spec, dbt.model.NodeType.Model) - compiled_models = self.get_compiled_models( - linker, - selected_nodes, - runner.run_type) - self.try_create_schema() - model_dependency_list = self.as_concurrent_dep_list( + dependency_list = self.as_concurrent_dep_list( linker, - compiled_models - ) + selected_nodes) - on_failure = self.on_model_failure(linker, compiled_models, - selected_nodes) - results = self.execute_models( - runner, model_dependency_list, on_failure - ) + on_failure = self.on_model_failure(linker, selected_nodes) + + results = self.execute_nodes(dependency_list, on_failure) return results @@ -685,42 +906,21 @@ def run_tests_from_graph(self, include_spec, exclude_spec, runner = TestRunner(self.project) linker = self.deserialize_graph() - selected_model_nodes = self.get_nodes_to_run( + selected_nodes = self.get_nodes_to_run( linker.graph, include_spec, exclude_spec, - dbt.model.NodeType.Model) - - # just throw everything in this set, then pick out tests later - nodes_and_neighbors = set() - for model_node in selected_model_nodes: - nodes_and_neighbors.add(model_node) - neighbors = linker.graph.neighbors(model_node) - for neighbor in neighbors: - nodes_and_neighbors.add(neighbor) + dbt.model.NodeType.Test) - compiled_models = self.get_compiled_models( + dependency_list = self.as_concurrent_dep_list( linker, - nodes_and_neighbors, - runner.run_type) - - selected_nodes = set(cm.fqn for cm in compiled_models) + selected_nodes) self.try_create_schema() - all_tests = [] - if test_schemas: - all_tests.extend([cm for cm in compiled_models - if cm.is_test_type(runner.test_schema_type)]) - - if test_data: - all_tests.extend([cm for cm in compiled_models - if cm.is_test_type(runner.test_data_type)]) - - dep_list = [all_tests] + on_failure = self.on_model_failure(linker, selected_nodes) - on_failure = self.on_model_failure(linker, all_tests, selected_nodes) - results = self.execute_models(runner, dep_list, on_failure) + results = self.execute_nodes(dependency_list, on_failure) return results diff --git a/dbt/task/run.py b/dbt/task/run.py index a06ec84c532..1ddfc88413c 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -17,9 +17,9 @@ def compile(self): compiler.initialize() results = compiler.compile() - stat_line = ", ".join([ - "{} {}".format(results[k], k) for k in CompilableEntities - ]) + stat_line = ", ".join( + ["{} {}s".format(ct, t) for t, ct in results.items()]) + logger.info("Compiled {}".format(stat_line)) def run(self): diff --git a/dbt/task/test.py b/dbt/task/test.py index 8875231f59e..4c43beda8b0 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -24,17 +24,16 @@ def compile(self): compiler.initialize() results = compiler.compile() - stat_line = ", ".join( - ["{} {}".format(results[k], k) for k in CompilableEntities] - ) + stat_line = ", ".join([ + "{} {}s".format(ct, t) for t, ct in results.items() + ]) logger.info("Compiled {}".format(stat_line)) def run(self): self.compile() runner = RunManager( - self.project, self.project['target-path'], self.args - ) + self.project, self.project['target-path'], self.args) include = self.args.models exclude = self.args.exclude diff --git a/dbt/utils.py b/dbt/utils.py index 8093a6a6619..8e3e341c3e5 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -106,7 +106,7 @@ def find_model_by_name(all_models, target_model_name, for name, model in all_models.items(): resource_type, package_name, model_name = name.split('.') - if (resource_type == 'models' and \ + if (resource_type == 'model' and \ ((target_model_name == model_name) and \ (target_model_package is None or target_model_package == package_name))): diff --git a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py index 3acd4a62a01..c36a146a431 100644 --- a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py @@ -43,7 +43,9 @@ def run_schema_and_assert(self, include, exclude, expected_tests): test_task = TestTask(args, self.project) test_results = test_task.run() - ran_tests = sorted([test.model.name for test in test_results]) + print(test_results) + + ran_tests = sorted([test.node.get('name') for test in test_results]) expected_sorted = sorted(expected_tests) self.assertEqual(ran_tests, expected_sorted) diff --git a/test/integration/008_schema_tests_test/test_schema_tests.py b/test/integration/008_schema_tests_test/test_schema_tests.py index 323423a857d..d53e8aeb5e5 100644 --- a/test/integration/008_schema_tests_test/test_schema_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_tests.py @@ -25,6 +25,7 @@ def run_schema_validations(self): args = FakeArgs() test_task = TestTask(args, project) + print(project) return test_task.run() @attr(type='postgres') diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index 22a148e79e9..7c73be5c357 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -45,15 +45,16 @@ def test__prepend_ctes__already_has_cte(self): ephemeral_config['materialized'] = 'ephemeral' compiled_models = { - 'models.root.view': { + 'model.root.view': { 'name': 'view', - 'unique_id': 'models.root.view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'root_path': '/usr/src/app', 'depends_on': [ - 'models.root.ephemeral' + 'model.root.ephemeral' ], 'config': self.model_config, 'path': 'view.sql', @@ -61,16 +62,17 @@ def test__prepend_ctes__already_has_cte(self): 'compiled': True, 'extra_ctes_injected': False, 'extra_cte_ids': [ - 'models.root.ephemeral' + 'model.root.ephemeral' ], 'extra_cte_sql': [], 'injected_sql': '', 'compiled_sql': ('with cte as (select * from something_else) ' 'select * from __dbt__CTE__ephemeral') }, - 'models.root.ephemeral': { + 'model.root.ephemeral': { 'name': 'ephemeral', - 'unique_id': 'models.root.ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', @@ -89,10 +91,10 @@ def test__prepend_ctes__already_has_cte(self): } result, all_models = dbt.compilation.prepend_ctes( - compiled_models['models.root.view'], + compiled_models['model.root.view'], compiled_models) - self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result, all_models.get('model.root.view')) self.assertEqual(result.get('extra_ctes_injected'), True) self.assertEqualIgnoreWhitespace( result.get('injected_sql'), @@ -102,14 +104,15 @@ def test__prepend_ctes__already_has_cte(self): 'select * from __dbt__CTE__ephemeral')) self.assertEqual( - all_models.get('models.root.ephemeral').get('extra_ctes_injected'), + all_models.get('model.root.ephemeral').get('extra_ctes_injected'), True) def test__prepend_ctes__no_ctes(self): compiled_models = { - 'models.root.view': { + 'model.root.view': { 'name': 'view', - 'unique_id': 'models.root.view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', @@ -127,9 +130,10 @@ def test__prepend_ctes__no_ctes(self): 'compiled_sql': ('with cte as (select * from something_else) ' 'select * from source_table') }, - 'models.root.view_no_cte': { + 'model.root.view_no_cte': { 'name': 'view_no_cte', - 'unique_id': 'models.root.view_no_cte', + 'resource_type': 'model', + 'unique_id': 'model.root.view_no_cte', 'fqn': ['root_project', 'view_no_cte'], 'empty': False, 'package_name': 'root', @@ -148,24 +152,24 @@ def test__prepend_ctes__no_ctes(self): } result, all_models = dbt.compilation.prepend_ctes( - compiled_models.get('models.root.view'), + compiled_models.get('model.root.view'), compiled_models) - self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result, all_models.get('model.root.view')) self.assertEqual(result.get('extra_ctes_injected'), True) self.assertEqualIgnoreWhitespace( result.get('injected_sql'), - compiled_models.get('models.root.view').get('compiled_sql')) + compiled_models.get('model.root.view').get('compiled_sql')) result, all_models = dbt.compilation.prepend_ctes( - compiled_models.get('models.root.view_no_cte'), + compiled_models.get('model.root.view_no_cte'), compiled_models) - self.assertEqual(result, all_models.get('models.root.view_no_cte')) + self.assertEqual(result, all_models.get('model.root.view_no_cte')) self.assertEqual(result.get('extra_ctes_injected'), True) self.assertEqualIgnoreWhitespace( result.get('injected_sql'), - compiled_models.get('models.root.view_no_cte').get('compiled_sql')) + compiled_models.get('model.root.view_no_cte').get('compiled_sql')) def test__prepend_ctes(self): @@ -173,15 +177,16 @@ def test__prepend_ctes(self): ephemeral_config['materialized'] = 'ephemeral' compiled_models = { - 'models.root.view': { + 'model.root.view': { 'name': 'view', - 'unique_id': 'models.root.view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'root_path': '/usr/src/app', 'depends_on': [ - 'models.root.ephemeral' + 'model.root.ephemeral' ], 'config': self.model_config, 'path': 'view.sql', @@ -189,15 +194,16 @@ def test__prepend_ctes(self): 'compiled': True, 'extra_ctes_injected': False, 'extra_cte_ids': [ - 'models.root.ephemeral' + 'model.root.ephemeral' ], 'extra_cte_sql': [], 'injected_sql': '', 'compiled_sql': 'select * from __dbt__CTE__ephemeral' }, - 'models.root.ephemeral': { + 'model.root.ephemeral': { 'name': 'ephemeral', - 'unique_id': 'models.root.ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', @@ -216,10 +222,10 @@ def test__prepend_ctes(self): } result, all_models = dbt.compilation.prepend_ctes( - compiled_models['models.root.view'], + compiled_models['model.root.view'], compiled_models) - self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result, all_models.get('model.root.view')) self.assertEqual(result.get('extra_ctes_injected'), True) self.assertEqualIgnoreWhitespace( result.get('injected_sql'), @@ -229,7 +235,7 @@ def test__prepend_ctes(self): 'select * from __dbt__CTE__ephemeral')) self.assertEqual( - all_models.get('models.root.ephemeral').get('extra_ctes_injected'), + all_models.get('model.root.ephemeral').get('extra_ctes_injected'), True) @@ -238,15 +244,16 @@ def test__prepend_ctes__multiple_levels(self): ephemeral_config['materialized'] = 'ephemeral' compiled_models = { - 'models.root.view': { + 'model.root.view': { 'name': 'view', - 'unique_id': 'models.root.view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', 'fqn': ['root_project', 'view'], 'empty': False, 'package_name': 'root', 'root_path': '/usr/src/app', 'depends_on': [ - 'models.root.ephemeral' + 'model.root.ephemeral' ], 'config': self.model_config, 'path': 'view.sql', @@ -254,15 +261,16 @@ def test__prepend_ctes__multiple_levels(self): 'compiled': True, 'extra_ctes_injected': False, 'extra_cte_ids': [ - 'models.root.ephemeral' + 'model.root.ephemeral' ], 'extra_cte_sql': [], 'injected_sql': '', 'compiled_sql': 'select * from __dbt__CTE__ephemeral' }, - 'models.root.ephemeral': { + 'model.root.ephemeral': { 'name': 'ephemeral', - 'unique_id': 'models.root.ephemeral', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', 'fqn': ['root_project', 'ephemeral'], 'empty': False, 'package_name': 'root', @@ -274,15 +282,16 @@ def test__prepend_ctes__multiple_levels(self): 'compiled': True, 'extra_ctes_injected': False, 'extra_cte_ids': [ - 'models.root.ephemeral_level_two' + 'model.root.ephemeral_level_two' ], 'extra_cte_sql': [], 'injected_sql': '', 'compiled_sql': 'select * from __dbt__CTE__ephemeral_level_two' }, - 'models.root.ephemeral_level_two': { + 'model.root.ephemeral_level_two': { 'name': 'ephemeral_level_two', - 'unique_id': 'models.root.ephemeral_level_two', + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral_level_two', 'fqn': ['root_project', 'ephemeral_level_two'], 'empty': False, 'package_name': 'root', @@ -302,10 +311,10 @@ def test__prepend_ctes__multiple_levels(self): } result, all_models = dbt.compilation.prepend_ctes( - compiled_models['models.root.view'], + compiled_models['model.root.view'], compiled_models) - self.assertEqual(result, all_models.get('models.root.view')) + self.assertEqual(result, all_models.get('model.root.view')) self.assertEqual(result.get('extra_ctes_injected'), True) self.assertEqualIgnoreWhitespace( result.get('injected_sql'), @@ -317,8 +326,8 @@ def test__prepend_ctes__multiple_levels(self): 'select * from __dbt__CTE__ephemeral')) self.assertEqual( - all_models.get('models.root.ephemeral').get('extra_ctes_injected'), + all_models.get('model.root.ephemeral').get('extra_ctes_injected'), True) self.assertEqual( - all_models.get('models.root.ephemeral_level_two').get('extra_ctes_injected'), + all_models.get('model.root.ephemeral_level_two').get('extra_ctes_injected'), True) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index ba8b06cfc55..29c0f888a70 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -131,7 +131,7 @@ def test__single_model(self): self.assertEquals( self.graph_result.nodes(), - [('test_models_compile', 'model_one')]) + ['model.root.model_one']) self.assertEquals( self.graph_result.edges(), @@ -149,18 +149,14 @@ def test__two_models_simple_ref(self): six.assertCountEqual(self, self.graph_result.nodes(), [ - ('test_models_compile', 'model_one'), - ('test_models_compile', 'model_two') + 'model.root.model_one', + 'model.root.model_two', ]) - six.assertCountEqual(self, - self.graph_result.edges(), - [ - ( - ('test_models_compile', 'model_one'), - ('test_models_compile', 'model_two') - ) - ]) + six.assertCountEqual( + self, + self.graph_result.edges(), + [ ('model.root.model_one','model.root.model_two',) ]) def test__model_materializations(self): self.use_models({ @@ -194,7 +190,8 @@ def test__model_materializations(self): nodes = self.graph_result.node for model, expected in expected_materialization.items(): - actual = nodes[("test_models_compile", model)]["materialized"] + actual = nodes['model.root.{}'.format(model)].get('config', {}) \ + .get('materialized') self.assertEquals(actual, expected) def test__model_enabled(self): @@ -216,11 +213,13 @@ def test__model_enabled(self): compiler = self.get_compiler(self.get_project(cfg)) compiler.compile() - six.assertCountEqual(self, - self.graph_result.nodes(), - [('test_models_compile', 'model_one')]) + six.assertCountEqual( + self, self.graph_result.nodes(), + ['model.root.model_one', 'model.root.model_two']) - six.assertCountEqual(self, self.graph_result.edges(), []) + six.assertCountEqual( + self, self.graph_result.edges(), + [('model.root.model_one','model.root.model_two',)]) def test__model_incremental_without_sql_where_fails(self): self.use_models({ @@ -261,13 +260,14 @@ def test__model_incremental(self): compiler = self.get_compiler(self.get_project(cfg)) compiler.compile() - node = ('test_models_compile', 'model_one') + node = 'model.root.model_one' self.assertEqual(self.graph_result.nodes(), [node]) self.assertEqual(self.graph_result.edges(), []) self.assertEqual( - self.graph_result.node[node]['materialized'], + self.graph_result.node[node].get('config', {}) \ + .get('materialized'), 'incremental') def test__topological_ordering(self): @@ -288,31 +288,19 @@ def test__topological_ordering(self): six.assertCountEqual(self, self.graph_result.nodes(), [ - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_2'), - ('test_models_compile', 'model_3'), - ('test_models_compile', 'model_4') + 'model.root.model_1', + 'model.root.model_2', + 'model.root.model_3', + 'model.root.model_4', ]) six.assertCountEqual(self, self.graph_result.edges(), [ - ( - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_2') - ), - ( - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_3') - ), - ( - ('test_models_compile', 'model_2'), - ('test_models_compile', 'model_3') - ), - ( - ('test_models_compile', 'model_3'), - ('test_models_compile', 'model_4') - ) + ('model.root.model_1', 'model.root.model_2',), + ('model.root.model_1', 'model.root.model_3',), + ('model.root.model_2', 'model.root.model_3',), + ('model.root.model_3', 'model.root.model_4',), ]) linker = dbt.linker.Linker() @@ -320,10 +308,10 @@ def test__topological_ordering(self): actual_ordering = linker.as_topological_ordering() expected_ordering = [ - ('test_models_compile', 'model_1'), - ('test_models_compile', 'model_2'), - ('test_models_compile', 'model_3'), - ('test_models_compile', 'model_4') + 'model.root.model_1', + 'model.root.model_2', + 'model.root.model_3', + 'model.root.model_4', ] self.assertEqual(actual_ordering, expected_ordering) @@ -349,18 +337,10 @@ def test__dependency_list(self): actual_dep_list = linker.as_dependency_list() expected_dep_list = [ - [ - ('test_models_compile', 'model_1') - ], - [ - ('test_models_compile', 'model_2') - ], - [ - ('test_models_compile', 'model_3') - ], - [ - ('test_models_compile', 'model_4'), - ] + ['model.root.model_1'], + ['model.root.model_2'], + ['model.root.model_3'], + ['model.root.model_4'], ] self.assertEqual(actual_dep_list, expected_dep_list) diff --git a/test/unit/test_graph_selection.py b/test/unit/test_graph_selection.py index 8c68a9da43f..fa37f94a071 100644 --- a/test/unit/test_graph_selection.py +++ b/test/unit/test_graph_selection.py @@ -12,18 +12,12 @@ class GraphSelectionTest(unittest.TestCase): def setUp(self): integer_graph = nx.balanced_tree(2, 2, nx.DiGraph()) - simple_mapping = { - i: letter for (i, letter) in enumerate(string.ascii_lowercase) - } package_mapping = { - i: ('X' if i % 2 == 0 else 'Y', letter) + i: 'm.' + ('X' if i % 2 == 0 else 'Y') + '.' + letter for (i, letter) in enumerate(string.ascii_lowercase) } - # Edges: [(a, b), (a, c), (b, d), (b, e), (c, f), (c, g)] - self.simple_graph = nx.relabel_nodes(integer_graph, simple_mapping) - # Edges: [(X.a, Y.b), (X.a, X.c), (Y.b, Y.d), (Y.b, X.e), (X.c, Y.f), (X.c, X.g)] self.package_graph = nx.relabel_nodes(integer_graph, package_mapping) @@ -78,37 +72,13 @@ def run_specs_and_assert(self, graph, include, exclude, expected): self.assertEquals(selected, expected) - # Test the select_nodes() interface - def test__single_node_selection(self): - self.run_specs_and_assert(self.simple_graph, ['a'], [], set('a')) - - def test__node_and_children(self): - self.run_specs_and_assert(self.simple_graph, ['a+'], [], set('abcdefg')) - - def test__node_and_parents(self): - self.run_specs_and_assert(self.simple_graph, ['+g'], [], set('acg')) - - def test__node_and_children_and_parents(self): - self.run_specs_and_assert(self.simple_graph, ['+c+'], [], set('acfg')) - - def test__node_and_children_and_parents_except_one(self): - self.run_specs_and_assert(self.simple_graph, ['+c+'], ['c'], set('afg')) - - def test__node_and_children_and_parents_except_many(self): - self.run_specs_and_assert(self.simple_graph, ['+c+'], ['+f'], set('g')) - - def test__multiple_node_selection(self): - self.run_specs_and_assert(self.simple_graph, ['a', 'b'], [], set('ab')) - - def test__multiple_node_selection_mixed(self): - self.run_specs_and_assert(self.simple_graph, ['a+', 'b+'], ['b', '+c'], set('defg')) def test__single_node_selection_in_package(self): self.run_specs_and_assert( self.package_graph, ['X.a'], [], - set([('X', 'a')]) + set(['m.X.a']) ) def test__multiple_node_selection_in_package(self): @@ -116,7 +86,7 @@ def test__multiple_node_selection_in_package(self): self.package_graph, ['X.a', 'b'], [], - set([('X', 'a'), ('Y', 'b')]) + set(['m.X.a', 'm.Y.b']) ) def test__select_children_except_in_package(self): @@ -124,16 +94,7 @@ def test__select_children_except_in_package(self): self.package_graph, ['X.a+'], ['b'], - set([ - ('X', 'a'), - # ('Y', 'b'), - ('X', 'c'), - ('Y', 'd'), - ('X', 'e'), - ('Y', 'f'), - ('X', 'g') - ]) - ) + set(['m.X.a','m.X.c', 'm.Y.d','m.X.e','m.Y.f','m.X.g'])) def parse_spec_and_assert(self, spec, parents, children, qualified_node_name): parsed = graph_selector.parse_spec(spec) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index b65293f56dd..f500ff5594b 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -20,7 +20,7 @@ def setUp(self): self.maxDiff = None self.root_project_config = { - 'name': 'root_project', + 'name': 'root', 'version': '0.1', 'profile': 'test', 'project-root': os.path.abspath('.'), @@ -43,6 +43,7 @@ def setUp(self): def test__single_model(self): models = [{ 'name': 'model_one', + 'resource_type': 'model', 'package_name': 'root', 'root_path': '/usr/src/app', 'path': 'model_one.sql', @@ -50,15 +51,17 @@ def test__single_model(self): }] self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.root.model_one': { + 'model.root.model_one': { 'name': 'model_one', - 'unique_id': 'models.root.model_one', - 'fqn': ['root_project', 'model_one'], + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'model_one'], 'empty': False, 'package_name': 'root', 'root_path': '/usr/src/app', @@ -74,6 +77,7 @@ def test__single_model(self): def test__empty_model(self): models = [{ 'name': 'model_one', + 'resource_type': 'model', 'package_name': 'root', 'path': 'model_one.sql', 'root_path': '/usr/src/app', @@ -81,14 +85,16 @@ def test__empty_model(self): }] self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config}), { - 'models.root.model_one': { + 'model.root.model_one': { 'name': 'model_one', - 'unique_id': 'models.root.model_one', - 'fqn': ['root_project', 'model_one'], + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'model_one'], 'empty': True, 'package_name': 'root', 'depends_on': [], @@ -104,12 +110,14 @@ def test__empty_model(self): def test__simple_dependency(self): models = [{ 'name': 'base', + 'resource_type': 'model', 'package_name': 'root', 'path': 'base.sql', 'root_path': '/usr/src/app', 'raw_sql': 'select * from events' }, { 'name': 'events_tx', + 'resource_type': 'model', 'package_name': 'root', 'path': 'events_tx.sql', 'root_path': '/usr/src/app', @@ -117,15 +125,17 @@ def test__simple_dependency(self): }] self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.root.base': { + 'model.root.base': { 'name': 'base', - 'unique_id': 'models.root.base', - 'fqn': ['root_project', 'base'], + 'resource_type': 'model', + 'unique_id': 'model.root.base', + 'fqn': ['root', 'base'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -135,10 +145,11 @@ def test__simple_dependency(self): 'raw_sql': self.find_input_by_name( models, 'base').get('raw_sql') }, - 'models.root.events_tx': { + 'model.root.events_tx': { 'name': 'events_tx', - 'unique_id': 'models.root.events_tx', - 'fqn': ['root_project', 'events_tx'], + 'resource_type': 'model', + 'unique_id': 'model.root.events_tx', + 'fqn': ['root', 'events_tx'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -154,18 +165,21 @@ def test__simple_dependency(self): def test__multiple_dependencies(self): models = [{ 'name': 'events', + 'resource_type': 'model', 'package_name': 'root', 'path': 'events.sql', 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.events', }, { 'name': 'sessions', + 'resource_type': 'model', 'package_name': 'root', 'path': 'sessions.sql', 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.sessions', }, { 'name': 'events_tx', + 'resource_type': 'model', 'package_name': 'root', 'path': 'events_tx.sql', 'root_path': '/usr/src/app', @@ -173,6 +187,7 @@ def test__multiple_dependencies(self): "select * from events"), }, { 'name': 'sessions_tx', + 'resource_type': 'model', 'package_name': 'root', 'path': 'sessions_tx.sql', 'root_path': '/usr/src/app', @@ -180,6 +195,7 @@ def test__multiple_dependencies(self): "select * from sessions"), }, { 'name': 'multi', + 'resource_type': 'model', 'package_name': 'root', 'path': 'multi.sql', 'root_path': '/usr/src/app', @@ -189,15 +205,17 @@ def test__multiple_dependencies(self): }] self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.root.events': { + 'model.root.events': { 'name': 'events', - 'unique_id': 'models.root.events', - 'fqn': ['root_project', 'events'], + 'resource_type': 'model', + 'unique_id': 'model.root.events', + 'fqn': ['root', 'events'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -207,10 +225,11 @@ def test__multiple_dependencies(self): 'raw_sql': self.find_input_by_name( models, 'events').get('raw_sql') }, - 'models.root.sessions': { + 'model.root.sessions': { 'name': 'sessions', - 'unique_id': 'models.root.sessions', - 'fqn': ['root_project', 'sessions'], + 'resource_type': 'model', + 'unique_id': 'model.root.sessions', + 'fqn': ['root', 'sessions'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -220,10 +239,11 @@ def test__multiple_dependencies(self): 'raw_sql': self.find_input_by_name( models, 'sessions').get('raw_sql') }, - 'models.root.events_tx': { + 'model.root.events_tx': { 'name': 'events_tx', - 'unique_id': 'models.root.events_tx', - 'fqn': ['root_project', 'events_tx'], + 'resource_type': 'model', + 'unique_id': 'model.root.events_tx', + 'fqn': ['root', 'events_tx'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -233,10 +253,11 @@ def test__multiple_dependencies(self): 'raw_sql': self.find_input_by_name( models, 'events_tx').get('raw_sql') }, - 'models.root.sessions_tx': { + 'model.root.sessions_tx': { 'name': 'sessions_tx', - 'unique_id': 'models.root.sessions_tx', - 'fqn': ['root_project', 'sessions_tx'], + 'resource_type': 'model', + 'unique_id': 'model.root.sessions_tx', + 'fqn': ['root', 'sessions_tx'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -246,10 +267,11 @@ def test__multiple_dependencies(self): 'raw_sql': self.find_input_by_name( models, 'sessions_tx').get('raw_sql') }, - 'models.root.multi': { + 'model.root.multi': { 'name': 'multi', - 'unique_id': 'models.root.multi', - 'fqn': ['root_project', 'multi'], + 'resource_type': 'model', + 'unique_id': 'model.root.multi', + 'fqn': ['root', 'multi'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -265,18 +287,21 @@ def test__multiple_dependencies(self): def test__multiple_dependencies__packages(self): models = [{ 'name': 'events', + 'resource_type': 'model', 'package_name': 'snowplow', 'path': 'events.sql', 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.events', }, { 'name': 'sessions', + 'resource_type': 'model', 'package_name': 'snowplow', 'path': 'sessions.sql', 'root_path': '/usr/src/app', 'raw_sql': 'select * from base.sessions', }, { 'name': 'events_tx', + 'resource_type': 'model', 'package_name': 'snowplow', 'path': 'events_tx.sql', 'root_path': '/usr/src/app', @@ -284,6 +309,7 @@ def test__multiple_dependencies__packages(self): "select * from events"), }, { 'name': 'sessions_tx', + 'resource_type': 'model', 'package_name': 'snowplow', 'path': 'sessions_tx.sql', 'root_path': '/usr/src/app', @@ -291,6 +317,7 @@ def test__multiple_dependencies__packages(self): "select * from sessions"), }, { 'name': 'multi', + 'resource_type': 'model', 'package_name': 'root', 'path': 'multi.sql', 'root_path': '/usr/src/app', @@ -300,14 +327,16 @@ def test__multiple_dependencies__packages(self): }] self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.snowplow.events': { + 'model.snowplow.events': { 'name': 'events', - 'unique_id': 'models.snowplow.events', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.events', 'fqn': ['snowplow', 'events'], 'empty': False, 'package_name': 'snowplow', @@ -318,9 +347,10 @@ def test__multiple_dependencies__packages(self): 'raw_sql': self.find_input_by_name( models, 'events').get('raw_sql') }, - 'models.snowplow.sessions': { + 'model.snowplow.sessions': { 'name': 'sessions', - 'unique_id': 'models.snowplow.sessions', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.sessions', 'fqn': ['snowplow', 'sessions'], 'empty': False, 'package_name': 'snowplow', @@ -331,9 +361,10 @@ def test__multiple_dependencies__packages(self): 'raw_sql': self.find_input_by_name( models, 'sessions').get('raw_sql') }, - 'models.snowplow.events_tx': { + 'model.snowplow.events_tx': { 'name': 'events_tx', - 'unique_id': 'models.snowplow.events_tx', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.events_tx', 'fqn': ['snowplow', 'events_tx'], 'empty': False, 'package_name': 'snowplow', @@ -344,9 +375,10 @@ def test__multiple_dependencies__packages(self): 'raw_sql': self.find_input_by_name( models, 'events_tx').get('raw_sql') }, - 'models.snowplow.sessions_tx': { + 'model.snowplow.sessions_tx': { 'name': 'sessions_tx', - 'unique_id': 'models.snowplow.sessions_tx', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.sessions_tx', 'fqn': ['snowplow', 'sessions_tx'], 'empty': False, 'package_name': 'snowplow', @@ -357,10 +389,11 @@ def test__multiple_dependencies__packages(self): 'raw_sql': self.find_input_by_name( models, 'sessions_tx').get('raw_sql') }, - 'models.root.multi': { + 'model.root.multi': { 'name': 'multi', - 'unique_id': 'models.root.multi', - 'fqn': ['root_project', 'multi'], + 'resource_type': 'model', + 'unique_id': 'model.root.multi', + 'fqn': ['root', 'multi'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -376,6 +409,7 @@ def test__multiple_dependencies__packages(self): def test__in_model_config(self): models = [{ 'name': 'model_one', + 'resource_type': 'model', 'package_name': 'root', 'path': 'model_one.sql', 'root_path': '/usr/src/app', @@ -388,15 +422,17 @@ def test__in_model_config(self): }) self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.root.model_one': { + 'model.root.model_one': { 'name': 'model_one', - 'unique_id': 'models.root.model_one', - 'fqn': ['root_project', 'model_one'], + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'model_one'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -411,13 +447,13 @@ def test__in_model_config(self): def test__root_project_config(self): self.root_project_config = { - 'name': 'root_project', + 'name': 'root', 'version': '0.1', 'profile': 'test', 'project-root': os.path.abspath('.'), 'models': { 'materialized': 'ephemeral', - 'root_project': { + 'root': { 'view': { 'materialized': 'view' } @@ -427,6 +463,7 @@ def test__root_project_config(self): models = [{ 'name': 'table', + 'resource_type': 'model', 'package_name': 'root', 'path': 'table.sql', 'root_path': '/usr/src/app', @@ -434,12 +471,14 @@ def test__root_project_config(self): "select * from events"), }, { 'name': 'ephemeral', + 'resource_type': 'model', 'package_name': 'root', 'path': 'ephemeral.sql', 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'view', + 'resource_type': 'model', 'package_name': 'root', 'path': 'view.sql', 'root_path': '/usr/src/app', @@ -461,15 +500,17 @@ def test__root_project_config(self): }) self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.root.table': { + 'model.root.table': { 'name': 'table', - 'unique_id': 'models.root.table', - 'fqn': ['root_project', 'table'], + 'resource_type': 'model', + 'unique_id': 'model.root.table', + 'fqn': ['root', 'table'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -479,10 +520,11 @@ def test__root_project_config(self): 'raw_sql': self.find_input_by_name( models, 'table').get('raw_sql') }, - 'models.root.ephemeral': { + 'model.root.ephemeral': { 'name': 'ephemeral', - 'unique_id': 'models.root.ephemeral', - 'fqn': ['root_project', 'ephemeral'], + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root', 'ephemeral'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -492,10 +534,11 @@ def test__root_project_config(self): 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') }, - 'models.root.view': { + 'model.root.view': { 'name': 'view', - 'unique_id': 'models.root.view', - 'fqn': ['root_project', 'view'], + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root', 'view'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -511,13 +554,13 @@ def test__root_project_config(self): def test__other_project_config(self): self.root_project_config = { - 'name': 'root_project', + 'name': 'root', 'version': '0.1', 'profile': 'test', 'project-root': os.path.abspath('.'), 'models': { 'materialized': 'ephemeral', - 'root_project': { + 'root': { 'view': { 'materialized': 'view' } @@ -546,6 +589,7 @@ def test__other_project_config(self): models = [{ 'name': 'table', + 'resource_type': 'model', 'package_name': 'root', 'path': 'table.sql', 'root_path': '/usr/src/app', @@ -553,24 +597,28 @@ def test__other_project_config(self): "select * from events"), }, { 'name': 'ephemeral', + 'resource_type': 'model', 'package_name': 'root', 'path': 'ephemeral.sql', 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'view', + 'resource_type': 'model', 'package_name': 'root', 'path': 'view.sql', 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'disabled', + 'resource_type': 'model', 'package_name': 'snowplow', 'path': 'disabled.sql', 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }, { 'name': 'package', + 'resource_type': 'model', 'package_name': 'snowplow', 'path': 'models/views/package.sql', 'root_path': '/usr/src/app', @@ -604,15 +652,17 @@ def test__other_project_config(self): }) self.assertEquals( - dbt.parser.parse_models( + dbt.parser.parse_sql_nodes( models, + self.root_project_config, {'root': self.root_project_config, 'snowplow': self.snowplow_project_config}), { - 'models.root.table': { + 'model.root.table': { 'name': 'table', - 'unique_id': 'models.root.table', - 'fqn': ['root_project', 'table'], + 'resource_type': 'model', + 'unique_id': 'model.root.table', + 'fqn': ['root', 'table'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -622,10 +672,11 @@ def test__other_project_config(self): 'raw_sql': self.find_input_by_name( models, 'table').get('raw_sql') }, - 'models.root.ephemeral': { + 'model.root.ephemeral': { 'name': 'ephemeral', - 'unique_id': 'models.root.ephemeral', - 'fqn': ['root_project', 'ephemeral'], + 'resource_type': 'model', + 'unique_id': 'model.root.ephemeral', + 'fqn': ['root', 'ephemeral'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -635,10 +686,11 @@ def test__other_project_config(self): 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') }, - 'models.root.view': { + 'model.root.view': { 'name': 'view', - 'unique_id': 'models.root.view', - 'fqn': ['root_project', 'view'], + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root', 'view'], 'empty': False, 'package_name': 'root', 'depends_on': [], @@ -648,9 +700,10 @@ def test__other_project_config(self): 'raw_sql': self.find_input_by_name( models, 'view').get('raw_sql') }, - 'models.snowplow.disabled': { + 'model.snowplow.disabled': { 'name': 'disabled', - 'unique_id': 'models.snowplow.disabled', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.disabled', 'fqn': ['snowplow', 'disabled'], 'empty': False, 'package_name': 'snowplow', @@ -661,9 +714,10 @@ def test__other_project_config(self): 'raw_sql': self.find_input_by_name( models, 'disabled').get('raw_sql') }, - 'models.snowplow.package': { + 'model.snowplow.package': { 'name': 'package', - 'unique_id': 'models.snowplow.package', + 'resource_type': 'model', + 'unique_id': 'model.snowplow.package', 'fqn': ['snowplow', 'views', 'package'], 'empty': False, 'package_name': 'snowplow', @@ -676,3 +730,140 @@ def test__other_project_config(self): } } ) + + def test__simple_schema_test(self): + tests = [{ + 'name': 'test_one', + 'resource_type': 'test', + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'path': 'test_one.yml', + 'raw_sql': None, + 'raw_yml': ('{model_one: {constraints: {not_null: [id],' + 'unique: [id],' + 'accepted_values: [{field: id, values: ["a","b"]}],' + 'relationships: [{from: id, to: model_two, field: id}]' + '}}}') + }] + + not_null_sql = dbt.parser.QUERY_VALIDATE_NOT_NULL \ + .format( + field='id', + ref="{{ref('model_one')}}") + + unique_sql = dbt.parser.QUERY_VALIDATE_UNIQUE \ + .format( + field='id', + ref="{{ref('model_one')}}") + + accepted_values_sql = dbt.parser.QUERY_VALIDATE_ACCEPTED_VALUES \ + .format( + field='id', + ref="{{ref('model_one')}}", + values_csv="'a','b'") + + relationships_sql = dbt.parser.QUERY_VALIDATE_REFERENTIAL_INTEGRITY \ + .format( + parent_field='id', + parent_ref="{{ref('model_two')}}", + child_field='id', + child_ref="{{ref('model_one')}}") + + self.assertEquals( + dbt.parser.parse_schema_tests( + tests, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'test.root.not_null_model_one_id': { + 'name': 'not_null_model_one_id', + 'resource_type': 'test', + 'unique_id': 'test.root.not_null_model_one_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'raw_sql': not_null_sql, + }, + 'test.root.unique_model_one_id': { + 'name': 'unique_model_one_id', + 'resource_type': 'test', + 'unique_id': 'test.root.unique_model_one_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'raw_sql': unique_sql, + }, + 'test.root.accepted_values_model_one_id': { + 'name': 'accepted_values_model_one_id', + 'resource_type': 'test', + 'unique_id': 'test.root.accepted_values_model_one_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'raw_sql': accepted_values_sql, + }, + 'test.root.relationships_model_one_id_to_model_two_id': { + 'name': 'relationships_model_one_id_to_model_two_id', + 'resource_type': 'test', + 'unique_id': 'test.root.relationships_model_one_id_to_model_two_id', + 'fqn': ['root', 'schema', 'test_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': self.model_config, + 'path': 'test_one.yml', + 'raw_sql': relationships_sql, + } + + + } + ) + + + def test__simple_data_test(self): + tests = [{ + 'name': 'no_events', + 'resource_type': 'test', + 'package_name': 'root', + 'path': 'no_events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': "select * from {{ref('base')}}" + }] + + self.assertEquals( + dbt.parser.parse_sql_nodes( + tests, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'test.root.no_events': { + 'name': 'no_events', + 'resource_type': 'test', + 'unique_id': 'test.root.no_events', + 'fqn': ['root', 'no_events'], + 'empty': False, + 'package_name': 'root', + 'depends_on': [], + 'config': self.model_config, + 'path': 'no_events.sql', + 'root_path': '/usr/src/app', + 'raw_sql': self.find_input_by_name( + tests, 'no_events').get('raw_sql') + } + } + ) diff --git a/test/unit/test_runner.py b/test/unit/test_runner.py new file mode 100644 index 00000000000..16fbbe20b6c --- /dev/null +++ b/test/unit/test_runner.py @@ -0,0 +1,269 @@ +from mock import MagicMock, patch +import unittest + +import os + +import dbt.flags +import dbt.parser +import dbt.runner + + +class TestRunner(unittest.TestCase): + + def setUp(self): + dbt.flags.STRICT_MODE = True + dbt.flags.NON_DESTRUCTIVE = True + + self.profile = { + 'type': 'postgres', + 'dbname': 'postgres', + 'user': 'root', + 'host': 'database', + 'pass': 'password123', + 'port': 5432, + 'schema': 'public' + } + + self.model_config = { + 'enabled': True, + 'materialized': 'view', + 'post-hook': [], + 'pre-hook': [], + 'vars': {}, + } + + self.model = { + 'name': 'view', + 'resource_type': 'model', + 'unique_id': 'model.root.view', + 'fqn': ['root_project', 'view'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [ + 'model.root.ephemeral' + ], + 'config': self.model_config, + 'path': 'view.sql', + 'raw_sql': 'select * from {{ref("ephemeral")}}', + 'compiled': True, + 'extra_ctes_injected': False, + 'extra_cte_ids': [ + 'model.root.ephemeral' + ], + 'extra_cte_sql': [], + 'compiled_sql': 'select * from __dbt__CTE__ephemeral', + 'injected_sql': ('with __dbt__CTE__ephemeral as (' + 'select * from "public"."ephemeral"', + ')' + 'select * from __dbt__CTE__ephemeral'), + 'wrapped_sql': ('create view "public"."view" as (' + 'with __dbt__CTE__ephemeral as (' + 'select * from "public"."ephemeral"' + ')' + 'select * from __dbt__CTE__ephemeral' + '))') + } + + self.existing = {} + + def fake_drop(profile, relation, relation_type, model_name): + del self.existing[relation] + + def fake_query_for_existing(profile, schema): + return self.existing + + self._drop = dbt.adapters.postgres.PostgresAdapter.drop + self._query_for_existing = \ + dbt.adapters.postgres.PostgresAdapter.query_for_existing + + dbt.adapters.postgres.PostgresAdapter.drop = MagicMock( + side_effect=fake_drop) + + dbt.adapters.postgres.PostgresAdapter.query_for_existing = MagicMock( + side_effect=fake_query_for_existing) + + def tearDown(self): + dbt.adapters.postgres.PostgresAdapter.drop = self._drop + dbt.adapters.postgres.PostgresAdapter.query_for_existing = \ + self._query_for_existing + + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=None) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view__existing(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + self.existing = {'view': 'view'} + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_not_called() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table__existing(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + self.existing = {'view': 'table'} + + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + self.model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_called_once() + mock_adapter_rename.assert_not_called() + mock_adapter_execute_model.assert_called_once() + + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=None) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__view__existing__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + self.existing = {'view': 'view'} + model = self.model.copy() + + dbt.runner.execute_model( + self.profile, + model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_called_once() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() + + + @patch('dbt.adapters.postgres.PostgresAdapter.execute_model', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.rename', return_value=1) + @patch('dbt.adapters.postgres.PostgresAdapter.truncate', return_value=None) + def test__execute_model__table__existing__destructive(self, + mock_adapter_truncate, + mock_adapter_rename, + mock_adapter_execute_model): + dbt.flags.NON_DESTRUCTIVE = False + + self.existing = {'view': 'table'} + + model = self.model.copy() + model['config']['materialized'] = 'table' + + dbt.runner.execute_model( + self.profile, + self.model) + + dbt.adapters.postgres.PostgresAdapter.drop.assert_called_once() + + mock_adapter_truncate.assert_not_called() + mock_adapter_rename.assert_called_once() + mock_adapter_execute_model.assert_called_once() diff --git a/tox.ini b/tox.ini index 92df45426a0..32957b36fb3 100644 --- a/tox.ini +++ b/tox.ini @@ -48,7 +48,7 @@ basepython = python3.6 passenv = * setenv = HOME=/root/ -commands = /bin/bash -c '{envpython} $(which nosetests) -v -a type=postgres {posargs} --with-coverage --cover-branches --cover-html --cover-html-dir=htmlcov test/integration/*' +commands = /bin/bash -c '{envpython} $(which nosetests) -v -a type=postgres {posargs} test/integration/*' deps = -r{toxinidir}/requirements.txt -r{toxinidir}/dev_requirements.txt From a12504365dc67c987ba3d3c2efbfe21ac9f6b2e8 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 26 Feb 2017 15:25:35 -0500 Subject: [PATCH 06/25] archive passing, hooks not so much --- dbt/compilation.py | 26 ++- dbt/contracts/graph/parsed.py | 3 + dbt/graph/selector.py | 7 - dbt/linker.py | 5 + dbt/model.py | 1 + dbt/parser.py | 71 ++++++- dbt/runner.py | 189 ++++++++++-------- dbt/task/archive.py | 10 +- dbt/task/test.py | 9 +- dbt/templates.py | 16 +- dbt/utils.py | 21 +- .../test_simple_dependency_with_configs.py | 6 +- test/integration/base.py | 2 + 13 files changed, 235 insertions(+), 131 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 535ca480a60..03dc3507aa2 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -404,10 +404,16 @@ def compile_nodes(self, linker, nodes): for name, injected_node in injected_nodes.items(): # now turn model nodes back into the old-style model object for # wrapping - if injected_node.get('resource_type') != NodeType.Model: - # don't wrap thing that aren't models, i.e. tests. + if injected_node.get('resource_type') == NodeType.Test: + # don't wrap tests. injected_node['wrapped_sql'] = injected_node['injected_sql'] wrapped_nodes[name] = injected_node + + elif injected_node.get('resource_type') == NodeType.Archive: + # unfortunately we do everything automagically for + # archives. in the future it'd be nice to generate + # the SQL at the parser level. + pass else: model = Model( self.project, @@ -427,7 +433,8 @@ def compile_nodes(self, linker, nodes): build_path = os.path.join('build', injected_node.get('path')) - if injected_node.get('config', {}) \ + if injected_node.get('resource_type') == NodeType.Model and \ + injected_node.get('config', {}) \ .get('materialized') != 'ephemeral': self.__write(build_path, wrapped_stmt) written_nodes.append(injected_node) @@ -547,7 +554,8 @@ def get_parsed_data_tests(self, root_project, all_projects): all_projects=all_projects, root_dir=project.get('project-root'), relative_dirs=project.get('test-paths', []), - resource_type=NodeType.Test)) + resource_type=NodeType.Test, + tags=['data'])) return parsed_tests @@ -556,9 +564,6 @@ def get_parsed_schema_tests(self, root_project, all_projects): parsed_tests = {} for name, project in all_projects.items(): - print('project') - print(project) - parsed_tests.update( dbt.parser.load_and_parse_yml( package_name=name, @@ -578,6 +583,9 @@ def load_all_nodes(self, root_project, all_projects): self.get_parsed_data_tests(root_project, all_projects)) all_nodes.update( self.get_parsed_schema_tests(root_project, all_projects)) + all_nodes.update( + dbt.parser.parse_archives_from_projects(root_project, + all_projects)) return all_nodes @@ -589,10 +597,6 @@ def compile(self): all_projects = self.get_all_projects() all_nodes = self.load_all_nodes(root_project, all_projects) - - print('all_nodes') - print(all_nodes.keys()) - all_macros = self.get_macros(this_project=self.project) for project in dbt.utils.dependency_projects(self.project): diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index 89c7dfeb8b7..b70b41a0422 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -13,6 +13,8 @@ Required('post-hook'): list, Required('pre-hook'): list, Required('vars'): dict, + + # incremental optional fields Optional('sql_where'): str, Optional('unique_key'): str, } @@ -26,6 +28,7 @@ Required('depends_on'): All(list, [All(str, Length(min=1, max=255))]), Required('empty'): bool, Required('config'): config_contract, + Required('tags'): All(list, [str]), }) def validate_one(parsed_graph_item): diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index 4f25be66d79..c017d3bcd78 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -43,8 +43,6 @@ def parse_spec(node_spec): def get_package_names(graph): - print('get_package_names') - print([node.split(".")[1] for node in graph.nodes()]) return set([node.split(".")[1] for node in graph.nodes()]) @@ -137,11 +135,6 @@ def warn_if_useless_spec(spec, nodes): def select_nodes(project, graph, raw_include_specs, raw_exclude_specs): selected_nodes = set() - print('select_nodes') - print(graph) - print(raw_include_specs) - print(raw_exclude_specs) - split_include_specs = split_specs(raw_include_specs) split_exclude_specs = split_specs(raw_exclude_specs) diff --git a/dbt/linker.py b/dbt/linker.py index a279e09b559..b548dd3330a 100644 --- a/dbt/linker.py +++ b/dbt/linker.py @@ -104,6 +104,11 @@ def dependency(self, node1, node2): self.graph.add_node(node2) self.graph.add_edge(node2, node1) + if len(list(nx.simple_cycles(self.graph))) > 0: + raise ValidationException( + "Detected a cycle when adding dependency from {} to {}" + .format(node1, node2)) + def add_node(self, node): self.graph.add_node(node) diff --git a/dbt/model.py b/dbt/model.py index 7a16ca9ea71..093462af2bc 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -91,6 +91,7 @@ def config(self): cfg = self._merge(defaults, active_config, self.in_model_config) else: own_config = self.load_config_from_own_project() + cfg = self._merge( defaults, own_config, self.in_model_config, active_config ) diff --git a/dbt/parser.py b/dbt/parser.py index 6d36dcafec0..e22a14ef552 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -125,7 +125,7 @@ def get_fqn(path, package_project_config, extra=[]): return fqn def parse_node(node, node_path, root_project_config, package_project_config, - fqn_extra=[]): + tags=[], fqn_extra=[]): parsed_node = copy.deepcopy(node) parsed_node.update({ @@ -147,19 +147,19 @@ def parse_node(node, node_path, root_project_config, package_project_config, env.from_string(node.get('raw_sql')).render(context) + config_dict = node.get('config', {}) + config_dict.update(config.config) + parsed_node['unique_id'] = node_path - parsed_node['config'] = config.config + parsed_node['config'] = config_dict parsed_node['empty'] = (len(node.get('raw_sql').strip()) == 0) parsed_node['fqn'] = fqn + parsed_node['tags'] = tags return parsed_node -def parse_models(nodes, projects): - return parse_sql_nodes(nodes, projects) - - -def parse_sql_nodes(nodes, root_project, projects): +def parse_sql_nodes(nodes, root_project, projects, tags=[]): to_return = {} dbt.contracts.graph.unparsed.validate(nodes) @@ -175,7 +175,8 @@ def parse_sql_nodes(nodes, root_project, projects): to_return[node_path] = parse_node(node, node_path, root_project, - projects.get(package_name)) + projects.get(package_name), + tags=tags) dbt.contracts.graph.parsed.validate(to_return) @@ -183,7 +184,7 @@ def parse_sql_nodes(nodes, root_project, projects): def load_and_parse_sql(package_name, root_project, all_projects, root_dir, - relative_dirs, resource_type): + relative_dirs, resource_type, tags=[]): extension = "[!.#~]*.sql" if dbt.flags.STRICT_MODE: @@ -212,7 +213,7 @@ def load_and_parse_sql(package_name, root_project, all_projects, root_dir, 'raw_sql': file_contents }) - return parse_sql_nodes(result, root_project, all_projects) + return parse_sql_nodes(result, root_project, all_projects, tags) def parse_schema_tests(tests, root_project, projects): @@ -290,6 +291,7 @@ def parse_schema_test(test_base, model_name, test_config, test_type, name), root_project_config, package_project_config, + tags=['schema'], fqn_extra=['schema']) @@ -324,3 +326,52 @@ def load_and_parse_yml(package_name, root_project, all_projects, root_dir, }) return parse_schema_tests(result, root_project, all_projects) + + +def parse_archives_from_projects(root_project, all_projects): + archives = [] + to_return = {} + + for name, project in all_projects.items(): + archives = archives + parse_archives_from_project(project) + + for archive in archives: + node_path = get_path(archive.get('resource_type'), + archive.get('package_name'), + archive.get('name')) + + to_return[node_path] = parse_node( + archive, + node_path, + root_project, + all_projects.get(archive.get('package_name'))) + + return to_return + + +def parse_archives_from_project(project): + archives = [] + archive_configs = project.get('archive', []) + + for archive_config in archive_configs: + tables = archive_config.get('tables') + + if tables is None: + continue + + for table in tables: + config = table.copy() + config['source_schema'] = archive_config.get('source_schema') + config['target_schema'] = archive_config.get('target_schema') + + archives.append({ + 'name': table.get('target_table'), + 'root_path': project.get('project-root'), + 'resource_type': NodeType.Archive, + 'path': project.get('project-root'), + 'package_name': project.get('name'), + 'config': config, + 'raw_sql': '-- noop' + }) + + return archives diff --git a/dbt/runner.py b/dbt/runner.py index c429b202d3b..e7422394dbf 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -1,6 +1,7 @@ from __future__ import print_function +import jinja2 import hashlib import psycopg2 import os @@ -98,7 +99,14 @@ def print_start_line(node, schema_name, index, total): if node.get('resource_type') == NodeType.Model: print_model_start_line(node, schema_name, index, total) if node.get('resource_type') == NodeType.Test: - print_model_start_line(node, schema_name, index, total) + print_test_start_line(node, schema_name, index, total) + + +def print_test_start_line(model, schema_name, index, total): + msg = "START test {name}".format( + name=model.get('name')) + + print_fancy_output_line(msg, 'RUN', index, total) def print_model_start_line(model, schema_name, index, total): @@ -186,7 +194,6 @@ def print_model_result_line(result, schema_name, index, total): result.execution_time) - def execute_model(profile, model): adapter = get_adapter(profile) schema = adapter.get_default_schema(profile) @@ -259,6 +266,66 @@ def execute_model(profile, model): return result +def execute_archive(profile, node, context): + adapter = get_adapter(profile) + + node_cfg = node.get('config', {}) + + source_columns = adapter.get_columns_in_table( + profile, node_cfg.get('source_schema'), node_cfg.get('source_table')) + + if len(source_columns) == 0: + raise RuntimeError( + 'Source table "{}"."{}" does not ' + 'exist'.format(source_schema, source_table)) + + dest_columns = source_columns + [ + dbt.schema.Column("valid_from", "timestamp", None), + dbt.schema.Column("valid_to", "timestamp", None), + dbt.schema.Column("scd_id", "text", None), + dbt.schema.Column("dbt_updated_at", "timestamp", None) + ] + + adapter.create_table( + profile, + schema=node_cfg.get('target_schema'), + table=node_cfg.get('target_table'), + columns=dest_columns, + sort=node_cfg.get('updated_at'), + dist=node_cfg.get('unique_key')) + + # TODO move this to inject_runtime_config, generate archive SQL + # in wrap step. can't do this right now because we actually need + # to inspect status of the schema at runtime and archive requires + # a lot of information about the schema to generate queries. + template_ctx = context.copy() + template_ctx.update(node_cfg) + + env = jinja2.Environment() + select = env.from_string( + dbt.templates.SCDArchiveTemplate, + template_ctx + ).render(node_cfg) + + insert_stmt = dbt.templates.ArchiveInsertTemplate().wrap( + schema=node_cfg.get('target_schema'), + table=node_cfg.get('target_table'), + query=select, + unique_key=node_cfg.get('unique_key')) + + env = jinja2.Environment() + node['wrapped_sql'] = env.from_string( + insert_stmt, + template_ctx + ).render(node_cfg) + + result = adapter.execute_model( + profile=profile, + model=node) + + return result + + class RunModelResult(object): def __init__(self, node, error=None, skip=False, status=None, execution_time=0): @@ -640,6 +707,8 @@ def execute_node(self, node): result = execute_model(profile, node) elif node.get('resource_type') == NodeType.Test: result = execute_test(profile, node) + elif node.get('resource_type') == NodeType.Archive: + result = execute_archive(profile, node, self.context) return result @@ -693,7 +762,6 @@ def as_concurrent_dep_list(self, linker, nodes_to_run): def on_model_failure(self, linker, selected_nodes): def skip_dependent(node): - print(node) dependent_nodes = linker.get_dependent_nodes(node.get('unique_id')) for node in dependent_nodes: if node in selected_nodes: @@ -750,7 +818,7 @@ def get_idx(node): node_results.append(node_result) nodes_to_execute = [node for node in node_list - if not node.get('skip')] + if not node.get('skip')] threads = self.threads num_nodes_this_batch = len(nodes_to_execute) @@ -818,29 +886,37 @@ def on_complete(run_model_results): return node_results - def get_nodes_to_run(self, graph, include_spec, exclude_spec, resource_type): + def get_nodes_to_run(self, graph, include_spec, exclude_spec, + resource_types, tags): if include_spec is None: include_spec = ['*'] if exclude_spec is None: exclude_spec = [] - print('graphnodes') - print(graph.nodes()) - to_run = [ n for n in graph.nodes() - if ((graph.node.get(n, {}).get('resource_type') == resource_type) - and graph.node.get(n, {}).get('empty') == False - and is_enabled(graph.node.get(n, {}))) + if ((graph.node.get(n).get('resource_type') in resource_types) + and is_enabled(graph.node.get(n)) + and (len(tags) == 0 or + # does the node share any tags with the run? + bool(set(graph.node.get(n).get('tags')) & + set(tags)))) ] - type_filtered_graph = graph.subgraph(to_run) + filtered_graph = graph.subgraph(to_run) selected_nodes = dbt.graph.selector.select_nodes(self.project, - type_filtered_graph, + filtered_graph, include_spec, exclude_spec) - return selected_nodes + + post_filter = [ + n for n in selected_nodes + if (get_materialization(graph.node.get(n)) != 'ephemeral' and + graph.node.get(n).get('empty') == False) + ] + + return post_filter def get_compiled_models(self, linker, nodes, node_type): compiled_models = [] @@ -879,38 +955,16 @@ def try_create_schema(self): logger.info(str(e)) raise - def run_models_from_graph(self, include_spec, exclude_spec): - linker = self.deserialize_graph() - - selected_nodes = self.get_nodes_to_run( - linker.graph, - include_spec, - exclude_spec, - dbt.model.NodeType.Model) - - self.try_create_schema() - - dependency_list = self.as_concurrent_dep_list( - linker, - selected_nodes) - - on_failure = self.on_model_failure(linker, selected_nodes) - - results = self.execute_nodes(dependency_list, on_failure) - - return results - - def run_tests_from_graph(self, include_spec, exclude_spec, - test_schemas, test_data): - - runner = TestRunner(self.project) + def run_types_from_graph(self, include_spec, exclude_spec, + resource_types, tags): linker = self.deserialize_graph() selected_nodes = self.get_nodes_to_run( linker.graph, include_spec, exclude_spec, - dbt.model.NodeType.Test) + resource_types, + tags) dependency_list = self.as_concurrent_dep_list( linker, @@ -924,45 +978,22 @@ def run_tests_from_graph(self, include_spec, exclude_spec, return results - def run_archives_from_graph(self): - runner = ArchiveRunner(self.project) - linker = self.deserialize_graph() - - selected_nodes = self.get_nodes_to_run( - linker.graph, - None, - None, - dbt.model.NodeType.Archive) - - compiled_models = self.get_compiled_models( - linker, - selected_nodes, - runner.run_type) - - self.try_create_schema() - - model_dependency_list = self.as_concurrent_dep_list( - linker, - compiled_models - ) - - on_failure = self.on_model_failure(linker, compiled_models, - selected_nodes) - results = self.execute_models( - runner, model_dependency_list, on_failure - ) - - return results - # ------------------------------------ - def run_tests(self, include_spec, exclude_spec, - test_schemas=False, test_data=False): - return self.run_tests_from_graph(include_spec, exclude_spec, - test_schemas, test_data) - def run_models(self, include_spec, exclude_spec): - return self.run_models_from_graph(include_spec, exclude_spec) - - def run_archives(self): - return self.run_archives_from_graph() + return self.run_types_from_graph(include_spec, + exclude_spec, + [NodeType.Model], + []) + + def run_tests(self, include_spec, exclude_spec, tags): + return self.run_types_from_graph(include_spec, + exclude_spec, + [NodeType.Test], + tags) + + def run_archives(self, include_spec, exclude_spec): + return self.run_types_from_graph(include_spec, + exclude_spec, + [NodeType.Archive], + []) diff --git a/dbt/task/archive.py b/dbt/task/archive.py index df8b5e5d048..160c8bfa7ec 100644 --- a/dbt/task/archive.py +++ b/dbt/task/archive.py @@ -11,10 +11,12 @@ def __init__(self, args, project): def compile(self): compiler = Compiler(self.project, self.args) compiler.initialize() - compiled = compiler.compile() + results = compiler.compile() - count_compiled_archives = compiled['archives'] - logger.info("Compiled {} archives".format(count_compiled_archives)) + stat_line = ", ".join( + ["{} {}s".format(ct, t) for t, ct in results.items()]) + + logger.info("Compiled {}".format(stat_line)) def run(self): self.compile() @@ -25,4 +27,4 @@ def run(self): self.args ) - runner.run_archives() + runner.run_archives(['*'], []) diff --git a/dbt/task/test.py b/dbt/task/test.py index 4c43beda8b0..b018540ffb7 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -40,14 +40,11 @@ def run(self): if (self.args.data and self.args.schema) or \ (not self.args.data and not self.args.schema): - res = runner.run_tests(include, exclude, - test_schemas=True, test_data=True) + res = runner.run_tests(include, exclude, []) elif self.args.data: - res = runner.run_tests(include, exclude, - test_schemas=False, test_data=True) + res = runner.run_tests(include, exclude, ['data']) elif self.args.schema: - res = runner.run_tests(include, exclude, - test_schemas=True, test_data=False) + res = runner.run_tests(include, exclude, ['schema']) else: raise RuntimeError("unexpected") diff --git a/dbt/templates.py b/dbt/templates.py index 6c5ed2e1334..a1aef839ec7 100644 --- a/dbt/templates.py +++ b/dbt/templates.py @@ -117,11 +117,9 @@ def wrap(self, opts): with "current_data" as ( select - {% raw %} - {% for col in get_columns_in_table(source_schema, source_table) %} - "{{ col.name }}" {% if not loop.last %},{% endif %} - {% endfor %}, - {% endraw %} + {% for col in get_columns_in_table(source_schema, source_table) %} + "{{ col.name }}" {% if not loop.last %},{% endif %} + {% endfor %}, "{{ updated_at }}" as "dbt_updated_at", "{{ unique_key }}" as "dbt_pk", "{{ updated_at }}" as "valid_from", @@ -133,11 +131,9 @@ def wrap(self, opts): "archived_data" as ( select - {% raw %} - {% for col in get_columns_in_table(source_schema, source_table) %} - "{{ col.name }}" {% if not loop.last %},{% endif %} - {% endfor %}, - {% endraw %} + {% for col in get_columns_in_table(source_schema, source_table) %} + "{{ col.name }}" {% if not loop.last %},{% endif %} + {% endfor %}, "{{ updated_at }}" as "dbt_updated_at", "{{ unique_key }}" as "dbt_pk", "valid_from", diff --git a/dbt/utils.py b/dbt/utils.py index 8e3e341c3e5..00e350eae29 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -1,4 +1,5 @@ import os +import jinja2 import json import dbt.project @@ -66,7 +67,10 @@ def __init__(self, model, context): if isinstance(model, dict) and model.get('unique_id'): self.local_vars = model.get('config', {}).get('vars') + self.model_name = model.get('name') else: + # still used for wrapping + self.model_name = model.nice_name self.local_vars = model.config.get('vars', {}) def pretty_dict(self, data): @@ -78,7 +82,7 @@ def __call__(self, var_name, default=None): compiler_error( self.model, self.UndefinedVarError.format( - var_name, self.model.nice_name, pretty_vars + var_name, self.model_name, pretty_vars ) ) elif var_name in self.local_vars: @@ -90,7 +94,20 @@ def __call__(self, var_name, default=None): var_name, self.model.nice_name, pretty_vars ) ) - compiled = self.model.compile_string(self.context, raw) + + # python 2+3 check for stringiness + try: + basestring + except NameError: + basestring = str + + # if bool/int/float/etc are passed in, don't compile anything + if not isinstance(raw, basestring): + return raw + + env = jinja2.Environment() + compiled = env.from_string(raw, self.context).render(self.context) + return compiled else: return default diff --git a/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py b/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py index 971c06bbefd..ac0744ed74f 100644 --- a/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py +++ b/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py @@ -105,6 +105,8 @@ def project_config(self): @attr(type='postgres') def test_simple_dependency(self): + self.use_default_project() + self.run_dbt(["deps"]) self.run_dbt(["run"]) @@ -114,7 +116,7 @@ def test_simple_dependency(self): self.assertTablesEqual("seed","incremental") -class TestSimpleDependencyWithModelSpecificOverriddenConfigs(BaseTestSimpleDependencyWithConfigs): +class TestSimpleDependencyWithModelSpecificOverriddenConfigsAndMaterializations(BaseTestSimpleDependencyWithConfigs): @property def project_config(self): @@ -127,7 +129,7 @@ def project_config(self): "vars": { "config_1": "ghi", "config_2": "jkl", - #"bool_config": True + "bool_config": True } }, diff --git a/test/integration/base.py b/test/integration/base.py index bce6ec87b22..5d697984a52 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -194,6 +194,8 @@ def tearDown(self): except: os.rename("dbt_modules", "dbt_modules-{}".format(time.time())) + self.handle.close() + @property def project_config(self): return {} From 5a6411355d679101b877bcfa5a52c091a1bba2dc Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 26 Feb 2017 17:53:15 -0500 Subject: [PATCH 07/25] integration tests passing! --- dbt/adapters/postgres.py | 10 ++ dbt/compilation.py | 97 ++++++++++--------- dbt/contracts/graph/parsed.py | 4 + dbt/flags.py | 2 +- dbt/graph/selector.py | 20 +++- dbt/main.py | 6 ++ dbt/parser.py | 45 +++++++-- dbt/runner.py | 75 ++++++++++---- .../test_schema_tests.py | 4 +- .../integration/010_permission_tests/seed.sql | 24 ++--- .../010_permission_tests/tearDown.sql | 2 +- .../010_permission_tests/test_permissions.py | 5 +- .../test_context_vars.py | 2 + .../test_cli_invocation.py | 3 +- 14 files changed, 203 insertions(+), 96 deletions(-) diff --git a/dbt/adapters/postgres.py b/dbt/adapters/postgres.py index 6ea6a2d6462..f886ed3c073 100644 --- a/dbt/adapters/postgres.py +++ b/dbt/adapters/postgres.py @@ -504,6 +504,16 @@ def commit(cls, profile): handle = connection.get('handle') handle.commit() + @classmethod + def rollback(cls, profile): + connection = cls.get_connection(profile) + + if flags.STRICT_MODE: + validate_connection(connection) + + handle = connection.get('handle') + handle.rollback() + @classmethod def get_status(cls, cursor): return cursor.statusmessage diff --git a/dbt/compilation.py b/dbt/compilation.py index 03dc3507aa2..dbd9b453aef 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -139,8 +139,6 @@ def __init__(self, project, args): self.args = args self.parsed_models = None - self.macro_generator = None - def initialize(self): if not os.path.exists(self.project['target-path']): os.makedirs(self.project['target-path']) @@ -246,41 +244,45 @@ def wrapped_do_ref(*args): return wrapped_do_ref - def get_compiler_context(self, linker, model, models): - runtime = RuntimeContext(model=model) - + def get_compiler_context(self, linker, model, models, + macro_generator=None): context = self.project.context() + if macro_generator is not None: + for macro_data in macro_generator(context): + macro = macro_data["macro"] + macro_name = macro_data["name"] + project = macro_data["project"] + + if context.get(project.get('name')) is None: + context[project.get('name')] = {} + + context.get(project.get('name'), {}) \ + .update({macro_name: macro}) + + if model.get('package_name') == project.get('name'): + context.update({macro_name: macro}) + + adapter = get_adapter(self.project.run_environment()) + # built-ins context['ref'] = self.__ref(context, model, models) context['config'] = self.__model_config(model, linker) - #context['this'] = This( - # context['env']['schema'], model.immediate_name, model.name - #) + context['this'] = This( + context['env']['schema'], + (model.get('name') if dbt.flags.NON_DESTRUCTIVE + else '{}__dbt_tmp'.format(model.get('name'))), + model.get('name') + ) context['var'] = Var(model, context=context) context['target'] = self.project.get_target() # these get re-interpolated at runtime! context['run_started_at'] = '{{ run_started_at }}' context['invocation_id'] = '{{ invocation_id }}' - - adapter = get_adapter(self.project.run_environment()) context['sql_now'] = adapter.date_function - runtime.update_global(context) - - # add in macros (can we cache these somehow?) - for macro_data in self.macro_generator(context): - macro = macro_data["macro"] - macro_name = macro_data["name"] - project = macro_data["project"] - - runtime.update_package(project['name'], {macro_name: macro}) - - if project['name'] == self.project['name']: - runtime.update_global({macro_name: macro}) - - return runtime + return context def get_context(self, linker, model, models): runtime = RuntimeContext(model=model) @@ -305,20 +307,9 @@ def get_context(self, linker, model, models): runtime.update_global(context) - # add in macros (can we cache these somehow?) - for macro_data in self.macro_generator(context): - macro = macro_data["macro"] - macro_name = macro_data["name"] - project = macro_data["project"] - - runtime.update_package(project['name'], {macro_name: macro}) - - if project['name'] == self.project['name']: - runtime.update_global({macro_name: macro}) - return runtime - def compile_node(self, linker, node, nodes): + def compile_node(self, linker, node, nodes, macro_generator): try: compiled_node = node.copy() compiled_node.update({ @@ -330,7 +321,8 @@ def compile_node(self, linker, node, nodes): 'injected_sql': None, }) - context = self.get_compiler_context(linker, compiled_node, nodes) + context = self.get_compiler_context(linker, compiled_node, nodes, + macro_generator) env = jinja2.sandbox.SandboxedEnvironment() @@ -380,7 +372,7 @@ def new_add_cte_to_rendered_query(self, linker, primary_model, return compiled_query - def compile_nodes(self, linker, nodes): + def compile_nodes(self, linker, nodes, macro_generator): all_projects = self.get_all_projects() compiled_nodes = {} @@ -389,7 +381,8 @@ def compile_nodes(self, linker, nodes): written_nodes = [] for name, node in nodes.items(): - compiled_nodes[name] = self.compile_node(linker, node, nodes) + compiled_nodes[name] = self.compile_node(linker, node, nodes, + macro_generator) if dbt.flags.STRICT_MODE: dbt.contracts.graph.compiled.validate(compiled_nodes) @@ -527,7 +520,7 @@ def get_all_projects(self): return all_projects - def get_parsed_models(self, root_project, all_projects): + def get_parsed_models(self, root_project, all_projects, macro_generator): parsed_models = {} for name, project in all_projects.items(): @@ -538,12 +531,14 @@ def get_parsed_models(self, root_project, all_projects): all_projects=all_projects, root_dir=project.get('project-root'), relative_dirs=project.get('source-paths', []), - resource_type=NodeType.Model)) + resource_type=NodeType.Model, + macro_generator=macro_generator)) return parsed_models - def get_parsed_data_tests(self, root_project, all_projects): + def get_parsed_data_tests(self, root_project, all_projects, + macro_generator): parsed_tests = {} for name, project in all_projects.items(): @@ -555,6 +550,7 @@ def get_parsed_data_tests(self, root_project, all_projects): root_dir=project.get('project-root'), relative_dirs=project.get('test-paths', []), resource_type=NodeType.Test, + macro_generator=macro_generator, tags=['data'])) return parsed_tests @@ -575,12 +571,14 @@ def get_parsed_schema_tests(self, root_project, all_projects): return parsed_tests - def load_all_nodes(self, root_project, all_projects): + def load_all_nodes(self, root_project, all_projects, macro_generator): all_nodes = {} - all_nodes.update(self.get_parsed_models(root_project, all_projects)) + all_nodes.update(self.get_parsed_models(root_project, all_projects, + macro_generator)) all_nodes.update( - self.get_parsed_data_tests(root_project, all_projects)) + self.get_parsed_data_tests(root_project, all_projects, + macro_generator)) all_nodes.update( self.get_parsed_schema_tests(root_project, all_projects)) all_nodes.update( @@ -596,7 +594,6 @@ def compile(self): root_project = self.project.cfg all_projects = self.get_all_projects() - all_nodes = self.load_all_nodes(root_project, all_projects) all_macros = self.get_macros(this_project=self.project) for project in dbt.utils.dependency_projects(self.project): @@ -604,9 +601,13 @@ def compile(self): self.get_macros(this_project=self.project, own_project=project) ) - self.macro_generator = self.generate_macros(all_macros) + macro_generator = self.generate_macros(all_macros) + + all_nodes = self.load_all_nodes(root_project, all_projects, + macro_generator) - compiled_nodes, written_nodes = self.compile_nodes(linker, all_nodes) + compiled_nodes, written_nodes = self.compile_nodes(linker, all_nodes, + macro_generator) # TODO re-add archives diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index b70b41a0422..04a7e340c46 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -17,6 +17,10 @@ # incremental optional fields Optional('sql_where'): str, Optional('unique_key'): str, + + # adapter optional fields + Optional('sort'): str, + Optional('dist'): str, } parsed_graph_item_contract = unparsed_graph_item_contract.extend({ diff --git a/dbt/flags.py b/dbt/flags.py index a3ae28bd830..3048445bf21 100644 --- a/dbt/flags.py +++ b/dbt/flags.py @@ -1,2 +1,2 @@ STRICT_MODE = False -NON_DESTRUCTIVE = True +NON_DESTRUCTIVE = False diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index c017d3bcd78..9ae42d9beb7 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -3,6 +3,9 @@ import networkx as nx from dbt.logger import GLOBAL_LOGGER as logger +import dbt.model + + SELECTOR_PARENTS = '+' SELECTOR_CHILDREN = '+' SELECTOR_GLOB = '*' @@ -82,7 +85,6 @@ def get_nodes_by_qualified_name(project, graph, qualified_name): # node naming has changed to dot notation. split to tuple for # compatibility with this code. fqn_ish = node.split('.')[1:] - print(fqn_ish) if len(qualified_name) == 1 and fqn_ish == qualified_name[0]: yield node @@ -109,6 +111,8 @@ def get_nodes_from_spec(project, graph, spec): qualified_node_name)) additional_nodes = set() + test_nodes = set() + if select_parents: for node in selected_nodes: parent_nodes = nx.ancestors(graph, node) @@ -117,9 +121,21 @@ def get_nodes_from_spec(project, graph, spec): if select_children: for node in selected_nodes: child_nodes = nx.descendants(graph, node) + print('\nchild_nodes') + print(child_nodes) additional_nodes.update(child_nodes) - return selected_nodes | additional_nodes + model_nodes = selected_nodes | additional_nodes + + for node in model_nodes: + # include tests that depend on this node. if we aren't running tests, + # they'll be filtered out later. + child_tests = [n for n in graph.successors(node) + if graph.node.get(n).get('resource_type') == \ + dbt.model.NodeType.Test] + test_nodes.update(child_tests) + + return model_nodes | test_nodes def warn_if_useless_spec(spec, nodes): diff --git a/dbt/main.py b/dbt/main.py index 99710aaa1c7..565baa62bf0 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -175,6 +175,12 @@ def invoke_dbt(parsed): log_dir = proj.get('log-path', 'logs') + if hasattr(proj.args, 'non_destructive') and \ + proj.args.non_destructive == True: + flags.NON_DESTRUCTIVE = True + else: + flags.NON_DESTRUCTIVE = False + logger.debug("running dbt with arguments %s", parsed) task = parsed.cls(args=parsed, project=proj) diff --git a/dbt/parser.py b/dbt/parser.py index e22a14ef552..388db393244 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -125,7 +125,7 @@ def get_fqn(path, package_project_config, extra=[]): return fqn def parse_node(node, node_path, root_project_config, package_project_config, - tags=[], fqn_extra=[]): + macro_generator=None, tags=[], fqn_extra=[]): parsed_node = copy.deepcopy(node) parsed_node.update({ @@ -137,10 +137,25 @@ def parse_node(node, node_path, root_project_config, package_project_config, config = dbt.model.SourceConfig( root_project_config, package_project_config, fqn) - context = { - 'ref': __ref(parsed_node), - 'config': __config(parsed_node, config), - } + context = {} + + if macro_generator is not None: + for macro_data in macro_generator(context): + macro = macro_data["macro"] + macro_name = macro_data["name"] + project = macro_data["project"] + + if context.get(project.get('name')) is None: + context[project.get('name')] = {} + + context.get(project.get('name'), {}) \ + .update({macro_name: macro}) + + if node.get('package_name') == project.get('name'): + context.update({macro_name: macro}) + + context['ref'] = __ref(parsed_node) + context['config'] = __config(parsed_node, config) env = jinja2.sandbox.SandboxedEnvironment( undefined=SilentUndefined) @@ -159,7 +174,7 @@ def parse_node(node, node_path, root_project_config, package_project_config, return parsed_node -def parse_sql_nodes(nodes, root_project, projects, tags=[]): +def parse_sql_nodes(nodes, root_project, projects, macro_generator, tags=[]): to_return = {} dbt.contracts.graph.unparsed.validate(nodes) @@ -176,6 +191,7 @@ def parse_sql_nodes(nodes, root_project, projects, tags=[]): node_path, root_project, projects.get(package_name), + macro_generator, tags=tags) dbt.contracts.graph.parsed.validate(to_return) @@ -184,7 +200,7 @@ def parse_sql_nodes(nodes, root_project, projects, tags=[]): def load_and_parse_sql(package_name, root_project, all_projects, root_dir, - relative_dirs, resource_type, tags=[]): + relative_dirs, resource_type, macro_generator, tags=[]): extension = "[!.#~]*.sql" if dbt.flags.STRICT_MODE: @@ -213,7 +229,8 @@ def load_and_parse_sql(package_name, root_project, all_projects, root_dir, 'raw_sql': file_contents }) - return parse_sql_nodes(result, root_project, all_projects, tags) + return parse_sql_nodes(result, root_project, all_projects, macro_generator, + tags) def parse_schema_tests(tests, root_project, projects): @@ -231,7 +248,9 @@ def parse_schema_tests(tests, root_project, projects): test, model_name, config, test_type, root_project, projects.get(test.get('package_name'))) - to_return[to_add.get('unique_id')] = to_add + + if to_add is not None: + to_return[to_add.get('unique_id')] = to_add return to_return @@ -249,6 +268,9 @@ def parse_schema_test(test_base, model_name, test_config, test_type, name_key = test_config elif test_type == 'relationships': + if not isinstance(test_config, dict): + return None + child_field = test_config.get('from') parent_field = test_config.get('field') parent_model = test_config.get('to') @@ -263,11 +285,14 @@ def parse_schema_test(test_base, model_name, test_config, test_type, parent_field) elif test_type == 'accepted_values': + if not isinstance(test_config, dict): + return None + raw_sql = QUERY_VALIDATE_ACCEPTED_VALUES.format( ref="{{ref('"+model_name+"')}}", field=test_config.get('field', ''), values_csv="'{}'".format( - "','".join(test_config.get('values', [])))) + "','".join([str(v) for v in test_config.get('values', [])]))) name_key = test_config.get('field') diff --git a/dbt/runner.py b/dbt/runner.py index e7422394dbf..795fc347a0a 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -159,6 +159,8 @@ def execute_test(profile, test): rows = cursor.fetchall() + adapter.rollback(profile) + cursor.close() if len(rows) > 1: @@ -325,6 +327,33 @@ def execute_archive(profile, node, context): return result +def run_hooks(profile, hooks, context, source): + if type(hooks) not in (list, tuple): + hooks = [hooks] + + print('hooks') + print(hooks) + + ctx = { + "target": profile, + "state": "start", + "invocation_id": context['invocation_id'], + "run_started_at": context['run_started_at'] + } + + compiled_hooks = [ + dbt.compilation.compile_string(hook, ctx) for hook in hooks + ] + + adapter = get_adapter(profile) + + adapter.execute_all( + profile=profile, + queries=compiled_hooks, + model_name=source) + + adapter.commit(profile) + class RunModelResult(object): def __init__(self, node, error=None, skip=False, status=None, @@ -771,7 +800,8 @@ def skip_dependent(node): return skip_dependent - def execute_nodes(self, node_dependency_list, on_failure): + def execute_nodes(self, node_dependency_list, on_failure, + should_run_hooks=False): profile = self.project.run_environment() adapter = get_adapter(profile) schema_name = adapter.get_default_schema(profile) @@ -798,8 +828,11 @@ def execute_nodes(self, node_dependency_list, on_failure): logger.info("") print_counts(flat_nodes) - # TODO: re-add hooks - # runner.pre_run_all(flat_models, self.context) + if should_run_hooks: + run_hooks(self.project.get_target(), + self.project.cfg.get('on-run-start', []), + self.context, + 'on-run-start hooks') node_id_to_index_map = {node.get('unique_id'): i + 1 for (i, node) in enumerate(flat_nodes)} @@ -879,15 +912,21 @@ def on_complete(run_model_results): pool.close() pool.join() + if should_run_hooks: + run_hooks(self.project.get_target(), + self.project.cfg.get('on-run-end', []), + self.context, + 'on-run-end hooks') + logger.info("") logger.info("FIXME") # logger.info(runner.post_run_all_msg(model_results)) - # runner.post_run_all(flat_models, model_results, self.context) return node_results def get_nodes_to_run(self, graph, include_spec, exclude_spec, resource_types, tags): + if include_spec is None: include_spec = ['*'] @@ -896,12 +935,8 @@ def get_nodes_to_run(self, graph, include_spec, exclude_spec, to_run = [ n for n in graph.nodes() - if ((graph.node.get(n).get('resource_type') in resource_types) - and is_enabled(graph.node.get(n)) - and (len(tags) == 0 or - # does the node share any tags with the run? - bool(set(graph.node.get(n).get('tags')) & - set(tags)))) + if (graph.node.get(n).get('empty') == False + and is_enabled(graph.node.get(n))) ] filtered_graph = graph.subgraph(to_run) @@ -912,11 +947,15 @@ def get_nodes_to_run(self, graph, include_spec, exclude_spec, post_filter = [ n for n in selected_nodes - if (get_materialization(graph.node.get(n)) != 'ephemeral' and - graph.node.get(n).get('empty') == False) + if ((graph.node.get(n).get('resource_type') in resource_types) + and get_materialization(graph.node.get(n)) != 'ephemeral' + and (len(tags) == 0 or + # does the node share any tags with the run? + bool(set(graph.node.get(n).get('tags')) & + set(tags)))) ] - return post_filter + return set(post_filter) def get_compiled_models(self, linker, nodes, node_type): compiled_models = [] @@ -956,7 +995,7 @@ def try_create_schema(self): raise def run_types_from_graph(self, include_spec, exclude_spec, - resource_types, tags): + resource_types, tags, should_run_hooks=False): linker = self.deserialize_graph() selected_nodes = self.get_nodes_to_run( @@ -974,7 +1013,8 @@ def run_types_from_graph(self, include_spec, exclude_spec, on_failure = self.on_model_failure(linker, selected_nodes) - results = self.execute_nodes(dependency_list, on_failure) + results = self.execute_nodes(dependency_list, on_failure, + should_run_hooks) return results @@ -983,8 +1023,9 @@ def run_types_from_graph(self, include_spec, exclude_spec, def run_models(self, include_spec, exclude_spec): return self.run_types_from_graph(include_spec, exclude_spec, - [NodeType.Model], - []) + resource_types=[NodeType.Model], + tags=[], + should_run_hooks=True) def run_tests(self, include_spec, exclude_spec, tags): return self.run_types_from_graph(include_spec, diff --git a/test/integration/008_schema_tests_test/test_schema_tests.py b/test/integration/008_schema_tests_test/test_schema_tests.py index d53e8aeb5e5..ba679796e8c 100644 --- a/test/integration/008_schema_tests_test/test_schema_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_tests.py @@ -35,7 +35,7 @@ def test_schema_tests(self): for result in test_results: # assert that all deliberately failing tests actually fail - if 'failure' in result.model.name: + if 'failure' in result.node.get('name'): self.assertFalse(result.errored) self.assertFalse(result.skipped) self.assertTrue(result.status > 0) @@ -75,4 +75,4 @@ def test_malformed_schema_test_wont_brick_run(self): self.run_dbt() ran_tests = self.run_schema_validations() - self.assertEqual(ran_tests, []) + self.assertEqual(len(ran_tests), 2) diff --git a/test/integration/010_permission_tests/seed.sql b/test/integration/010_permission_tests/seed.sql index 50ae457a701..6a3e0e6cf46 100644 --- a/test/integration/010_permission_tests/seed.sql +++ b/test/integration/010_permission_tests/seed.sql @@ -1,6 +1,6 @@ -create schema private; +create schema private_010; -create table private.seed ( +create table private_010.seed ( id BIGSERIAL PRIMARY KEY, first_name VARCHAR(50), last_name VARCHAR(50), @@ -10,13 +10,13 @@ create table private.seed ( ); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Kathryn', 'Walker', 'kwalker1@ezinearticles.com', 'Female', '194.121.179.35'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Gerald', 'Ryan', 'gryan2@com.com', 'Male', '11.3.212.243'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Bonnie', 'Spencer', 'bspencer3@ameblo.jp', 'Female', '216.32.196.175'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Harold', 'Taylor', 'htaylor4@people.com.cn', 'Male', '253.10.246.136'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Jacqueline', 'Griffin', 'jgriffin5@t.co', 'Female', '16.13.192.220'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Wanda', 'Arnold', 'warnold6@google.nl', 'Female', '232.116.150.64'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Craig', 'Ortiz', 'cortiz7@sciencedaily.com', 'Male', '199.126.106.13'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Gary', 'Day', 'gday8@nih.gov', 'Male', '35.81.68.186'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Rose', 'Wright', 'rwright9@yahoo.co.jp', 'Female', '236.82.178.100'); -insert into private.seed (first_name, last_name, email, gender, ip_address) values ('Raymond', 'Kelley', 'rkelleya@fc2.com', 'Male', '213.65.166.67'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Kathryn', 'Walker', 'kwalker1@ezinearticles.com', 'Female', '194.121.179.35'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Gerald', 'Ryan', 'gryan2@com.com', 'Male', '11.3.212.243'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Bonnie', 'Spencer', 'bspencer3@ameblo.jp', 'Female', '216.32.196.175'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Harold', 'Taylor', 'htaylor4@people.com.cn', 'Male', '253.10.246.136'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Jacqueline', 'Griffin', 'jgriffin5@t.co', 'Female', '16.13.192.220'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Wanda', 'Arnold', 'warnold6@google.nl', 'Female', '232.116.150.64'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Craig', 'Ortiz', 'cortiz7@sciencedaily.com', 'Male', '199.126.106.13'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Gary', 'Day', 'gday8@nih.gov', 'Male', '35.81.68.186'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Rose', 'Wright', 'rwright9@yahoo.co.jp', 'Female', '236.82.178.100'); +insert into private_010.seed (first_name, last_name, email, gender, ip_address) values ('Raymond', 'Kelley', 'rkelleya@fc2.com', 'Male', '213.65.166.67'); diff --git a/test/integration/010_permission_tests/tearDown.sql b/test/integration/010_permission_tests/tearDown.sql index 4da20cda0bf..f7125ff7824 100644 --- a/test/integration/010_permission_tests/tearDown.sql +++ b/test/integration/010_permission_tests/tearDown.sql @@ -1,2 +1,2 @@ -drop schema if exists private cascade; +drop schema if exists private_010 cascade; diff --git a/test/integration/010_permission_tests/test_permissions.py b/test/integration/010_permission_tests/test_permissions.py index a322457e596..1a01464e205 100644 --- a/test/integration/010_permission_tests/test_permissions.py +++ b/test/integration/010_permission_tests/test_permissions.py @@ -6,13 +6,14 @@ class TestPermissions(DBTIntegrationTest): def setUp(self): DBTIntegrationTest.setUp(self) + self.run_sql_file("test/integration/010_permission_tests/tearDown.sql") self.run_sql_file("test/integration/010_permission_tests/seed.sql") def tearDown(self): - DBTIntegrationTest.tearDown(self) - self.run_sql_file("test/integration/010_permission_tests/tearDown.sql") + DBTIntegrationTest.tearDown(self) + @property def schema(self): return "permission_tests_010" diff --git a/test/integration/013_context_var_tests/test_context_vars.py b/test/integration/013_context_var_tests/test_context_vars.py index 514a0b54540..ff102dc9be7 100644 --- a/test/integration/013_context_var_tests/test_context_vars.py +++ b/test/integration/013_context_var_tests/test_context_vars.py @@ -1,6 +1,8 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest +import dbt.flags + class TestContextVars(DBTIntegrationTest): def setUp(self): diff --git a/test/integration/015_cli_invocation_tests/test_cli_invocation.py b/test/integration/015_cli_invocation_tests/test_cli_invocation.py index e7782cbb795..d7d4ecbbc69 100644 --- a/test/integration/015_cli_invocation_tests/test_cli_invocation.py +++ b/test/integration/015_cli_invocation_tests/test_cli_invocation.py @@ -97,4 +97,5 @@ def test_toplevel_dbt_run_with_profile_dir_arg(self): # make sure the test runs against `custom_schema` for test_result in res: - self.assertTrue(self.custom_schema, test_result.model.compiled_contents) + self.assertTrue(self.custom_schema, + test_result.node.get('wrapped_sql')) From f34892f89e84ca915426006f837e9b365a447503 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 26 Feb 2017 17:54:20 -0500 Subject: [PATCH 08/25] remove runners (they are unused now) --- dbt/runner.py | 287 -------------------------------------------------- 1 file changed, 287 deletions(-) diff --git a/dbt/runner.py b/dbt/runner.py index 795fc347a0a..d2ac6271512 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -373,293 +373,6 @@ def skipped(self): return self.skip -class BaseRunner(object): - def __init__(self, project): - self.project = project - - self.profile = project.run_environment() - self.adapter = get_adapter(self.profile) - - def pre_run_msg(self, model): - raise NotImplementedError("not implemented") - - def skip_msg(self, model): - return "SKIP relation {}.{}".format( - self.adapter.get_default_schema(self.profile), model.name) - - def post_run_msg(self, result): - raise NotImplementedError("not implemented") - - def pre_run_all_msg(self, models): - raise NotImplementedError("not implemented") - - def post_run_all_msg(self, results): - raise NotImplementedError("not implemented") - - def post_run_all(self, models, results, context): - pass - - def pre_run_all(self, models, context): - pass - - def status(self, result): - raise NotImplementedError("not implemented") - - -class ModelRunner(BaseRunner): - run_type = dbt.model.NodeType.Model - - def pre_run_msg(self, model): - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - "model_type": model.materialization, - "info": "START" - } - - output = ("START {model_type} model {schema}.{model_name} " - .format(**print_vars)) - return output - - def post_run_msg(self, result): - model = result.node - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - "model_type": model.materialization, - "info": "ERROR creating" if result.errored else "OK created" - } - - output = ("{info} {model_type} model {schema}.{model_name} " - .format(**print_vars)) - return output - - def pre_run_all_msg(self, models): - return "{} Running {} models".format(get_timestamp(), len(models)) - - def post_run_all_msg(self, results): - return ("{} Finished running {} models" - .format(get_timestamp(), len(results))) - - def status(self, result): - return result.status - - def is_non_destructive(self): - if hasattr(self.project.args, 'non_destructive'): - return self.project.args.non_destructive - else: - return False - - def execute(self, model): - profile = self.project.run_environment() - adapter = get_adapter(profile) - - if model.tmp_drop_type is not None: - if model.materialization == 'table' and self.is_non_destructive(): - adapter.truncate( - profile=profile, - table=model.tmp_name, - model_name=model.name) - else: - adapter.drop( - profile=profile, - relation=model.tmp_name, - relation_type=model.tmp_drop_type, - model_name=model.name) - - status = adapter.execute_model( - profile=profile, - model=model) - - if model.final_drop_type is not None: - if model.materialization == 'table' and self.is_non_destructive(): - # we just inserted into this recently truncated table... - # do nothing here - pass - else: - adapter.drop( - profile=profile, - relation=model.name, - relation_type=model.final_drop_type, - model_name=model.name) - - if model.should_rename(self.project.args): - adapter.rename( - profile=profile, - from_name=model.tmp_name, - to_name=model.name, - model_name=model.name) - - adapter.commit( - profile=profile) - - return status - - def __run_hooks(self, hooks, context, source): - if type(hooks) not in (list, tuple): - hooks = [hooks] - - target = self.project.get_target() - - ctx = { - "target": target, - "state": "start", - "invocation_id": context['invocation_id'], - "run_started_at": context['run_started_at'] - } - - compiled_hooks = [ - dbt.compilation.compile_string(hook, ctx) for hook in hooks - ] - - profile = self.project.run_environment() - adapter = get_adapter(profile) - - adapter.execute_all( - profile=profile, - queries=compiled_hooks, - model_name=source) - - adapter.commit(profile) - - def pre_run_all(self, models, context): - hooks = self.project.cfg.get('on-run-start', []) - self.__run_hooks(hooks, context, 'on-run-start hooks') - - def post_run_all(self, models, results, context): - hooks = self.project.cfg.get('on-run-end', []) - self.__run_hooks(hooks, context, 'on-run-end hooks') - - -class TestRunner(ModelRunner): - run_type = dbt.model.NodeType.Test - - test_data_type = dbt.model.TestNodeType.DataTest - test_schema_type = dbt.model.TestNodeType.SchemaTest - - def pre_run_msg(self, model): - if model.is_test_type(self.test_data_type): - return "DATA TEST {name} ".format(name=model.name) - else: - return "SCHEMA TEST {name} ".format(name=model.name) - - def post_run_msg(self, result): - model = result.model - info = self.status(result) - - return "{info} {name} ".format(info=info, name=model.name) - - def pre_run_all_msg(self, models): - return "{} Running {} tests".format(get_timestamp(), len(models)) - - def post_run_all_msg(self, results): - total = len(results) - passed = len([result for result in results if not - result.errored and not result.skipped and - result.status == 0]) - failed = len([result for result in results if not - result.errored and not result.skipped and - result.status > 0]) - errored = len([result for result in results if result.errored]) - skipped = len([result for result in results if result.skipped]) - - total_errors = failed + errored - - overview = ("PASS={passed} FAIL={total_errors} SKIP={skipped} " - "TOTAL={total}".format( - total=total, - passed=passed, - total_errors=total_errors, - skipped=skipped)) - - if total_errors > 0: - final = "Tests completed with errors" - else: - final = "All tests passed" - - return "\n{overview}\n{final}".format(overview=overview, final=final) - - def status(self, result): - if result.errored: - info = "ERROR" - elif result.status > 0: - info = 'FAIL {}'.format(result.status) - elif result.status == 0: - info = 'PASS' - else: - raise RuntimeError("unexpected status: {}".format(result.status)) - - return info - - def execute(self, model): - profile = self.project.run_environment() - adapter = get_adapter(profile) - - _, cursor = adapter.execute_one( - profile, model.compiled_contents, model.name) - rows = cursor.fetchall() - - cursor.close() - - if len(rows) > 1: - raise RuntimeError( - "Bad test {name}: Returned {num_rows} rows instead of 1" - .format(name=model.name, num_rows=len(rows))) - - row = rows[0] - if len(row) > 1: - raise RuntimeError( - "Bad test {name}: Returned {num_cols} cols instead of 1" - .format(name=model.name, num_cols=len(row))) - - return row[0] - - -class ArchiveRunner(BaseRunner): - run_type = dbt.model.NodeType.Archive - - def pre_run_msg(self, model): - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - } - - output = ("START archive table {schema}.{model_name} " - .format(**print_vars)) - return output - - def post_run_msg(self, result): - model = result.model - print_vars = { - "schema": self.adapter.get_default_schema(self.profile), - "model_name": model.name, - "info": "ERROR archiving" if result.errored else "OK created" - } - - output = "{info} table {schema}.{model_name} ".format(**print_vars) - return output - - def pre_run_all_msg(self, models): - return "Archiving {} tables".format(len(models)) - - def post_run_all_msg(self, results): - return ("{} Finished archiving {} tables" - .format(get_timestamp(), len(results))) - - def status(self, result): - return result.status - - def execute(self, model): - profile = self.project.run_environment() - adapter = get_adapter(profile) - - status = adapter.execute_model( - profile=profile, - model=model) - - return status - - class RunManager(object): def __init__(self, project, target_path, args): self.project = project From 1a2ada75b20fc1a4a6429357de587b8e0c7fd1eb Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 26 Feb 2017 17:56:22 -0500 Subject: [PATCH 09/25] remove get_compiled_models -- unused --- dbt/runner.py | 27 --------------------------- 1 file changed, 27 deletions(-) diff --git a/dbt/runner.py b/dbt/runner.py index d2ac6271512..6059b76894a 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -437,7 +437,6 @@ def deserialize_graph(self): return dbt.linker.from_file(graph_file) - def execute_node(self, node): profile = self.project.run_environment() @@ -454,7 +453,6 @@ def execute_node(self, node): return result - def safe_execute_node(self, data): node = data @@ -490,7 +488,6 @@ def safe_execute_node(self, data): status=status, execution_time=execution_time) - def as_concurrent_dep_list(self, linker, nodes_to_run): dependency_list = linker.as_dependency_list(nodes_to_run) @@ -501,7 +498,6 @@ def as_concurrent_dep_list(self, linker, nodes_to_run): return concurrent_dependency_list - def on_model_failure(self, linker, selected_nodes): def skip_dependent(node): dependent_nodes = linker.get_dependent_nodes(node.get('unique_id')) @@ -512,7 +508,6 @@ def skip_dependent(node): return skip_dependent - def execute_nodes(self, node_dependency_list, on_failure, should_run_hooks=False): profile = self.project.run_environment() @@ -670,28 +665,6 @@ def get_nodes_to_run(self, graph, include_spec, exclude_spec, return set(post_filter) - def get_compiled_models(self, linker, nodes, node_type): - compiled_models = [] - - for fqn in nodes: - compiled_model = make_compiled_model(fqn, linker.get_node(fqn)) - - if not compiled_model.is_type(node_type): - continue - - if not compiled_model.should_execute(self.args, - self.existing_models): - continue - - context = self.context.copy() - context.update(compiled_model.context()) - - profile = self.project.run_environment() - compiled_model.compile(context, profile, self.existing_models) - compiled_models.append(compiled_model) - - return compiled_models - def try_create_schema(self): profile = self.project.run_environment() adapter = get_adapter(profile) From c7dd77604bcaef510296f12b6e76fcb9c1300d29 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 26 Feb 2017 18:05:42 -0500 Subject: [PATCH 10/25] ripping things out, part 1: compiled_model.py --- dbt/compilation.py | 31 -------- dbt/compiled_model.py | 178 ------------------------------------------ dbt/model.py | 78 ------------------ dbt/runner.py | 1 - dbt/source.py | 25 +----- 5 files changed, 1 insertion(+), 312 deletions(-) delete mode 100644 dbt/compiled_model.py diff --git a/dbt/compilation.py b/dbt/compilation.py index dbd9b453aef..e222318821a 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -146,26 +146,16 @@ def initialize(self): if not os.path.exists(self.project['modules-path']): os.makedirs(self.project['modules-path']) - def get_macros(self, this_project, own_project=None): if own_project is None: own_project = this_project paths = own_project.get('macro-paths', []) return Source(this_project, own_project=own_project).get_macros(paths) - - def get_archives(self, project): - return Source( - project, - own_project=project - ).get_archives() - - def analysis_sources(self, project): paths = project.get('analysis-paths', []) return Source(project).get_analyses(paths) - def __write(self, build_filepath, payload): target_path = os.path.join(self.project['target-path'], build_filepath) @@ -181,7 +171,6 @@ def do_config(*args, **kwargs): return do_config - def __ref(self, ctx, model, all_models): schema = ctx.get('env', {}).get('schema') @@ -342,7 +331,6 @@ def write_graph_file(self, linker): graph_path = os.path.join(self.project['target-path'], filename) linker.write_graph(graph_path) - def new_add_cte_to_rendered_query(self, linker, primary_model, compiled_models): @@ -482,7 +470,6 @@ def compile_analyses(self, linker, compiled_models): return written_analyses - def generate_macros(self, all_macros): def do_gen(ctx): macros = [] @@ -492,19 +479,6 @@ def do_gen(ctx): return macros return do_gen - - def compile_archives(self, linker, compiled_models): - all_archives = self.get_archives(self.project) - - for archive in all_archives: - sql = archive.compile() - fqn = tuple(archive.fqn) - linker.update_node_data(fqn, archive.serialize()) - self.__write(archive.build_path(), sql) - - return all_archives - - def get_all_projects(self): root_project = self.project.cfg all_projects = {root_project.get('name'): root_project} @@ -519,7 +493,6 @@ def get_all_projects(self): return all_projects - def get_parsed_models(self, root_project, all_projects, macro_generator): parsed_models = {} @@ -536,7 +509,6 @@ def get_parsed_models(self, root_project, all_projects, macro_generator): return parsed_models - def get_parsed_data_tests(self, root_project, all_projects, macro_generator): parsed_tests = {} @@ -555,7 +527,6 @@ def get_parsed_data_tests(self, root_project, all_projects, return parsed_tests - def get_parsed_schema_tests(self, root_project, all_projects): parsed_tests = {} @@ -570,7 +541,6 @@ def get_parsed_schema_tests(self, root_project, all_projects): return parsed_tests - def load_all_nodes(self, root_project, all_projects, macro_generator): all_nodes = {} @@ -587,7 +557,6 @@ def load_all_nodes(self, root_project, all_projects, macro_generator): return all_nodes - def compile(self): linker = Linker() diff --git a/dbt/compiled_model.py b/dbt/compiled_model.py deleted file mode 100644 index bea4d29d072..00000000000 --- a/dbt/compiled_model.py +++ /dev/null @@ -1,178 +0,0 @@ -import hashlib -import jinja2 -from dbt.utils import compiler_error, to_unicode -from dbt.adapters.factory import get_adapter -import dbt.model - - -class CompiledModel(object): - def __init__(self, fqn, data): - self.fqn = fqn - self.data = data - self.nice_name = ".".join(fqn) - - # these are set just before the models are executed - self.tmp_drop_type = None - self.final_drop_type = None - self.profile = None - - self.skip = False - self._contents = None - self.compiled_contents = None - - def __getitem__(self, key): - return self.data[key] - - def hashed_name(self): - fqn_string = ".".join(self.fqn) - return hashlib.md5(fqn_string.encode('utf-8')).hexdigest() - - def context(self): - return self.data - - def hashed_contents(self): - return hashlib.md5(self.contents.encode('utf-8')).hexdigest() - - def do_skip(self): - self.skip = True - - def should_skip(self): - return self.skip - - def is_type(self, run_type): - return self.data['dbt_run_type'] == run_type - - def is_test_type(self, test_type): - return self.data.get('dbt_test_type') == test_type - - def is_test(self): - return self.data['dbt_run_type'] == dbt.model.NodeType.Test - - @property - def contents(self): - if self._contents is None: - with open(self.data['build_path']) as fh: - self._contents = to_unicode(fh.read(), 'utf-8') - - return self._contents - - def compile(self, context, profile, existing): - self.prepare(existing, profile) - - contents = self.contents - try: - env = jinja2.Environment() - self.compiled_contents = env.from_string(contents).render(context) - return self.compiled_contents - except jinja2.exceptions.TemplateSyntaxError as e: - compiler_error(self, str(e)) - - @property - def materialization(self): - return self.data['materialized'] - - @property - def name(self): - return self.data['name'] - - @property - def tmp_name(self): - return self.data['tmp_name'] - - def project(self): - return {'name': self.data['project_name']} - - @property - def schema(self): - if self.profile is None: - raise RuntimeError( - "`profile` not set in compiled model {}".format(self) - ) - else: - return get_adapter(self.profile).get_default_schema(self.profile) - - def should_execute(self, args, existing): - if args.non_destructive and \ - self.materialization == 'view' and \ - self.name in existing: - - return False - else: - return self.data['enabled'] and self.materialization != 'ephemeral' - - def should_rename(self, args): - if args.non_destructive and self.materialization == 'table': - return False - else: - return self.materialization in ['table', 'view'] - - def prepare(self, existing, profile): - if self.materialization == 'incremental': - tmp_drop_type = None - final_drop_type = None - else: - tmp_drop_type = existing.get(self.tmp_name, None) - final_drop_type = existing.get(self.name, None) - - self.tmp_drop_type = tmp_drop_type - self.final_drop_type = final_drop_type - self.profile = profile - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -class CompiledTest(CompiledModel): - def __init__(self, fqn, data): - super(CompiledTest, self).__init__(fqn, data) - - def should_rename(self): - return False - - def should_execute(self, args, existing): - return True - - def prepare(self, existing, profile): - self.profile = profile - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -class CompiledArchive(CompiledModel): - def __init__(self, fqn, data): - super(CompiledArchive, self).__init__(fqn, data) - - def should_rename(self): - return False - - def should_execute(self, args, existing): - return True - - def prepare(self, existing, profile): - self.profile = profile - - def __repr__(self): - return "".format( - self.data['project_name'], self.name, self.data['build_path'] - ) - - -def make_compiled_model(fqn, data): - run_type = data['dbt_run_type'] - - if run_type == dbt.model.NodeType.Model: - return CompiledModel(fqn, data) - - elif run_type == dbt.model.NodeType.Test: - return CompiledTest(fqn, data) - - elif run_type == dbt.model.NodeType.Archive: - return CompiledArchive(fqn, data) - - else: - raise RuntimeError("invalid run_type given: {}".format(run_type)) diff --git a/dbt/model.py b/dbt/model.py index 093462af2bc..acb101b0217 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -757,84 +757,6 @@ def __repr__(self): ) -class ArchiveModel(DBTSource): - dbt_run_type = NodeType.Archive - build_dir = 'archive' - template = ArchiveInsertTemplate() - - def __init__(self, project, archive_data): - - self.validate(archive_data) - - self.source_schema = archive_data['source_schema'] - self.target_schema = archive_data['target_schema'] - self.source_table = archive_data['source_table'] - self.target_table = archive_data['target_table'] - self.unique_key = archive_data['unique_key'] - self.updated_at = archive_data['updated_at'] - - rel_filepath = os.path.join(self.target_schema, self.target_table) - - super(ArchiveModel, self).__init__( - project, self.build_dir, rel_filepath, project - ) - - def validate(self, data): - required = [ - 'source_schema', - 'target_schema', - 'source_table', - 'target_table', - 'unique_key', - 'updated_at', - ] - - for key in required: - if data.get(key, None) is None: - compiler_error( - "Invalid archive config: missing required field '{}'" - .format(key) - ) - - def serialize(self): - data = DBTSource.serialize(self).copy() - - serialized = { - "source_schema": self.source_schema, - "target_schema": self.target_schema, - "source_table": self.source_table, - "target_table": self.target_table, - "unique_key": self.unique_key, - "updated_at": self.updated_at - } - - data.update(serialized) - return data - - def compile(self): - archival = dbt.archival.Archival(self.project, self) - query = archival.compile() - - sql = self.template.wrap( - self.target_schema, self.target_table, query, self.unique_key - ) - - return sql - - def build_path(self): - filename = "{}.sql".format(self.name) - path_parts = [self.build_dir] + self.fqn[:-1] + [filename] - return os.path.join(*path_parts) - - def __repr__(self): - return " {} unique:{} updated_at:{}>".format( - self.source_table, - self.target_table, - self.unique_key, - self.updated_at - ) - - class DataTest(DBTSource): dbt_run_type = NodeType.Test dbt_test_type = TestNodeType.DataTest diff --git a/dbt/runner.py b/dbt/runner.py index 6059b76894a..b59f625934e 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -19,7 +19,6 @@ from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ dependency_projects -from dbt.compiled_model import make_compiled_model from dbt.model import NodeType import dbt.compilation diff --git a/dbt/source.py b/dbt/source.py index afb7792fd5f..30a900de140 100644 --- a/dbt/source.py +++ b/dbt/source.py @@ -1,7 +1,7 @@ import os.path import fnmatch from dbt.model import Model, Analysis, SchemaFile, Csv, Macro, \ - ArchiveModel, DataTest + DataTest import dbt.clients.system @@ -90,26 +90,3 @@ def get_macros(self, macro_dirs): return self.build_models_from_file_matches( Macro, file_matches) - - def get_archives(self): - "Get Archive models defined in project config" - - if 'archive' not in self.project: - return [] - - raw_source_schemas = self.project['archive'] - - archives = [] - for schema in raw_source_schemas: - schema = schema.copy() - if 'tables' not in schema: - continue - - tables = schema.pop('tables') - for table in tables: - fields = table.copy() - fields.update(schema) - archives.append(ArchiveModel( - self.project, fields - )) - return archives From 8dc998b979ec896bd7f9ff54bf061bcd8986f7a5 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Sun, 26 Feb 2017 18:11:51 -0500 Subject: [PATCH 11/25] ripping stuff out, part 2: archival and other unused model types --- dbt/archival.py | 70 ------------ dbt/compilation.py | 33 ------ dbt/model.py | 262 --------------------------------------------- dbt/source.py | 33 +----- 4 files changed, 1 insertion(+), 397 deletions(-) delete mode 100644 dbt/archival.py diff --git a/dbt/archival.py b/dbt/archival.py deleted file mode 100644 index 74245922cf1..00000000000 --- a/dbt/archival.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import print_function -import dbt.schema -import dbt.templates -import jinja2 - -from dbt.adapters.factory import get_adapter - - -class Archival(object): - - def __init__(self, project, archive_model): - self.archive_model = archive_model - self.project = project - - def compile(self): - source_schema = self.archive_model.source_schema - target_schema = self.archive_model.target_schema - source_table = self.archive_model.source_table - target_table = self.archive_model.target_table - unique_key = self.archive_model.unique_key - updated_at = self.archive_model.updated_at - - profile = self.project.run_environment() - adapter = get_adapter(profile) - - adapter.create_schema(profile, target_schema) - - source_columns = adapter.get_columns_in_table( - profile, source_schema, source_table) - - if len(source_columns) == 0: - raise RuntimeError( - 'Source table "{}"."{}" does not ' - 'exist'.format(source_schema, source_table)) - - extra_cols = [ - dbt.schema.Column("valid_from", "timestamp", None), - dbt.schema.Column("valid_to", "timestamp", None), - dbt.schema.Column("scd_id", "text", None), - dbt.schema.Column("dbt_updated_at", "timestamp", None) - ] - - dest_columns = source_columns + extra_cols - - adapter.create_table( - profile, - target_schema, - target_table, - dest_columns, - sort=updated_at, - dist=unique_key - ) - - env = jinja2.Environment() - - ctx = { - "columns": source_columns, - "updated_at": updated_at, - "unique_key": unique_key, - "source_schema": source_schema, - "source_table": source_table, - "target_schema": target_schema, - "target_table": target_table - } - - base_query = dbt.templates.SCDArchiveTemplate - template = env.from_string(base_query, globals=ctx) - rendered = template.render(ctx) - - return rendered diff --git a/dbt/compilation.py b/dbt/compilation.py index e222318821a..5bc17ff050c 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -152,10 +152,6 @@ def get_macros(self, this_project, own_project=None): paths = own_project.get('macro-paths', []) return Source(this_project, own_project=own_project).get_macros(paths) - def analysis_sources(self, project): - paths = project.get('analysis-paths', []) - return Source(project).get_analyses(paths) - def __write(self, build_filepath, payload): target_path = os.path.join(self.project['target-path'], build_filepath) @@ -441,35 +437,6 @@ def compile_nodes(self, linker, nodes, macro_generator): return wrapped_nodes, written_nodes - - def compile_analyses(self, linker, compiled_models): - analyses = self.analysis_sources(self.project) - compiled_analyses = { - analysis: self.compile_model( - linker, analysis, compiled_models - ) for analysis in analyses - } - - written_analyses = [] - referenceable_models = {} - referenceable_models.update(compiled_models) - referenceable_models.update(compiled_analyses) - for analysis in analyses: - injected_stmt = self.add_cte_to_rendered_query( - linker, - analysis, - referenceable_models - ) - - serialized = analysis.serialize() - linker.update_node_data(tuple(analysis.fqn), serialized) - - build_path = analysis.build_path() - self.__write(build_path, injected_stmt) - written_analyses.append(analysis) - - return written_analyses - def generate_macros(self, all_macros): def do_gen(ctx): macros = [] diff --git a/dbt/model.py b/dbt/model.py index acb101b0217..a4e303dcb01 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -18,12 +18,6 @@ class NodeType(object): Model = 'model' Test = 'test' Archive = 'archive' - Analysis = 'analysis' - - -class TestNodeType(object): - SchemaTest = 'schema' - DataTest = 'data' class SourceConfig(object): @@ -503,224 +497,6 @@ def __repr__(self): ) -class Analysis(Model): - dbt_run_type = NodeType.Analysis - - def __init__(self, project, target_dir, rel_filepath, own_project): - return super(Analysis, self).__init__( - project, - target_dir, - rel_filepath, - own_project - ) - - def build_path(self): - build_dir = 'build-analysis' - filename = "{}.sql".format(self.name) - path_parts = [build_dir] + self.fqn[:-1] + [filename] - return os.path.join(*path_parts) - - def __repr__(self): - return "".format(self.name, self.filepath) - - -class SchemaTest(DBTSource): - test_type = "base" - dbt_run_type = NodeType.Test - dbt_test_type = TestNodeType.SchemaTest - - def __init__(self, project, target_dir, rel_filepath, model_name, options): - self.schema = project.context()['env']['schema'] - self.model_name = model_name - self.options = options - self.params = self.get_params(options) - - super(SchemaTest, self).__init__( - project, target_dir, rel_filepath, project - ) - - @property - def fqn(self): - parts = split_path(self.filepath) - name, _ = os.path.splitext(parts[-1]) - return [self.project['name']] + parts[1:-1] + \ - ['schema', self.get_filename()] - - def serialize(self): - serialized = DBTSource.serialize(self).copy() - serialized['dbt_test_type'] = self.dbt_test_type - - return serialized - - def get_params(self, options): - return { - "schema": self.schema, - "table": self.model_name, - "field": options - } - - def unique_option_key(self): - return self.params - - def get_filename(self): - key = re.sub('[^0-9a-zA-Z]+', '_', self.unique_option_key()) - filename = "{test_type}_{model_name}_{key}".format( - test_type=self.test_type, model_name=self.model_name, key=key - ) - return filename - - def build_path(self): - build_dir = "test" - filename = "{}.sql".format(self.get_filename()) - path_parts = [build_dir] + self.fqn[:-1] + [filename] - return os.path.join(*path_parts) - - @property - def template(self): - raise NotImplementedError("not implemented") - - def render(self): - return self.template.format(**self.params) - - def __repr__(self): - class_name = self.__class__.__name__ - return "<{} {}.{}: {}>".format( - class_name, self.project['name'], self.name, self.filepath - ) - - -class NotNullSchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_NOT_NULL - test_type = "not_null" - - def unique_option_key(self): - return self.params['field'] - - def describe(self): - return 'VALIDATE NOT NULL {schema}.{table}.{field}' \ - .format(**self.params) - - -class UniqueSchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_UNIQUE - test_type = "unique" - - def unique_option_key(self): - return self.params['field'] - - def describe(self): - return 'VALIDATE UNIQUE {schema}.{table}.{field}'.format(**self.params) - - -class ReferentialIntegritySchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_REFERENTIAL_INTEGRITY - test_type = "relationships" - - def get_params(self, options): - return { - "schema": self.schema, - "child_table": self.model_name, - "child_field": options['from'], - "parent_table": options['to'], - "parent_field": options['field'] - } - - def unique_option_key(self): - return "{child_field}_to_{parent_table}_{parent_field}" \ - .format(**self.params) - - def describe(self): - return """VALIDATE REFERENTIAL INTEGRITY - {schema}.{child_table}.{child_field} to - {schema}.{parent_table}.{parent_field}""".format(**self.params) - - -class AcceptedValuesSchemaTest(SchemaTest): - template = dbt.schema_tester.QUERY_VALIDATE_ACCEPTED_VALUES - test_type = "accepted_values" - - def get_params(self, options): - quoted_values = ["'{}'".format(v) for v in options['values']] - quoted_values_csv = ",".join(quoted_values) - return { - "schema": self.schema, - "table": self.model_name, - "field": options['field'], - "values_csv": quoted_values_csv - } - - def unique_option_key(self): - return "{field}".format(**self.params) - - def describe(self): - return """VALIDATE ACCEPTED VALUES - {schema}.{table}.{field} VALUES - ({values_csv})""".format(**self.params) - - -class SchemaFile(DBTSource): - SchemaTestMap = { - 'not_null': NotNullSchemaTest, - 'unique': UniqueSchemaTest, - 'relationships': ReferentialIntegritySchemaTest, - 'accepted_values': AcceptedValuesSchemaTest - } - - def __init__(self, project, target_dir, rel_filepath, own_project): - super(SchemaFile, self).__init__( - project, target_dir, rel_filepath, own_project - ) - self.og_target_dir = target_dir - self.schema = yaml.safe_load(self.contents) - - def get_test(self, test_type): - if test_type in SchemaFile.SchemaTestMap: - return SchemaFile.SchemaTestMap[test_type] - else: - possible_types = ", ".join(SchemaFile.SchemaTestMap.keys()) - compiler_error( - self, - "Invalid validation type given in {}: '{}'. Possible: {}" - .format(self.filepath, test_type, possible_types) - ) - - def do_compile(self): - schema_tests = [] - for model_name, constraint_blob in self.schema.items(): - constraints = constraint_blob.get('constraints', {}) - for constraint_type, constraint_data in constraints.items(): - if constraint_data is None: - compiler_error( - self, - "no constraints given to test: '{}.{}'" - .format(model_name, constraint_type) - ) - for params in constraint_data: - schema_test_klass = self.get_test(constraint_type) - schema_test = schema_test_klass( - self.project, - self.og_target_dir, - self.rel_filepath, - model_name, - params - ) - schema_tests.append(schema_test) - return schema_tests - - def compile(self): - try: - return self.do_compile() - except TypeError as e: - compiler_error(self, str(e)) - except AttributeError as e: - compiler_error(self, str(e)) - - def __repr__(self): - return "".format( - self.project['name'], self.model_name, self.filepath - ) - - class Csv(DBTSource): def __init__(self, project, target_dir, rel_filepath, own_project): super(Csv, self).__init__( @@ -755,41 +531,3 @@ def __repr__(self): return "".format( self.project['name'], self.name, self.filepath ) - - -class DataTest(DBTSource): - dbt_run_type = NodeType.Test - dbt_test_type = TestNodeType.DataTest - - def __init__(self, project, target_dir, rel_filepath, own_project): - super(DataTest, self).__init__( - project, - target_dir, - rel_filepath, - own_project - ) - - def build_path(self): - build_dir = "test" - filename = "{}.sql".format(self.name) - fqn_parts = self.fqn[0:1] + ['data'] + self.fqn[1:-1] - path_parts = [build_dir] + fqn_parts + [filename] - return os.path.join(*path_parts) - - def serialize(self): - serialized = DBTSource.serialize(self).copy() - serialized['dbt_test_type'] = self.dbt_test_type - - return serialized - - def render(self, query): - return "select count(*) from (\n{}\n) sbq".format(query) - - @property - def immediate_name(self): - return self.name - - def __repr__(self): - return "".format( - self.project['name'], self.name, self.filepath - ) diff --git a/dbt/source.py b/dbt/source.py index 30a900de140..ac8747e9f78 100644 --- a/dbt/source.py +++ b/dbt/source.py @@ -1,7 +1,6 @@ import os.path import fnmatch -from dbt.model import Model, Analysis, SchemaFile, Csv, Macro, \ - DataTest +from dbt.model import Model, Csv, Macro import dbt.clients.system @@ -41,36 +40,6 @@ def get_models(self, model_dirs): Model, file_matches) - def get_analyses(self, analysis_dirs): - file_matches = dbt.clients.system.find_matching( - self.own_project_root, - analysis_dirs, - "[!.#~]*.sql") - - return self.build_models_from_file_matches( - Analysis, - file_matches) - - def get_schemas(self, schema_dirs): - file_matches = dbt.clients.system.find_matching( - self.own_project_root, - schema_dirs, - "[!.#~]*.yml") - - return self.build_models_from_file_matches( - SchemaFile, - file_matches) - - def get_tests(self, test_dirs): - file_matches = dbt.clients.system.find_matching( - self.own_project_root, - test_dirs, - "[!.#~]*.sql") - - return self.build_models_from_file_matches( - DataTest, - file_matches) - def get_csvs(self, csv_dirs): file_matches = dbt.clients.system.find_matching( self.own_project_root, From b3b17eeca0c97e9eb979d6cacb75754923768076 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 08:53:10 -0500 Subject: [PATCH 12/25] pep8 compliance --- dbt/compilation.py | 9 ++++----- dbt/contracts/common.py | 1 + dbt/contracts/graph/compiled.py | 2 ++ dbt/contracts/graph/parsed.py | 7 +++++-- dbt/contracts/project.py | 2 ++ dbt/graph/selector.py | 2 +- dbt/linker.py | 1 - dbt/main.py | 2 +- dbt/parser.py | 12 +++++++++--- dbt/runner.py | 32 ++++++++++++++++---------------- dbt/utils.py | 5 +++-- 11 files changed, 44 insertions(+), 31 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 5bc17ff050c..519ef47424e 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -58,7 +58,7 @@ def recursively_prepend_ctes(model, all_models): model = model.copy() prepend_ctes = [] - if model.get('all_ctes_injected') == True: + if model.get('all_ctes_injected') is True: return (model, model.get('extra_cte_ids'), all_models) for cte_id in model.get('extra_cte_ids'): @@ -200,7 +200,7 @@ def do_ref(*args): target_model_id = target_model.get('unique_id') if target_model.get('config', {}) \ - .get('enabled') == False: + .get('enabled') is False: compiler_error( model, "Model '{}' depends on model '{}' which is disabled in " @@ -223,14 +223,14 @@ def wrapped_do_ref(*args): except RuntimeError as e: logger.info("Compiler error in {}".format(model.get('path'))) logger.info("Enabled models:") - for n,m in all_models.items(): + for n, m in all_models.items(): logger.info(" - {}".format(".".join(m.get('fqn')))) raise e return wrapped_do_ref def get_compiler_context(self, linker, model, models, - macro_generator=None): + macro_generator=None): context = self.project.context() if macro_generator is not None: @@ -355,7 +355,6 @@ def new_add_cte_to_rendered_query(self, linker, primary_model, ) return compiled_query - def compile_nodes(self, linker, nodes, macro_generator): all_projects = self.get_all_projects() diff --git a/dbt/contracts/common.py b/dbt/contracts/common.py index 5d6c14b7ddf..4f68581b294 100644 --- a/dbt/contracts/common.py +++ b/dbt/contracts/common.py @@ -3,6 +3,7 @@ from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger + def validate_with(schema, data): try: schema(data) diff --git a/dbt/contracts/graph/compiled.py b/dbt/contracts/graph/compiled.py index 341818f26aa..2c6670c8d34 100644 --- a/dbt/contracts/graph/compiled.py +++ b/dbt/contracts/graph/compiled.py @@ -7,6 +7,7 @@ from dbt.contracts.common import validate_with from dbt.contracts.graph.parsed import parsed_graph_item_contract + compiled_graph_item_contract = parsed_graph_item_contract.extend({ # compiled fields Required('compiled'): bool, @@ -23,6 +24,7 @@ def validate_one(compiled_graph_item): validate_with(compiled_graph_item_contract, compiled_graph_item) + def validate(compiled_graph): for k, v in compiled_graph.items(): validate_with(compiled_graph_item_contract, v) diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index 04a7e340c46..2064c93d890 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -7,6 +7,7 @@ from dbt.contracts.common import validate_with from dbt.contracts.graph.unparsed import unparsed_graph_item_contract + config_contract = { Required('enabled'): bool, Required('materialized'): Any('table', 'view', 'ephemeral', 'incremental'), @@ -23,6 +24,7 @@ Optional('dist'): str, } + parsed_graph_item_contract = unparsed_graph_item_contract.extend({ # identifiers Required('unique_id'): All(str, Length(min=1, max=255)), @@ -35,6 +37,7 @@ Required('tags'): All(list, [str]), }) + def validate_one(parsed_graph_item): validate_with(parsed_graph_item_contract, parsed_graph_item) @@ -45,8 +48,8 @@ def validate_one(parsed_graph_item): parsed_graph_item.get('config', {}).get('sql_where') is None: raise ValidationException( 'missing `sql_where` for an incremental model') - elif materialization != 'incremental' and \ - parsed_graph_item.get('config', {}).get('sql_where') is not None: + elif (materialization != 'incremental' and + parsed_graph_item.get('config', {}).get('sql_where') is not None): raise ValidationException( 'invalid field `sql_where` for a non-incremental model') diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py index be4f54fc85f..5b5668eea47 100644 --- a/dbt/contracts/project.py +++ b/dbt/contracts/project.py @@ -10,8 +10,10 @@ projects_list_contract = Schema({str: project_contract}) + def validate(project): validate_with(project_contract, project) + def validate_list(projects): validate_with(projects_list_contract, projects) diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index 9ae42d9beb7..bd9510da6ad 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -131,7 +131,7 @@ def get_nodes_from_spec(project, graph, spec): # include tests that depend on this node. if we aren't running tests, # they'll be filtered out later. child_tests = [n for n in graph.successors(node) - if graph.node.get(n).get('resource_type') == \ + if graph.node.get(n).get('resource_type') == dbt.model.NodeType.Test] test_nodes.update(child_tests) diff --git a/dbt/linker.py b/dbt/linker.py index b548dd3330a..a56dc29777b 100644 --- a/dbt/linker.py +++ b/dbt/linker.py @@ -12,7 +12,6 @@ def from_file(graph_file): return linker - class Linker(object): def __init__(self, data=None): if data is None: diff --git a/dbt/main.py b/dbt/main.py index 565baa62bf0..1c1dd10bdef 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -176,7 +176,7 @@ def invoke_dbt(parsed): log_dir = proj.get('log-path', 'logs') if hasattr(proj.args, 'non_destructive') and \ - proj.args.non_destructive == True: + proj.args.non_destructive is True: flags.NON_DESTRUCTIVE = True else: flags.NON_DESTRUCTIVE = False diff --git a/dbt/parser.py b/dbt/parser.py index 388db393244..398cfba3d1e 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -81,15 +81,19 @@ def _fail_with_undefined_error(self, *args, **kwargs): def get_path(resource_type, package_name, resource_name): return "{}.{}.{}".format(resource_type, package_name, resource_name) + def get_model_path(package_name, resource_name): return get_path(NodeType.Model, package_name, resource_name) + def get_test_path(package_name, resource_name): return get_path(NodeType.Test, package_name, resource_name) + def get_macro_path(package_name, resource_name): return get_path('macros', package_name, resource_name) + def __ref(model): def ref(*args): @@ -114,16 +118,18 @@ def config(*args, **kwargs): return config + def get_fqn(path, package_project_config, extra=[]): parts = dbt.utils.split_path(path) name, _ = os.path.splitext(parts[-1]) fqn = ([package_project_config.get('name')] + - parts[1:-1] + - extra + - [name]) + parts[1:-1] + + extra + + [name]) return fqn + def parse_node(node, node_path, root_project_config, package_project_config, macro_generator=None, tags=[], fqn_extra=[]): parsed_node = copy.deepcopy(node) diff --git a/dbt/runner.py b/dbt/runner.py index b59f625934e..02ca83cfcfe 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -52,7 +52,7 @@ def get_hashed_contents(model): def is_enabled(model): - return model.get('config', {}).get('enabled') == True + return model.get('config', {}).get('enabled') is True def print_timestamped_line(msg): @@ -91,7 +91,7 @@ def print_counts(flat_nodes): counts[t] = counts.get(t, 0) + 1 for k, v in counts.items(): - print_timestamped_line("Running {} {}s".format(v,k)) + print_timestamped_line("Running {} {}s".format(v, k)) def print_start_line(node, schema_name, index, total): @@ -222,8 +222,8 @@ def execute_model(profile, model): # never drop existing relations in non destructive mode. pass - elif get_materialization(model) != 'incremental' and \ - existing.get(tmp_name) is not None: + elif (get_materialization(model) != 'incremental' and + existing.get(tmp_name) is not None): # otherwise, for non-incremental things, drop them with IF EXISTS adapter.drop( profile=profile, @@ -326,6 +326,7 @@ def execute_archive(profile, node, context): return result + def run_hooks(profile, hooks, context, source): if type(hooks) not in (list, tuple): hooks = [hooks] @@ -416,7 +417,6 @@ def call_table_exists(schema, table): "already_exists": call_table_exists, } - def inject_runtime_config(self, node): sql = dbt.compilation.compile_string(node.get('wrapped_sql'), self.context) @@ -476,7 +476,8 @@ def safe_execute_node(self, data): except Exception as e: error = ("Unhandled error while executing {filepath}\n{error}" .format( - filepath=node.get('build_path'), error=str(e).strip())) + filepath=node.get('build_path'), + error=str(e).strip())) logger.debug(error) raise e @@ -542,7 +543,7 @@ def execute_nodes(self, node_dependency_list, on_failure, 'on-run-start hooks') node_id_to_index_map = {node.get('unique_id'): i + 1 for (i, node) - in enumerate(flat_nodes)} + in enumerate(flat_nodes)} def get_idx(node): return node_id_to_index_map[node.get('unique_id')] @@ -550,7 +551,7 @@ def get_idx(node): node_results = [] for node_list in node_dependency_list: for i, node in enumerate([node for node in node_list - if node.get('skip')]): + if node.get('skip')]): print_skip_line( schema_name, node.get('name'), get_idx(node), num_nodes) @@ -642,8 +643,8 @@ def get_nodes_to_run(self, graph, include_spec, exclude_spec, to_run = [ n for n in graph.nodes() - if (graph.node.get(n).get('empty') == False - and is_enabled(graph.node.get(n))) + if (graph.node.get(n).get('empty') is False and + is_enabled(graph.node.get(n))) ] filtered_graph = graph.subgraph(to_run) @@ -654,12 +655,11 @@ def get_nodes_to_run(self, graph, include_spec, exclude_spec, post_filter = [ n for n in selected_nodes - if ((graph.node.get(n).get('resource_type') in resource_types) - and get_materialization(graph.node.get(n)) != 'ephemeral' - and (len(tags) == 0 or - # does the node share any tags with the run? - bool(set(graph.node.get(n).get('tags')) & - set(tags)))) + if ((graph.node.get(n).get('resource_type') in resource_types) and + get_materialization(graph.node.get(n)) != 'ephemeral' and + (len(tags) == 0 or + # does the node share any tags with the run? + bool(set(graph.node.get(n).get('tags')) & set(tags)))) ] return set(post_filter) diff --git a/dbt/utils.py b/dbt/utils.py index 00e350eae29..2b751b9aacf 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -123,14 +123,15 @@ def find_model_by_name(all_models, target_model_name, for name, model in all_models.items(): resource_type, package_name, model_name = name.split('.') - if (resource_type == 'model' and \ - ((target_model_name == model_name) and \ + if (resource_type == 'model' and + ((target_model_name == model_name) and (target_model_package is None or target_model_package == package_name))): return model return None + def find_model_by_fqn(models, fqn): for model in models: if tuple(model.fqn) == tuple(fqn): From 7db4b1d4be11c471460886031b960eca8e76ab25 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 08:58:15 -0500 Subject: [PATCH 13/25] remove print() call from runner.py --- dbt/runner.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/dbt/runner.py b/dbt/runner.py index 02ca83cfcfe..defa90296da 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -331,9 +331,6 @@ def run_hooks(profile, hooks, context, source): if type(hooks) not in (list, tuple): hooks = [hooks] - print('hooks') - print(hooks) - ctx = { "target": profile, "state": "start", From 7a0039d0e386d91610a41661a470522ccdc0e1fd Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 08:59:15 -0500 Subject: [PATCH 14/25] remove print() calls from selector.py --- dbt/graph/selector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index bd9510da6ad..1fe227925bb 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -121,8 +121,6 @@ def get_nodes_from_spec(project, graph, spec): if select_children: for node in selected_nodes: child_nodes = nx.descendants(graph, node) - print('\nchild_nodes') - print(child_nodes) additional_nodes.update(child_nodes) model_nodes = selected_nodes | additional_nodes From 6a3202ee1083fa0d85925329225438fc258e16de Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 09:00:31 -0500 Subject: [PATCH 15/25] remove schema_tester, dbt.archival import --- dbt/model.py | 2 - dbt/schema_tester.py | 118 ------------------------------------------- 2 files changed, 120 deletions(-) delete mode 100644 dbt/schema_tester.py diff --git a/dbt/model.py b/dbt/model.py index a4e303dcb01..3bd4b09c13f 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -4,9 +4,7 @@ import re from dbt.templates import BaseCreateTemplate, ArchiveInsertTemplate from dbt.utils import split_path -import dbt.schema_tester import dbt.project -import dbt.archival from dbt.adapters.factory import get_adapter from dbt.utils import deep_merge, DBTConfigKeys, compiler_error, \ compiler_warning diff --git a/dbt/schema_tester.py b/dbt/schema_tester.py deleted file mode 100644 index 42da30a5a6d..00000000000 --- a/dbt/schema_tester.py +++ /dev/null @@ -1,118 +0,0 @@ -import os - -from dbt.logger import GLOBAL_LOGGER as logger -import dbt.targets - -import psycopg2 -import logging -import time -import datetime - - -QUERY_VALIDATE_NOT_NULL = """ -with validation as ( - select {field} as f - from "{schema}"."{table}" -) -select count(*) from validation where f is null -""" - -QUERY_VALIDATE_UNIQUE = """ -with validation as ( - select {field} as f - from "{schema}"."{table}" - where {field} is not null -), -validation_errors as ( - select f from validation group by f having count(*) > 1 -) -select count(*) from validation_errors -""" - -QUERY_VALIDATE_ACCEPTED_VALUES = """ -with all_values as ( - select distinct {field} as f - from "{schema}"."{table}" -), -validation_errors as ( - select f from all_values where f not in ({values_csv}) -) -select count(*) from validation_errors -""" - -QUERY_VALIDATE_REFERENTIAL_INTEGRITY = """ -with parent as ( - select {parent_field} as id - from "{schema}"."{parent_table}" -), child as ( - select {child_field} as id - from "{schema}"."{child_table}" -) -select count(*) from child -where id not in (select id from parent) and id is not null -""" - -DDL_TEST_RESULT_CREATE = """ -create table if not exists {schema}.dbt_test_results ( - tested_at timestamp without time zone, - model_name text, - errored bool, - skipped bool, - failed bool, - count_failures integer, - execution_time double precision -); -""" - - -class SchemaTester(object): - def __init__(self, project): - self.project = project - - self.test_started_at = datetime.datetime.now() - - def get_target(self): - target_cfg = self.project.run_environment() - return dbt.targets.get_target(target_cfg) - - def execute_query(self, model, sql): - target = self.get_target() - - with target.get_handle() as handle: - with handle.cursor() as cursor: - try: - logger.debug("SQL: %s", sql) - pre = time.time() - cursor.execute(sql) - post = time.time() - logger.debug( - "SQL status: %s in %d seconds", - cursor.statusmessage, post-pre) - except psycopg2.ProgrammingError as e: - logger.debug('programming error: %s', sql) - return e.diag.message_primary - except Exception as e: - logger.debug( - 'encountered exception while running: %s', sql) - e.model = model - raise e - - result = cursor.fetchone() - if len(result) != 1: - logger.debug("SQL: %s", sql) - logger.debug("RESULT: %s", result) - raise RuntimeError( - "Unexpected validation result. Expected 1 record, " - "got {}".format(len(result))) - else: - return result[0] - - def validate_schema(self, schema_test): - sql = schema_test.render() - num_rows = self.execute_query(model, sql) - if num_rows == 0: - logger.info(" OK") - yield True - else: - logger.info(" FAILED ({})".format(num_rows)) - yield False From 2a2aec2eca12be500e9fea1813f772f1e139b137 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 10:18:32 -0500 Subject: [PATCH 16/25] fix unit tests, compile cmd --- dbt/parser.py | 3 +- dbt/task/compile.py | 4 +-- test/unit/test_compiler.py | 9 ++++++ test/unit/test_graph.py | 57 +++++++++++++++++++++----------------- test/unit/test_parser.py | 28 +++++++++++++++++++ 5 files changed, 73 insertions(+), 28 deletions(-) diff --git a/dbt/parser.py b/dbt/parser.py index 398cfba3d1e..e545c09acca 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -180,7 +180,8 @@ def parse_node(node, node_path, root_project_config, package_project_config, return parsed_node -def parse_sql_nodes(nodes, root_project, projects, macro_generator, tags=[]): +def parse_sql_nodes(nodes, root_project, projects, macro_generator=None, + tags=[]): to_return = {} dbt.contracts.graph.unparsed.validate(nodes) diff --git a/dbt/task/compile.py b/dbt/task/compile.py index 04eb627d6b5..5321225614b 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -13,6 +13,6 @@ def run(self): results = compiler.compile() stat_line = ", ".join( - ["{} {}".format(results[k], k) for k in CompilableEntities] - ) + ["{} {}s".format(ct, t) for t, ct in results.items()]) + logger.info("Compiled {}".format(stat_line)) diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index 7c73be5c357..5d1e3053645 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -57,6 +57,7 @@ def test__prepend_ctes__already_has_cte(self): 'model.root.ephemeral' ], 'config': self.model_config, + 'tags': [], 'path': 'view.sql', 'raw_sql': 'select * from {{ref("ephemeral")}}', 'compiled': True, @@ -79,6 +80,7 @@ def test__prepend_ctes__already_has_cte(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': ephemeral_config, + 'tags': [], 'path': 'ephemeral.sql', 'raw_sql': 'select * from source_table', 'compiled': True, @@ -119,6 +121,7 @@ def test__prepend_ctes__no_ctes(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'view.sql', 'raw_sql': ('with cte as (select * from something_else) ' 'select * from source_table'), @@ -140,6 +143,7 @@ def test__prepend_ctes__no_ctes(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'view.sql', 'raw_sql': 'select * from source_table', 'compiled': True, @@ -189,6 +193,7 @@ def test__prepend_ctes(self): 'model.root.ephemeral' ], 'config': self.model_config, + 'tags': [], 'path': 'view.sql', 'raw_sql': 'select * from {{ref("ephemeral")}}', 'compiled': True, @@ -210,6 +215,7 @@ def test__prepend_ctes(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': ephemeral_config, + 'tags': [], 'path': 'ephemeral.sql', 'raw_sql': 'select * from source_table', 'compiled': True, @@ -256,6 +262,7 @@ def test__prepend_ctes__multiple_levels(self): 'model.root.ephemeral' ], 'config': self.model_config, + 'tags': [], 'path': 'view.sql', 'raw_sql': 'select * from {{ref("ephemeral")}}', 'compiled': True, @@ -277,6 +284,7 @@ def test__prepend_ctes__multiple_levels(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': ephemeral_config, + 'tags': [], 'path': 'ephemeral.sql', 'raw_sql': 'select * from {{ref("ephemeral_level_two")}}', 'compiled': True, @@ -298,6 +306,7 @@ def test__prepend_ctes__multiple_levels(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': ephemeral_config, + 'tags': [], 'path': 'ephemeral_level_two.sql', 'raw_sql': 'select * from source_table', 'compiled': True, diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 29c0f888a70..989ca3cb808 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -131,7 +131,7 @@ def test__single_model(self): self.assertEquals( self.graph_result.nodes(), - ['model.root.model_one']) + ['model.test_models_compile.model_one']) self.assertEquals( self.graph_result.edges(), @@ -149,14 +149,14 @@ def test__two_models_simple_ref(self): six.assertCountEqual(self, self.graph_result.nodes(), [ - 'model.root.model_one', - 'model.root.model_two', + 'model.test_models_compile.model_one', + 'model.test_models_compile.model_two', ]) six.assertCountEqual( self, self.graph_result.edges(), - [ ('model.root.model_one','model.root.model_two',) ]) + [ ('model.test_models_compile.model_one','model.test_models_compile.model_two',) ]) def test__model_materializations(self): self.use_models({ @@ -190,8 +190,9 @@ def test__model_materializations(self): nodes = self.graph_result.node for model, expected in expected_materialization.items(): - actual = nodes['model.root.{}'.format(model)].get('config', {}) \ - .get('materialized') + key = 'model.test_models_compile.{}'.format(model) + actual = nodes[key].get('config', {}) \ + .get('materialized') self.assertEquals(actual, expected) def test__model_enabled(self): @@ -215,11 +216,13 @@ def test__model_enabled(self): six.assertCountEqual( self, self.graph_result.nodes(), - ['model.root.model_one', 'model.root.model_two']) + ['model.test_models_compile.model_one', + 'model.test_models_compile.model_two']) six.assertCountEqual( self, self.graph_result.edges(), - [('model.root.model_one','model.root.model_two',)]) + [('model.test_models_compile.model_one', + 'model.test_models_compile.model_two',)]) def test__model_incremental_without_sql_where_fails(self): self.use_models({ @@ -260,7 +263,7 @@ def test__model_incremental(self): compiler = self.get_compiler(self.get_project(cfg)) compiler.compile() - node = 'model.root.model_one' + node = 'model.test_models_compile.model_one' self.assertEqual(self.graph_result.nodes(), [node]) self.assertEqual(self.graph_result.edges(), []) @@ -288,19 +291,23 @@ def test__topological_ordering(self): six.assertCountEqual(self, self.graph_result.nodes(), [ - 'model.root.model_1', - 'model.root.model_2', - 'model.root.model_3', - 'model.root.model_4', + 'model.test_models_compile.model_1', + 'model.test_models_compile.model_2', + 'model.test_models_compile.model_3', + 'model.test_models_compile.model_4', ]) six.assertCountEqual(self, self.graph_result.edges(), [ - ('model.root.model_1', 'model.root.model_2',), - ('model.root.model_1', 'model.root.model_3',), - ('model.root.model_2', 'model.root.model_3',), - ('model.root.model_3', 'model.root.model_4',), + ('model.test_models_compile.model_1', + 'model.test_models_compile.model_2',), + ('model.test_models_compile.model_1', + 'model.test_models_compile.model_3',), + ('model.test_models_compile.model_2', + 'model.test_models_compile.model_3',), + ('model.test_models_compile.model_3', + 'model.test_models_compile.model_4',), ]) linker = dbt.linker.Linker() @@ -308,10 +315,10 @@ def test__topological_ordering(self): actual_ordering = linker.as_topological_ordering() expected_ordering = [ - 'model.root.model_1', - 'model.root.model_2', - 'model.root.model_3', - 'model.root.model_4', + 'model.test_models_compile.model_1', + 'model.test_models_compile.model_2', + 'model.test_models_compile.model_3', + 'model.test_models_compile.model_4', ] self.assertEqual(actual_ordering, expected_ordering) @@ -337,10 +344,10 @@ def test__dependency_list(self): actual_dep_list = linker.as_dependency_list() expected_dep_list = [ - ['model.root.model_1'], - ['model.root.model_2'], - ['model.root.model_3'], - ['model.root.model_4'], + ['model.test_models_compile.model_1'], + ['model.test_models_compile.model_2'], + ['model.test_models_compile.model_3'], + ['model.test_models_compile.model_4'], ] self.assertEqual(actual_dep_list, expected_dep_list) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index f500ff5594b..b5c3100d987 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -67,6 +67,7 @@ def test__single_model(self): 'root_path': '/usr/src/app', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'model_one.sql', 'raw_sql': self.find_input_by_name( models, 'model_one').get('raw_sql') @@ -99,6 +100,7 @@ def test__empty_model(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'model_one.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -140,6 +142,7 @@ def test__simple_dependency(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'base.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -154,6 +157,7 @@ def test__simple_dependency(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'events_tx.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -220,6 +224,7 @@ def test__multiple_dependencies(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'events.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -234,6 +239,7 @@ def test__multiple_dependencies(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'sessions.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -248,6 +254,7 @@ def test__multiple_dependencies(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'events_tx.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -262,6 +269,7 @@ def test__multiple_dependencies(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'sessions_tx.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -276,6 +284,7 @@ def test__multiple_dependencies(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'multi.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -342,6 +351,7 @@ def test__multiple_dependencies__packages(self): 'package_name': 'snowplow', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'events.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -356,6 +366,7 @@ def test__multiple_dependencies__packages(self): 'package_name': 'snowplow', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'sessions.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -370,6 +381,7 @@ def test__multiple_dependencies__packages(self): 'package_name': 'snowplow', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'events_tx.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -384,6 +396,7 @@ def test__multiple_dependencies__packages(self): 'package_name': 'snowplow', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'sessions_tx.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -398,6 +411,7 @@ def test__multiple_dependencies__packages(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'path': 'multi.sql', 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( @@ -437,6 +451,7 @@ def test__in_model_config(self): 'package_name': 'root', 'depends_on': [], 'config': self.model_config, + 'tags': [], 'root_path': '/usr/src/app', 'path': 'model_one.sql', 'raw_sql': self.find_input_by_name( @@ -516,6 +531,7 @@ def test__root_project_config(self): 'depends_on': [], 'path': 'table.sql', 'config': self.model_config, + 'tags': [], 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'table').get('raw_sql') @@ -530,6 +546,7 @@ def test__root_project_config(self): 'depends_on': [], 'path': 'ephemeral.sql', 'config': ephemeral_config, + 'tags': [], 'root_path': '/usr/src/app', 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') @@ -545,6 +562,7 @@ def test__root_project_config(self): 'path': 'view.sql', 'root_path': '/usr/src/app', 'config': view_config, + 'tags': [], 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') } @@ -669,6 +687,7 @@ def test__other_project_config(self): 'path': 'table.sql', 'root_path': '/usr/src/app', 'config': self.model_config, + 'tags': [], 'raw_sql': self.find_input_by_name( models, 'table').get('raw_sql') }, @@ -683,6 +702,7 @@ def test__other_project_config(self): 'path': 'ephemeral.sql', 'root_path': '/usr/src/app', 'config': ephemeral_config, + 'tags': [], 'raw_sql': self.find_input_by_name( models, 'ephemeral').get('raw_sql') }, @@ -697,6 +717,7 @@ def test__other_project_config(self): 'path': 'view.sql', 'root_path': '/usr/src/app', 'config': view_config, + 'tags': [], 'raw_sql': self.find_input_by_name( models, 'view').get('raw_sql') }, @@ -711,6 +732,7 @@ def test__other_project_config(self): 'path': 'disabled.sql', 'root_path': '/usr/src/app', 'config': disabled_config, + 'tags': [], 'raw_sql': self.find_input_by_name( models, 'disabled').get('raw_sql') }, @@ -725,6 +747,7 @@ def test__other_project_config(self): 'path': 'models/views/package.sql', 'root_path': '/usr/src/app', 'config': sort_config, + 'tags': [], 'raw_sql': self.find_input_by_name( models, 'package').get('raw_sql') } @@ -787,6 +810,7 @@ def test__simple_schema_test(self): 'depends_on': [], 'config': self.model_config, 'path': 'test_one.yml', + 'tags': ['schema'], 'raw_sql': not_null_sql, }, 'test.root.unique_model_one_id': { @@ -800,6 +824,7 @@ def test__simple_schema_test(self): 'depends_on': [], 'config': self.model_config, 'path': 'test_one.yml', + 'tags': ['schema'], 'raw_sql': unique_sql, }, 'test.root.accepted_values_model_one_id': { @@ -813,6 +838,7 @@ def test__simple_schema_test(self): 'depends_on': [], 'config': self.model_config, 'path': 'test_one.yml', + 'tags': ['schema'], 'raw_sql': accepted_values_sql, }, 'test.root.relationships_model_one_id_to_model_two_id': { @@ -826,6 +852,7 @@ def test__simple_schema_test(self): 'depends_on': [], 'config': self.model_config, 'path': 'test_one.yml', + 'tags': ['schema'], 'raw_sql': relationships_sql, } @@ -862,6 +889,7 @@ def test__simple_data_test(self): 'config': self.model_config, 'path': 'no_events.sql', 'root_path': '/usr/src/app', + 'tags': [], 'raw_sql': self.find_input_by_name( tests, 'no_events').get('raw_sql') } From 305c80cfdb55387c9e49849148383eb5a7d8ceec Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 12:09:34 -0500 Subject: [PATCH 17/25] functional test improvements --- dbt/adapters/snowflake.py | 4 ++-- dbt/runner.py | 24 ++++++++++-------------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/dbt/adapters/snowflake.py b/dbt/adapters/snowflake.py index c89c64fefbb..231d26ec4f1 100644 --- a/dbt/adapters/snowflake.py +++ b/dbt/adapters/snowflake.py @@ -182,7 +182,7 @@ def rename(cls, profile, from_name, to_name, model_name=None): @classmethod def execute_model(cls, profile, model): - parts = re.split(r'-- (DBT_OPERATION .*)', model.compiled_contents) + parts = re.split(r'-- (DBT_OPERATION .*)', model.get('wrapped_sql')) connection = cls.get_connection(profile) if flags.STRICT_MODE: @@ -216,7 +216,7 @@ def call_expand_target_column_types(kwargs): func_map[function](kwargs) else: handle, cursor = cls.add_query_to_transaction( - part, connection, model.name) + part, connection, model.get('name')) handle.commit() diff --git a/dbt/runner.py b/dbt/runner.py index defa90296da..a8e16089b2c 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -158,7 +158,7 @@ def execute_test(profile, test): rows = cursor.fetchall() - adapter.rollback(profile) + adapter.commit(profile) cursor.close() @@ -195,10 +195,9 @@ def print_model_result_line(result, schema_name, index, total): result.execution_time) -def execute_model(profile, model): +def execute_model(profile, model, existing): adapter = get_adapter(profile) schema = adapter.get_default_schema(profile) - existing = adapter.query_for_existing(profile, schema) tmp_name = '{}__dbt_tmp'.format(model.get('name')) @@ -387,11 +386,6 @@ def __init__(self, project, target_path, args): adapter = get_adapter(profile) schema_name = adapter.get_default_schema(profile) - self.existing_models = adapter.query_for_existing( - profile, - schema_name - ) - def call_get_columns_in_table(schema_name, table_name): return adapter.get_columns_in_table( profile, schema_name, table_name) @@ -423,7 +417,7 @@ def inject_runtime_config(self, node): return node def deserialize_graph(self): - logger.info("Loading dependency graph file") + logger.info("Loading dependency graph file.") base_target_path = self.project['target-path'] graph_file = os.path.join( @@ -433,7 +427,7 @@ def deserialize_graph(self): return dbt.linker.from_file(graph_file) - def execute_node(self, node): + def execute_node(self, node, existing): profile = self.project.run_environment() logger.debug("executing node %s", node.get('unique_id')) @@ -441,7 +435,7 @@ def execute_node(self, node): node = self.inject_runtime_config(node) if node.get('resource_type') == NodeType.Model: - result = execute_model(profile, node) + result = execute_model(profile, node, existing) elif node.get('resource_type') == NodeType.Test: result = execute_test(profile, node) elif node.get('resource_type') == NodeType.Archive: @@ -450,14 +444,14 @@ def execute_node(self, node): return result def safe_execute_node(self, data): - node = data + node, existing = data start_time = time.time() error = None try: - status = self.execute_node(node) + status = self.execute_node(node, existing) except (RuntimeError, dbt.exceptions.ProgrammingException, psycopg2.ProgrammingError, @@ -528,6 +522,8 @@ def execute_nodes(self, node_dependency_list, on_failure, ) logger.info("Running!") + existing = adapter.query_for_existing(profile, schema_name) + pool = ThreadPool(num_threads) logger.info("") @@ -606,7 +602,7 @@ def on_complete(run_model_results): map_result = pool.map_async( self.safe_execute_node, - local_nodes, + [(node, existing,) for node in local_nodes], callback=on_complete ) map_result.wait() From 77c480a1be1a3ca8c891ad933f3c7700cfd8353d Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 15:22:14 -0500 Subject: [PATCH 18/25] fix skipping, functional testing w/ revzilla --- dbt/compilation.py | 20 ++++++++++- dbt/model.py | 1 + dbt/parser.py | 2 +- dbt/runner.py | 33 ++++++++++++++---- dbt/task/archive.py | 13 ++----- dbt/task/compile.py | 14 +++----- dbt/task/run.py | 16 +++------ dbt/task/test.py | 15 ++------ test/unit/test_parser.py | 74 ++++++++++++++++++++++++++++++++++++---- test/unit/test_runner.py | 27 ++++++++++----- 10 files changed, 145 insertions(+), 70 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 519ef47424e..56b2e5f27de 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -33,6 +33,23 @@ graph_file_name = 'graph.yml' +def compile_and_print_status(project, args): + compiler = Compiler(project, args) + compiler.initialize() + results = { + NodeType.Model: 0, + NodeType.Test: 0, + NodeType.Archive: 0, + } + + results.update(compiler.compile()) + + stat_line = ", ".join( + ["{} {}s".format(ct, t) for t, ct in results.items()]) + + logger.info("Compiled {}".format(stat_line)) + + def compile_string(string, ctx): try: env = jinja2.Environment() @@ -224,7 +241,8 @@ def wrapped_do_ref(*args): logger.info("Compiler error in {}".format(model.get('path'))) logger.info("Enabled models:") for n, m in all_models.items(): - logger.info(" - {}".format(".".join(m.get('fqn')))) + if m.get('resource_type') == NodeType.Model: + logger.info(" - {}".format(m.get('unique_id'))) raise e return wrapped_do_ref diff --git a/dbt/model.py b/dbt/model.py index 3bd4b09c13f..6eb986d280d 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -2,6 +2,7 @@ import yaml import jinja2 import re + from dbt.templates import BaseCreateTemplate, ArchiveInsertTemplate from dbt.utils import split_path import dbt.project diff --git a/dbt/parser.py b/dbt/parser.py index e545c09acca..ad3ce1f66e8 100644 --- a/dbt/parser.py +++ b/dbt/parser.py @@ -123,7 +123,7 @@ def get_fqn(path, package_project_config, extra=[]): parts = dbt.utils.split_path(path) name, _ = os.path.splitext(parts[-1]) fqn = ([package_project_config.get('name')] + - parts[1:-1] + + parts[:-1] + extra + [name]) diff --git a/dbt/runner.py b/dbt/runner.py index a8e16089b2c..eff946173e2 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -195,6 +195,20 @@ def print_model_result_line(result, schema_name, index, total): result.execution_time) + +def print_results_line(results): + stats = {} + + for result in results: + stats[result.node.get('resource_type')] = stats.get( + result.node.get('resource_type'), 0) + 1 + + stat_line = ", ".join( + ["{} {}s".format(ct, t) for t, ct in stats.items()]) + + print_timestamped_line("Finished running {}.".format(stat_line)) + + def execute_model(profile, model, existing): adapter = get_adapter(profile) schema = adapter.get_default_schema(profile) @@ -432,6 +446,9 @@ def execute_node(self, node, existing): logger.debug("executing node %s", node.get('unique_id')) + if node.get('skip') is True: + return RunModelResult(node, skip=True) + node = self.inject_runtime_config(node) if node.get('resource_type') == NodeType.Model: @@ -463,7 +480,9 @@ def safe_execute_node(self, data): if type(e) == psycopg2.InternalError and \ ABORTED_TRANSACTION_STRING == e.diag.message_primary: return RunModelResult( - node, error=ABORTED_TRANSACTION_STRING, status="SKIP") + node, + error='{}\n'.format(ABORTED_TRANSACTION_STRING), + status="SKIP") except Exception as e: error = ("Unhandled error while executing {filepath}\n{error}" .format( @@ -494,8 +513,9 @@ def skip_dependent(node): dependent_nodes = linker.get_dependent_nodes(node.get('unique_id')) for node in dependent_nodes: if node in selected_nodes: - # TODO fix skipping - pass + node_data = linker.get_node(node) + node_data['skip'] = True + linker.update_node_data(node, node_data) return skip_dependent @@ -545,8 +565,8 @@ def get_idx(node): for node_list in node_dependency_list: for i, node in enumerate([node for node in node_list if node.get('skip')]): - print_skip_line( - schema_name, node.get('name'), get_idx(node), num_nodes) + print_skip_line(node, schema_name, node.get('name'), + get_idx(node), num_nodes) node_result = RunModelResult(node, skip=True) node_results.append(node_result) @@ -620,8 +640,7 @@ def on_complete(run_model_results): 'on-run-end hooks') logger.info("") - logger.info("FIXME") - # logger.info(runner.post_run_all_msg(model_results)) + print_results_line(node_results) return node_results diff --git a/dbt/task/archive.py b/dbt/task/archive.py index 160c8bfa7ec..752daf404ae 100644 --- a/dbt/task/archive.py +++ b/dbt/task/archive.py @@ -8,18 +8,9 @@ def __init__(self, args, project): self.args = args self.project = project - def compile(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join( - ["{} {}s".format(ct, t) for t, ct in results.items()]) - - logger.info("Compiled {}".format(stat_line)) - def run(self): - self.compile() + dbt.compilation.compile_and_print_status( + self.project, self.args) runner = RunManager( self.project, diff --git a/dbt/task/compile.py b/dbt/task/compile.py index 5321225614b..0868c54a9b3 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -1,6 +1,6 @@ -from dbt.compilation import Compiler, CompilableEntities -from dbt.logger import GLOBAL_LOGGER as logger +import dbt.compilation +from dbt.logger import GLOBAL_LOGGER as logger class CompileTask: def __init__(self, args, project): @@ -8,11 +8,5 @@ def __init__(self, args, project): self.project = project def run(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join( - ["{} {}s".format(ct, t) for t, ct in results.items()]) - - logger.info("Compiled {}".format(stat_line)) + dbt.compilation.compile_and_print_status( + self.project, self.args) diff --git a/dbt/task/run.py b/dbt/task/run.py index 1ddfc88413c..c6f5eee9f9c 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -1,6 +1,7 @@ from __future__ import print_function -from dbt.compilation import Compiler, CompilableEntities +import dbt.compilation + from dbt.logger import GLOBAL_LOGGER as logger from dbt.runner import RunManager @@ -12,18 +13,9 @@ def __init__(self, args, project): self.args = args self.project = project - def compile(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join( - ["{} {}s".format(ct, t) for t, ct in results.items()]) - - logger.info("Compiled {}".format(stat_line)) - def run(self): - self.compile() + dbt.compilation.compile_and_print_status( + self.project, self.args) runner = RunManager( self.project, self.project['target-path'], self.args diff --git a/dbt/task/test.py b/dbt/task/test.py index b018540ffb7..2771c6f6e8e 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -1,5 +1,5 @@ +import dbt.compilation -from dbt.compilation import Compiler, CompilableEntities from dbt.runner import RunManager from dbt.logger import GLOBAL_LOGGER as logger @@ -19,18 +19,9 @@ def __init__(self, args, project): self.args = args self.project = project - def compile(self): - compiler = Compiler(self.project, self.args) - compiler.initialize() - results = compiler.compile() - - stat_line = ", ".join([ - "{} {}s".format(ct, t) for t, ct in results.items() - ]) - logger.info("Compiled {}".format(stat_line)) - def run(self): - self.compile() + dbt.compilation.compile_and_print_status( + self.project, self.args) runner = RunManager( self.project, self.project['target-path'], self.args) diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index b5c3100d987..6e55030f382 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -75,6 +75,63 @@ def test__single_model(self): } ) + def test__single_model__nested_configuration(self): + models = [{ + 'name': 'model_one', + 'resource_type': 'model', + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'path': 'nested/path/model_one.sql', + 'raw_sql': ("select * from events"), + }] + + self.root_project_config = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + 'models': { + 'materialized': 'ephemeral', + 'root': { + 'nested': { + 'path': { + 'materialized': 'ephemeral' + } + } + } + } + } + + ephemeral_config = self.model_config.copy() + ephemeral_config.update({ + 'materialized': 'ephemeral' + }) + + self.assertEquals( + dbt.parser.parse_sql_nodes( + models, + self.root_project_config, + {'root': self.root_project_config, + 'snowplow': self.snowplow_project_config}), + { + 'model.root.model_one': { + 'name': 'model_one', + 'resource_type': 'model', + 'unique_id': 'model.root.model_one', + 'fqn': ['root', 'nested', 'path', 'model_one'], + 'empty': False, + 'package_name': 'root', + 'root_path': '/usr/src/app', + 'depends_on': [], + 'config': ephemeral_config, + 'tags': [], + 'path': 'nested/path/model_one.sql', + 'raw_sql': self.find_input_by_name( + models, 'model_one').get('raw_sql') + } + } + ) + def test__empty_model(self): models = [{ 'name': 'model_one', @@ -597,10 +654,12 @@ def test__other_project_config(self): 'version': '0.1', 'project-root': os.path.abspath('./dbt_modules/snowplow'), 'models': { - 'enabled': False, - 'views': { - 'materialized': 'table', - 'sort': 'timestamp' + 'snowplow': { + 'enabled': False, + 'views': { + 'materialized': 'table', + 'sort': 'timestamp' + } } } } @@ -638,7 +697,7 @@ def test__other_project_config(self): 'name': 'package', 'resource_type': 'model', 'package_name': 'snowplow', - 'path': 'models/views/package.sql', + 'path': 'views/package.sql', 'root_path': '/usr/src/app', 'raw_sql': ("select * from events"), }] @@ -666,7 +725,8 @@ def test__other_project_config(self): sort_config = self.model_config.copy() sort_config.update({ 'enabled': False, - 'materialized': 'view' + 'materialized': 'view', + 'sort': 'timestamp', }) self.assertEquals( @@ -744,7 +804,7 @@ def test__other_project_config(self): 'empty': False, 'package_name': 'snowplow', 'depends_on': [], - 'path': 'models/views/package.sql', + 'path': 'views/package.sql', 'root_path': '/usr/src/app', 'config': sort_config, 'tags': [], diff --git a/test/unit/test_runner.py b/test/unit/test_runner.py index 16fbbe20b6c..e0797b85205 100644 --- a/test/unit/test_runner.py +++ b/test/unit/test_runner.py @@ -101,7 +101,8 @@ def test__execute_model__view(self, dbt.runner.execute_model( self.profile, - model) + model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() @@ -117,12 +118,14 @@ def test__execute_model__view__existing(self, mock_adapter_truncate, mock_adapter_rename, mock_adapter_execute_model): - self.existing = {'view': 'view'} + self.existing = {'view': 'table'} + model = self.model.copy() dbt.runner.execute_model( self.profile, - model) + model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() @@ -143,7 +146,8 @@ def test__execute_model__table(self, dbt.runner.execute_model( self.profile, - model) + model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() @@ -166,7 +170,8 @@ def test__execute_model__table__existing(self, dbt.runner.execute_model( self.profile, - self.model) + self.model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() @@ -189,7 +194,8 @@ def test__execute_model__view__destructive(self, dbt.runner.execute_model( self.profile, - model) + model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() @@ -212,7 +218,8 @@ def test__execute_model__view__existing__destructive(self, dbt.runner.execute_model( self.profile, - model) + model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_called_once() @@ -235,7 +242,8 @@ def test__execute_model__table__destructive(self, dbt.runner.execute_model( self.profile, - model) + model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_not_called() @@ -260,7 +268,8 @@ def test__execute_model__table__existing__destructive(self, dbt.runner.execute_model( self.profile, - self.model) + self.model, + existing=self.existing) dbt.adapters.postgres.PostgresAdapter.drop.assert_called_once() From ac3e5ee4266c4955ef0f8f33cb7f2dbc47bbd72b Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 16:37:01 -0500 Subject: [PATCH 19/25] hooks work... finishing up? --- dbt/compilation.py | 3 +- dbt/main.py | 2 +- dbt/runner.py | 22 ++- dbt/task/test.py | 1 - .../models/hooks.sql | 0 test/integration/014_hook_tests/seed.sql | 39 +++++ .../integration/014_hook_tests/seed_model.sql | 19 +++ .../seed.sql => 014_hook_tests/seed_run.sql} | 4 +- .../014_hook_tests/test_model_hooks.py | 140 ++++++++++++++++++ .../test_run_hooks.py} | 13 +- 10 files changed, 225 insertions(+), 18 deletions(-) rename test/integration/{014_pre_post_run_hook_tests => 014_hook_tests}/models/hooks.sql (100%) create mode 100644 test/integration/014_hook_tests/seed.sql create mode 100644 test/integration/014_hook_tests/seed_model.sql rename test/integration/{014_pre_post_run_hook_tests/seed.sql => 014_hook_tests/seed_run.sql} (80%) create mode 100644 test/integration/014_hook_tests/test_model_hooks.py rename test/integration/{014_pre_post_run_hook_tests/test_pre_post_run_hooks.py => 014_hook_tests/test_run_hooks.py} (90%) diff --git a/dbt/compilation.py b/dbt/compilation.py index 56b2e5f27de..9da13eb908e 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -415,7 +415,8 @@ def compile_nodes(self, linker, nodes, macro_generator): injected_node.get('path'), all_projects.get(injected_node.get('package_name'))) - model._config = injected_node.get('config', {}) + cfg = injected_node.get('config', {}) + model._config = cfg context = self.get_context(linker, model, injected_nodes) diff --git a/dbt/main.py b/dbt/main.py index 1c1dd10bdef..a1813d5114c 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -28,7 +28,7 @@ def main(args=None): args = sys.argv[1:] try: - return handle(args) + handle(args) except RuntimeError as e: logger.info("Encountered an error:") diff --git a/dbt/runner.py b/dbt/runner.py index eff946173e2..6db3ec57ec5 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -60,7 +60,7 @@ def print_timestamped_line(msg): def print_fancy_output_line(msg, status, index, total, execution_time=None): - prefix = "{timestamp} {index} of {total} {message}".format( + prefix = "{timestamp} | {index} of {total} {message}".format( timestamp=get_timestamp(), index=index, total=total, @@ -91,7 +91,9 @@ def print_counts(flat_nodes): counts[t] = counts.get(t, 0) + 1 for k, v in counts.items(): + logger.info("") print_timestamped_line("Running {} {}s".format(v, k)) + print_timestamped_line("") def print_start_line(node, schema_name, index, total): @@ -143,7 +145,7 @@ def print_test_result_line(result, schema_name, index, total): "{info} {name}".format( info=info, name=model.get('name')), - result.status, + info, index, total, result.execution_time) @@ -196,7 +198,7 @@ def print_model_result_line(result, schema_name, index, total): -def print_results_line(results): +def print_results_line(results, execution_time): stats = {} for result in results: @@ -206,7 +208,10 @@ def print_results_line(results): stat_line = ", ".join( ["{} {}s".format(ct, t) for t, ct in stats.items()]) - print_timestamped_line("Finished running {}.".format(stat_line)) + print_timestamped_line("") + print_timestamped_line( + "Finished running {stat_line} in {execution_time:0.2f}s." + .format(stat_line=stat_line, execution_time=execution_time)) def execute_model(profile, model, existing): @@ -540,15 +545,15 @@ def execute_nodes(self, node_dependency_list, on_failure, logger.info("Concurrency: {} threads (target='{}')".format( num_threads, self.project.get_target().get('name')) ) - logger.info("Running!") existing = adapter.query_for_existing(profile, schema_name) pool = ThreadPool(num_threads) - logger.info("") print_counts(flat_nodes) + start_time = time.time() + if should_run_hooks: run_hooks(self.project.get_target(), self.project.cfg.get('on-run-start', []), @@ -639,8 +644,9 @@ def on_complete(run_model_results): self.context, 'on-run-end hooks') - logger.info("") - print_results_line(node_results) + execution_time = time.time() - start_time + + print_results_line(node_results, execution_time) return node_results diff --git a/dbt/task/test.py b/dbt/task/test.py index 2771c6f6e8e..e96737b0a33 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -39,5 +39,4 @@ def run(self): else: raise RuntimeError("unexpected") - logger.info("Done!") return res diff --git a/test/integration/014_pre_post_run_hook_tests/models/hooks.sql b/test/integration/014_hook_tests/models/hooks.sql similarity index 100% rename from test/integration/014_pre_post_run_hook_tests/models/hooks.sql rename to test/integration/014_hook_tests/models/hooks.sql diff --git a/test/integration/014_hook_tests/seed.sql b/test/integration/014_hook_tests/seed.sql new file mode 100644 index 00000000000..b889daa446f --- /dev/null +++ b/test/integration/014_hook_tests/seed.sql @@ -0,0 +1,39 @@ + +drop table run_hooks_014.on_run_hook; + +create table run_hooks_014.on_run_hook ( + "state" TEXT, -- start|end + + "target.dbname" TEXT, + "target.host" TEXT, + "target.name" TEXT, + "target.schema" TEXT, + "target.type" TEXT, + "target.user" TEXT, + "target.pass" TEXT, + "target.port" INTEGER, + "target.threads" INTEGER, + + "run_started_at" TEXT, + "invocation_id" TEXT +); + + +drop table model_hooks_014.on_model_hook; + +create table model_hooks_014.on_model_hook ( + "state" TEXT, -- start|end + + "target.dbname" TEXT, + "target.host" TEXT, + "target.name" TEXT, + "target.schema" TEXT, + "target.type" TEXT, + "target.user" TEXT, + "target.pass" TEXT, + "target.port" INTEGER, + "target.threads" INTEGER, + + "run_started_at" TEXT, + "invocation_id" TEXT +); diff --git a/test/integration/014_hook_tests/seed_model.sql b/test/integration/014_hook_tests/seed_model.sql new file mode 100644 index 00000000000..0ac8eb49b62 --- /dev/null +++ b/test/integration/014_hook_tests/seed_model.sql @@ -0,0 +1,19 @@ + +drop table if exists model_hooks_014.on_model_hook; + +create table model_hooks_014.on_model_hook ( + "state" TEXT, -- start|end + + "target.dbname" TEXT, + "target.host" TEXT, + "target.name" TEXT, + "target.schema" TEXT, + "target.type" TEXT, + "target.user" TEXT, + "target.pass" TEXT, + "target.port" INTEGER, + "target.threads" INTEGER, + + "run_started_at" TEXT, + "invocation_id" TEXT +); diff --git a/test/integration/014_pre_post_run_hook_tests/seed.sql b/test/integration/014_hook_tests/seed_run.sql similarity index 80% rename from test/integration/014_pre_post_run_hook_tests/seed.sql rename to test/integration/014_hook_tests/seed_run.sql index 49c34acd0fc..918c993699d 100644 --- a/test/integration/014_pre_post_run_hook_tests/seed.sql +++ b/test/integration/014_hook_tests/seed_run.sql @@ -1,5 +1,7 @@ -create table pre_post_run_hooks_014.on_run_hook ( +drop table if exists run_hooks_014.on_run_hook; + +create table run_hooks_014.on_run_hook ( "state" TEXT, -- start|end "target.dbname" TEXT, diff --git a/test/integration/014_hook_tests/test_model_hooks.py b/test/integration/014_hook_tests/test_model_hooks.py new file mode 100644 index 00000000000..2f510e4771c --- /dev/null +++ b/test/integration/014_hook_tests/test_model_hooks.py @@ -0,0 +1,140 @@ +from nose.plugins.attrib import attr +from test.integration.base import DBTIntegrationTest + + +MODEL_PRE_HOOK = """ + insert into model_hooks_014.on_model_hook ( + "state", + "target.dbname", + "target.host", + "target.name", + "target.schema", + "target.type", + "target.user", + "target.pass", + "target.port", + "target.threads", + "run_started_at", + "invocation_id" + ) VALUES ( + 'start', + '{{ target.dbname }}', + '{{ target.host }}', + '{{ target.name }}', + '{{ target.schema }}', + '{{ target.type }}', + '{{ target.user }}', + '{{ target.pass }}', + {{ target.port }}, + {{ target.threads }}, + '{{ run_started_at }}', + '{{ invocation_id }}' + ) +""" + +MODEL_POST_HOOK = """ + insert into model_hooks_014.on_model_hook ( + "state", + "target.dbname", + "target.host", + "target.name", + "target.schema", + "target.type", + "target.user", + "target.pass", + "target.port", + "target.threads", + "run_started_at", + "invocation_id" + ) VALUES ( + 'end', + '{{ target.dbname }}', + '{{ target.host }}', + '{{ target.name }}', + '{{ target.schema }}', + '{{ target.type }}', + '{{ target.user }}', + '{{ target.pass }}', + {{ target.port }}, + {{ target.threads }}, + '{{ run_started_at }}', + '{{ invocation_id }}' + ) +""" + + +class TestPrePostModelHooks(DBTIntegrationTest): + + def setUp(self): + DBTIntegrationTest.setUp(self) + + self.run_sql_file("test/integration/014_hook_tests/seed_model.sql") + + self.fields = [ + 'state', + 'target.dbname', + 'target.host', + 'target.name', + 'target.port', + 'target.schema', + 'target.threads', + 'target.type', + 'target.user', + 'target.pass', + 'run_started_at', + 'invocation_id' + ] + + @property + def schema(self): + return "model_hooks_014" + + @property + def project_config(self): + return { + 'models': { + 'test': { + 'pre-hook': MODEL_PRE_HOOK, + 'post-hook': MODEL_POST_HOOK, + } + } + } + + @property + def models(self): + return "test/integration/014_hook_tests/models" + + def get_ctx_vars(self, state): + field_list = ", ".join(['"{}"'.format(f) for f in self.fields]) + query = "select {field_list} from {schema}.on_model_hook where state = '{state}'".format(field_list=field_list, schema=self.schema, state=state) + + vals = self.run_sql(query, fetch='all') + self.assertFalse(len(vals) == 0, 'nothing inserted into hooks table') + self.assertFalse(len(vals) > 1, 'too many rows in hooks table') + ctx = dict([(k,v) for (k,v) in zip(self.fields, vals[0])]) + + return ctx + + def check_hooks(self, state): + ctx = self.get_ctx_vars(state) + + self.assertEqual(ctx['state'], state) + self.assertEqual(ctx['target.dbname'], 'dbt') + self.assertEqual(ctx['target.host'], 'database') + self.assertEqual(ctx['target.name'], 'default2') + self.assertEqual(ctx['target.port'], 5432) + self.assertEqual(ctx['target.schema'], self.schema) + self.assertEqual(ctx['target.threads'], 1) + self.assertEqual(ctx['target.type'], 'postgres') + self.assertEqual(ctx['target.user'], 'root') + self.assertEqual(ctx['target.pass'], '') + + self.assertTrue(ctx['run_started_at'] is not None and len(ctx['run_started_at']) > 0, 'run_started_at was not set') + self.assertTrue(ctx['invocation_id'] is not None and len(ctx['invocation_id']) > 0, 'invocation_id was not set') + + @attr(type='postgres') + def test_pre_and_post_model_hooks(self): + self.run_dbt(['run']) + + self.check_hooks('start') + self.check_hooks('end') diff --git a/test/integration/014_pre_post_run_hook_tests/test_pre_post_run_hooks.py b/test/integration/014_hook_tests/test_run_hooks.py similarity index 90% rename from test/integration/014_pre_post_run_hook_tests/test_pre_post_run_hooks.py rename to test/integration/014_hook_tests/test_run_hooks.py index 226dc9ef3df..451ab129df0 100644 --- a/test/integration/014_pre_post_run_hook_tests/test_pre_post_run_hooks.py +++ b/test/integration/014_hook_tests/test_run_hooks.py @@ -3,7 +3,7 @@ RUN_START_HOOK = """ - insert into pre_post_run_hooks_014.on_run_hook ( + insert into run_hooks_014.on_run_hook ( "state", "target.dbname", "target.host", @@ -33,7 +33,7 @@ """ RUN_END_HOOK = """ - insert into pre_post_run_hooks_014.on_run_hook ( + insert into run_hooks_014.on_run_hook ( "state", "target.dbname", "target.host", @@ -67,7 +67,7 @@ class TestPrePostRunHooks(DBTIntegrationTest): def setUp(self): DBTIntegrationTest.setUp(self) - self.run_sql_file("test/integration/014_pre_post_run_hook_tests/seed.sql") + self.run_sql_file("test/integration/014_hook_tests/seed_run.sql") self.fields = [ 'state', @@ -86,7 +86,7 @@ def setUp(self): @property def schema(self): - return "pre_post_run_hooks_014" + return "run_hooks_014" @property def project_config(self): @@ -97,7 +97,7 @@ def project_config(self): @property def models(self): - return "test/integration/014_pre_post_run_hook_tests/models" + return "test/integration/014_hook_tests/models" def get_ctx_vars(self, state): field_list = ", ".join(['"{}"'.format(f) for f in self.fields]) @@ -105,6 +105,7 @@ def get_ctx_vars(self, state): vals = self.run_sql(query, fetch='all') self.assertFalse(len(vals) == 0, 'nothing inserted into on_run_hook table') + self.assertFalse(len(vals) > 1, 'too many rows in hooks table') ctx = dict([(k,v) for (k,v) in zip(self.fields, vals[0])]) return ctx @@ -117,7 +118,7 @@ def check_hooks(self, state): self.assertEqual(ctx['target.host'], 'database') self.assertEqual(ctx['target.name'], 'default2') self.assertEqual(ctx['target.port'], 5432) - self.assertEqual(ctx['target.schema'], 'pre_post_run_hooks_014') + self.assertEqual(ctx['target.schema'], self.schema) self.assertEqual(ctx['target.threads'], 1) self.assertEqual(ctx['target.type'], 'postgres') self.assertEqual(ctx['target.user'], 'root') From 102e8dfdb8e6ce5b040ecfcbdb1f9ea101170d5a Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 16:56:52 -0500 Subject: [PATCH 20/25] add compat module to deal with str/unicode/basestring diffs in 2 vs 3 --- dbt/compat.py | 5 +++++ dbt/contracts/connection.py | 25 +++++++++++++------------ dbt/contracts/graph/compiled.py | 9 +++++---- dbt/contracts/graph/parsed.py | 17 +++++++++-------- dbt/contracts/graph/unparsed.py | 11 ++++++----- dbt/model.py | 7 +------ dbt/utils.py | 8 ++------ 7 files changed, 41 insertions(+), 41 deletions(-) create mode 100644 dbt/compat.py diff --git a/dbt/compat.py b/dbt/compat.py new file mode 100644 index 00000000000..25a7ea45332 --- /dev/null +++ b/dbt/compat.py @@ -0,0 +1,5 @@ +# python 2+3 check for stringiness +try: + basestring = basestring +except NameError: + basestring = str diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 1fd691c0f4d..6cd2a5ad0b0 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -1,5 +1,6 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional +from dbt.compat import basestring from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger @@ -12,22 +13,22 @@ }) postgres_credentials_contract = Schema({ - Required('dbname'): str, - Required('host'): str, - Required('user'): str, - Required('pass'): str, + Required('dbname'): basestring, + Required('host'): basestring, + Required('user'): basestring, + Required('pass'): basestring, Required('port'): All(int, Range(min=0, max=65535)), - Required('schema'): str, + Required('schema'): basestring, }) snowflake_credentials_contract = Schema({ - Required('account'): str, - Required('user'): str, - Required('password'): str, - Required('database'): str, - Required('schema'): str, - Required('warehouse'): str, - Optional('role'): str, + Required('account'): basestring, + Required('user'): basestring, + Required('password'): basestring, + Required('database'): basestring, + Required('schema'): basestring, + Required('warehouse'): basestring, + Optional('role'): basestring, }) credentials_mapping = { diff --git a/dbt/contracts/graph/compiled.py b/dbt/contracts/graph/compiled.py index 2c6670c8d34..5277e7889f7 100644 --- a/dbt/contracts/graph/compiled.py +++ b/dbt/contracts/graph/compiled.py @@ -1,6 +1,7 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length +from dbt.compat import basestring from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger @@ -11,13 +12,13 @@ compiled_graph_item_contract = parsed_graph_item_contract.extend({ # compiled fields Required('compiled'): bool, - Required('compiled_sql'): Any(str, None), + Required('compiled_sql'): Any(basestring, None), # injected fields Required('extra_ctes_injected'): bool, - Required('extra_cte_ids'): All(list, [str]), - Required('extra_cte_sql'): All(list, [str]), - Required('injected_sql'): Any(str, None), + Required('extra_cte_ids'): All(list, [basestring]), + Required('extra_cte_sql'): All(list, [basestring]), + Required('injected_sql'): Any(basestring, None), }) diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index 2064c93d890..6951a0ad613 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -1,6 +1,7 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length +from dbt.compat import basestring from dbt.exceptions import ValidationException from dbt.logger import GLOBAL_LOGGER as logger @@ -16,25 +17,25 @@ Required('vars'): dict, # incremental optional fields - Optional('sql_where'): str, - Optional('unique_key'): str, + Optional('sql_where'): basestring, + Optional('unique_key'): basestring, # adapter optional fields - Optional('sort'): str, - Optional('dist'): str, + Optional('sort'): basestring, + Optional('dist'): basestring, } parsed_graph_item_contract = unparsed_graph_item_contract.extend({ # identifiers - Required('unique_id'): All(str, Length(min=1, max=255)), - Required('fqn'): All(list, [All(str)]), + Required('unique_id'): All(basestring, Length(min=1, max=255)), + Required('fqn'): All(list, [All(basestring)]), # parsed fields - Required('depends_on'): All(list, [All(str, Length(min=1, max=255))]), + Required('depends_on'): All(list, [All(basestring, Length(min=1, max=255))]), Required('empty'): bool, Required('config'): config_contract, - Required('tags'): All(list, [str]), + Required('tags'): All(list, [basestring]), }) diff --git a/dbt/contracts/graph/unparsed.py b/dbt/contracts/graph/unparsed.py index 4b89d8aedd9..e0defcfce00 100644 --- a/dbt/contracts/graph/unparsed.py +++ b/dbt/contracts/graph/unparsed.py @@ -1,6 +1,7 @@ from voluptuous import Schema, Required, All, Any, Extra, Range, Optional, \ Length +from dbt.compat import basestring from dbt.contracts.common import validate_with from dbt.logger import GLOBAL_LOGGER as logger @@ -8,14 +9,14 @@ unparsed_graph_item_contract = Schema({ # identifiers - Required('name'): All(str, Length(min=1, max=63)), - Required('package_name'): str, + Required('name'): All(basestring, Length(min=1, max=63)), + Required('package_name'): basestring, Required('resource_type'): Any(NodeType.Model, NodeType.Test), # filesystem - Required('root_path'): str, - Required('path'): str, - Required('raw_sql'): str, + Required('root_path'): basestring, + Required('path'): basestring, + Required('raw_sql'): basestring, }) diff --git a/dbt/model.py b/dbt/model.py index 6eb986d280d..28328ad44ff 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -3,6 +3,7 @@ import jinja2 import re +from dbt.compat import basestring from dbt.templates import BaseCreateTemplate, ArchiveInsertTemplate from dbt.utils import split_path import dbt.project @@ -397,12 +398,6 @@ def build_path(self): return os.path.join(*path_parts) def compile_string(self, ctx, string): - # python 2+3 check for stringiness - try: - basestring - except NameError: - basestring = str - # if bool/int/float/etc are passed in, don't compile anything if not isinstance(string, basestring): return string diff --git a/dbt/utils.py b/dbt/utils.py index 2b751b9aacf..801f613eecb 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -3,6 +3,8 @@ import json import dbt.project + +from dbt.compat import basestring from dbt.logger import GLOBAL_LOGGER as logger DBTConfigKeys = [ @@ -95,12 +97,6 @@ def __call__(self, var_name, default=None): ) ) - # python 2+3 check for stringiness - try: - basestring - except NameError: - basestring = str - # if bool/int/float/etc are passed in, don't compile anything if not isinstance(raw, basestring): return raw From f24c0b196db0ee18bd7719603ce162d4a13c2c9e Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Mon, 27 Feb 2017 17:07:42 -0500 Subject: [PATCH 21/25] switch compilation import --- dbt/task/archive.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbt/task/archive.py b/dbt/task/archive.py index 752daf404ae..8077adb2793 100644 --- a/dbt/task/archive.py +++ b/dbt/task/archive.py @@ -1,5 +1,6 @@ +import dbt.compilation + from dbt.runner import RunManager -from dbt.compilation import Compiler from dbt.logger import GLOBAL_LOGGER as logger From 99b9ea175d83d948f0916708d43e61d54b1d9585 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Tue, 28 Feb 2017 11:18:53 -0500 Subject: [PATCH 22/25] fun with string compatibility --- dbt/compat.py | 44 +++++++++++++++++++++++++++++++++-- dbt/compilation.py | 8 ++++--- dbt/contracts/graph/parsed.py | 3 ++- dbt/runner.py | 2 -- dbt/task/compile.py | 1 + 5 files changed, 50 insertions(+), 8 deletions(-) diff --git a/dbt/compat.py b/dbt/compat.py index 25a7ea45332..26cba2577d3 100644 --- a/dbt/compat.py +++ b/dbt/compat.py @@ -1,5 +1,45 @@ -# python 2+3 check for stringiness +import codecs + +WHICH_PYTHON = None + try: - basestring = basestring + basestring + WHICH_PYTHON = 2 except NameError: + WHICH_PYTHON = 3 + +if WHICH_PYTHON == 2: + basestring = basestring +else: basestring = str + + +def to_unicode(s): + if WHICH_PYTHON == 2: + return unicode(s) + else: + return str(s) + + +def to_string(s): + if WHICH_PYTHON == 2: + if isinstance(s, unicode): + return s + elif isinstance(s, basestring): + return to_unicode(s) + else: + return to_unicode(str(s)) + else: + if isinstance(s, basestring): + return s + else: + return str(s) + + +def write_file(path, s): + if WHICH_PYTHON == 2: + with codecs.open(path, 'w', encoding='utf-8') as f: + return f.write(to_string(s)) + else: + with open(path, 'w') as f: + return f.write(to_string(s)) diff --git a/dbt/compilation.py b/dbt/compilation.py index 9da13eb908e..6377558d74d 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -1,3 +1,4 @@ +import codecs import os import fnmatch import jinja2 @@ -11,11 +12,12 @@ from dbt.model import Model, NodeType from dbt.source import Source from dbt.utils import find_model_by_fqn, find_model_by_name, \ - split_path, This, Var, compiler_error, to_string + split_path, This, Var, compiler_error from dbt.linker import Linker from dbt.runtime import RuntimeContext +import dbt.compat import dbt.contracts.graph.compiled import dbt.contracts.graph.parsed import dbt.contracts.project @@ -53,7 +55,7 @@ def compile_and_print_status(project, args): def compile_string(string, ctx): try: env = jinja2.Environment() - template = env.from_string(str(string), globals=ctx) + template = env.from_string(dbt.compat.to_string(string), globals=ctx) return template.render(ctx) except jinja2.exceptions.TemplateSyntaxError as e: compiler_error(None, str(e)) @@ -147,7 +149,7 @@ def inject_ctes_into_sql(sql, ctes): with_stmt, sqlparse.sql.Token(sqlparse.tokens.Keyword, ", ".join(ctes))) - return str(parsed) + return dbt.compat.to_string(parsed) class Compiler(object): diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index 6951a0ad613..320a18e47cd 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -32,7 +32,8 @@ Required('fqn'): All(list, [All(basestring)]), # parsed fields - Required('depends_on'): All(list, [All(basestring, Length(min=1, max=255))]), + Required('depends_on'): All(list, + [All(basestring, Length(min=1, max=255))]), Required('empty'): bool, Required('config'): config_contract, Required('tags'): All(list, [basestring]), diff --git a/dbt/runner.py b/dbt/runner.py index 6db3ec57ec5..b86ed6cae66 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -1,4 +1,3 @@ - from __future__ import print_function import jinja2 @@ -197,7 +196,6 @@ def print_model_result_line(result, schema_name, index, total): result.execution_time) - def print_results_line(results, execution_time): stats = {} diff --git a/dbt/task/compile.py b/dbt/task/compile.py index 0868c54a9b3..bfa207a7384 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -2,6 +2,7 @@ from dbt.logger import GLOBAL_LOGGER as logger + class CompileTask: def __init__(self, args, project): self.args = args From d96913d6ada182db35f87cee90a412bcb4f59bf3 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Tue, 28 Feb 2017 11:27:19 -0500 Subject: [PATCH 23/25] write_file is necessary --- dbt/compilation.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 6377558d74d..1870c0024b9 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -1,4 +1,3 @@ -import codecs import os import fnmatch import jinja2 @@ -177,8 +176,7 @@ def __write(self, build_filepath, payload): if not os.path.exists(os.path.dirname(target_path)): os.makedirs(os.path.dirname(target_path)) - with open(target_path, 'w') as f: - f.write(to_string(payload)) + dbt.compat.write_file(target_path, payload) def __model_config(self, model, linker): def do_config(*args, **kwargs): From caf61051d41a8467f0d2fab5948d22dd43766be0 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Wed, 1 Mar 2017 22:07:20 -0500 Subject: [PATCH 24/25] re-add analyses --- dbt/compilation.py | 39 ++++++++++++++++++++++++++++----- dbt/contracts/graph/unparsed.py | 4 +++- dbt/model.py | 1 + 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 1870c0024b9..c2e0eaa103e 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -37,16 +37,24 @@ def compile_and_print_status(project, args): compiler = Compiler(project, args) compiler.initialize() + names = { + NodeType.Model: 'models', + NodeType.Test: 'tests', + NodeType.Archive: 'archives', + NodeType.Analysis: 'analyses', + } + results = { NodeType.Model: 0, NodeType.Test: 0, NodeType.Archive: 0, + NodeType.Analysis: 0, } results.update(compiler.compile()) stat_line = ", ".join( - ["{} {}s".format(ct, t) for t, ct in results.items()]) + ["{} {}".format(ct, names.get(t)) for t, ct in results.items()]) logger.info("Compiled {}".format(stat_line)) @@ -398,8 +406,9 @@ def compile_nodes(self, linker, nodes, macro_generator): for name, injected_node in injected_nodes.items(): # now turn model nodes back into the old-style model object for # wrapping - if injected_node.get('resource_type') == NodeType.Test: - # don't wrap tests. + if injected_node.get('resource_type') in [NodeType.Test, + NodeType.Analysis]: + # don't wrap tests or analyses. injected_node['wrapped_sql'] = injected_node['injected_sql'] wrapped_nodes[name] = injected_node @@ -408,6 +417,7 @@ def compile_nodes(self, linker, nodes, macro_generator): # archives. in the future it'd be nice to generate # the SQL at the parser level. pass + else: model = Model( self.project, @@ -428,10 +438,11 @@ def compile_nodes(self, linker, nodes, macro_generator): build_path = os.path.join('build', injected_node.get('path')) - if injected_node.get('resource_type') == NodeType.Model and \ + if injected_node.get('resource_type') in (NodeType.Model, + NodeType.Analysis) and \ injected_node.get('config', {}) \ .get('materialized') != 'ephemeral': - self.__write(build_path, wrapped_stmt) + self.__write(build_path, injected_node.get('wrapped_sql')) written_nodes.append(injected_node) injected_node['build_path'] = build_path @@ -494,6 +505,22 @@ def get_parsed_models(self, root_project, all_projects, macro_generator): return parsed_models + def get_parsed_analyses(self, root_project, all_projects, macro_generator): + parsed_models = {} + + for name, project in all_projects.items(): + parsed_models.update( + dbt.parser.load_and_parse_sql( + package_name=name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.get('project-root'), + relative_dirs=project.get('analysis-paths', []), + resource_type=NodeType.Analysis, + macro_generator=macro_generator)) + + return parsed_models + def get_parsed_data_tests(self, root_project, all_projects, macro_generator): parsed_tests = {} @@ -531,6 +558,8 @@ def load_all_nodes(self, root_project, all_projects, macro_generator): all_nodes.update(self.get_parsed_models(root_project, all_projects, macro_generator)) + all_nodes.update(self.get_parsed_analyses(root_project, all_projects, + macro_generator)) all_nodes.update( self.get_parsed_data_tests(root_project, all_projects, macro_generator)) diff --git a/dbt/contracts/graph/unparsed.py b/dbt/contracts/graph/unparsed.py index e0defcfce00..b1b91c166d3 100644 --- a/dbt/contracts/graph/unparsed.py +++ b/dbt/contracts/graph/unparsed.py @@ -11,7 +11,9 @@ # identifiers Required('name'): All(basestring, Length(min=1, max=63)), Required('package_name'): basestring, - Required('resource_type'): Any(NodeType.Model, NodeType.Test), + Required('resource_type'): Any(NodeType.Model, + NodeType.Test, + NodeType.Analysis), # filesystem Required('root_path'): basestring, diff --git a/dbt/model.py b/dbt/model.py index 28328ad44ff..7da379afb0c 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -16,6 +16,7 @@ class NodeType(object): Base = 'base' Model = 'model' + Analysis = 'analysis' Test = 'test' Archive = 'archive' From d3142cbeeba4b0885674f94cec01bc04f90952e0 Mon Sep 17 00:00:00 2001 From: Connor McArthur Date: Thu, 2 Mar 2017 16:44:06 -0500 Subject: [PATCH 25/25] update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d061621aaef..bf6e6a3912e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## dbt 0.7.2 (unreleased) + +### Changes + +- Graph refactor: fix common issues with load order ([#292](https://github.com/fishtown-analytics/dbt/pull/292)) + ## dbt 0.7.1 (February 28, 2017) ### Overview