diff --git a/dbt/project.py b/dbt/project.py index b2f6eade325..bd6f561645f 100644 --- a/dbt/project.py +++ b/dbt/project.py @@ -4,12 +4,17 @@ import copy default_project_cfg = { - 'source-paths': ['model'], + 'source-paths': ['models'], 'test-paths': ['test'], 'target-path': 'target', 'clean-targets': ['target'], 'outputs': {'default': {}}, 'run-target': 'default', + 'models': {}, + 'model-defaults': { + "enabled": True, + "materialized": False + } } default_profiles = { @@ -18,7 +23,6 @@ default_active_profiles = ['user'] - class Project: def __init__(self, cfg, profiles, active_profile_names=[]): @@ -46,6 +50,9 @@ def __contains__(self, key): def __setitem__(self, key, value): return self.cfg.__setitem__(key, value) + def get(self, key, default=None): + return self.cfg.get(key, default) + def run_environment(self): target_name = self.cfg['run-target'] return self.cfg['outputs'][target_name] diff --git a/dbt/task/compile.py b/dbt/task/compile.py index 7d87896bede..1032a3e314e 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -2,41 +2,27 @@ import os import fnmatch import jinja2 - +import yaml +from collections import defaultdict class CompileTask: def __init__(self, args, project): self.args = args self.project = project - def __is_specified_model(self, path): - if 'models' not in self.project: - return True - - path_parts = path.split("/") - if len(path_parts) < 2: - return False - else: - model = path_parts[1] - for allowed_model in self.project['models']: - if fnmatch.fnmatch(model, allowed_model): - return True - return False - def __src_index(self): """returns: {'model': ['pardot/model.sql', 'segment/model.sql']} """ - indexed_files = {} + indexed_files = defaultdict(list) for source_path in self.project['source-paths']: for root, dirs, files in os.walk(source_path): - if not self.__is_specified_model(root): - continue for filename in files: + abs_path = os.path.join(root, filename) + rel_path = os.path.relpath(abs_path, source_path) + if fnmatch.fnmatch(filename, "*.sql"): - abs_path = os.path.join(root, filename) - rel_path = os.path.relpath(abs_path, source_path) - indexed_files.setdefault(source_path, []).append(rel_path) + indexed_files[source_path].append(rel_path) return indexed_files @@ -46,17 +32,70 @@ def __write(self, path, payload): if not os.path.exists(os.path.dirname(target_path)): os.makedirs(os.path.dirname(target_path)) elif os.path.exists(target_path): - print "Compiler overwrite of {}".format(target_path) + print("Compiler overwrite of {}".format(target_path)) with open(target_path, 'w') as f: f.write(payload) + def __wrap_in_create(self, path, query, model_config): + filename = os.path.basename(path) + identifier, ext = os.path.splitext(filename) + + # default to view if not provided in config! + table_or_view = 'table' if model_config['materialized'] else 'view' + + ctx = self.project.context() + schema = ctx['env']['schema'] + + create_template = "create {table_or_view} {schema}.{identifier} as ( {query} );" + + opts = { + "table_or_view": table_or_view, + "schema": schema, + "identifier": identifier, + "query": query + } + + return create_template.format(**opts) + + def __get_model_identifiers(self, model_filepath): + model_group = os.path.dirname(model_filepath) + model_name, _ = os.path.splitext(os.path.basename(model_filepath)) + return model_group, model_name + + def __get_model_config(self, model_group, model_name): + """merges model, model group, and base configs together. Model config + takes precedence, then model_group, then base config""" + + config = self.project['model-defaults'].copy() + + model_configs = self.project['models'] + model_group_config = model_configs.get(model_group, {}) + model_config = model_group_config.get(model_name, {}) + + config.update(model_group_config) + config.update(model_config) + + return config + def __compile(self, src_index): - for src_path, files in src_index.iteritems(): + for src_path, files in src_index.items(): jinja = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath=src_path)) for f in files: + + model_group, model_name = self.__get_model_identifiers(f) + model_config = self.__get_model_config(model_group, model_name) + + if not model_config.get('enabled'): + continue + template = jinja.get_template(f) - self.__write(f, template.render(self.project.context())) + rendered = template.render(self.project.context()) + + create_stmt = self.__wrap_in_create(f, rendered, model_config) + + if create_stmt: + self.__write(f, create_stmt) def run(self): src_index = self.__src_index() diff --git a/dbt/task/debug.py b/dbt/task/debug.py index b3e876d61ad..45f5b56117f 100644 --- a/dbt/task/debug.py +++ b/dbt/task/debug.py @@ -7,6 +7,6 @@ def __init__(self, args, project): self.project = project def run(self): - print "args: {}".format(self.args) - print "project: " + print("args: {}".format(self.args)) + print("project: ") pprint.pprint(self.project) diff --git a/dbt/task/run.py b/dbt/task/run.py index acd79f498b4..ec8737c19c5 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -2,7 +2,10 @@ import psycopg2 import os import fnmatch +import re +import sqlparse +import networkx as nx class RedshiftTarget: def __init__(self, cfg): @@ -27,11 +30,105 @@ def get_handle(self): return psycopg2.connect(self.__get_spec()) +class Relation(object): + def __init__(self, schema, name): + self.schema = schema + self.name = name + + def valid(self): + return None not in (self.schema, self.name) + + @property + def val(self): + return "{}.{}".format(self.schema, self.name) + + def __repr__(self): + return self.val + + def __str__(self): + return self.val + +class Linker(object): + def __init__(self, graph=None): + if graph is None: + self.graph = nx.DiGraph() + else: + self.graph = graph + + self.node_sql_map = {} + + def extract_name_and_deps(self, stmt): + table_def = stmt.token_next_by_instance(0, sqlparse.sql.Identifier) + schema, tbl_or_view = table_def.get_parent_name(), table_def.get_real_name() + if schema is None or tbl_or_view is None: + raise RuntimeError('schema or view not defined?') + + definition = table_def.token_next_by_instance(0, sqlparse.sql.Parenthesis) + + definition_node = Relation(schema, tbl_or_view) + + local_defs = set() + new_nodes = set() + + def extract_deps(stmt): + token = stmt.token_first() + while token is not None: + excluded_types = [sqlparse.sql.Function] # don't dive into window functions + if type(token) not in excluded_types and token.is_group(): + # this is a thing that has a name -- note that! + local_defs.add(token.get_name()) + # recurse into the group + extract_deps(token) + + if type(token) == sqlparse.sql.Identifier: + new_node = Relation(token.get_parent_name(), token.get_real_name()) + + if new_node.valid(): + new_nodes.add(new_node) # don't add edges yet! + + index = stmt.token_index(token) + token = stmt.token_next(index) + + extract_deps(definition) + + # only add nodes which don't reference locally defined constructs + for new_node in new_nodes: + if new_node.schema not in local_defs: + self.graph.add_node(new_node.val) + self.graph.add_edge(definition_node.val, new_node.val) + + return definition_node.val + + def as_dependency_list(self): + order = nx.topological_sort(self.graph, reverse=True) + for node in order: + if node in self.node_sql_map: # TODO : + yield (node, self.node_sql_map[node]) + else: + pass + + def register(self, node, sql): + if node in self.node_sql_map: + raise RuntimeError("multiple declarations of node: {}".format(node)) + self.node_sql_map[node] = sql + + def link(self, sql): + sql = sql.strip() + for statement in sqlparse.parse(sql): + if statement.get_type().startswith('CREATE'): + node = self.extract_name_and_deps(statement) + self.register(node, sql) + else: + print("Ignoring {}".format(sql[0:100].replace('\n', ' '))) + + class RunTask: def __init__(self, args, project): self.args = args self.project = project + self.linker = Linker() + def __compiled_files(self): compiled_files = [] sql_path = self.project['target-path'] @@ -59,15 +156,49 @@ def __create_schema(self): with handle.cursor() as cursor: cursor.execute('create schema if not exists "{}"'.format(target_cfg['schema'])) + def __load_models(self): + target = self.__get_target() + for f in self.__compiled_files(): + with open(os.path.join(self.project['target-path'], f), 'r') as fh: + self.linker.link(fh.read()) + + def __query_for_existing(self, cursor, schema): + sql = """ + select '{schema}.' || tablename as name, 'table' as type from pg_tables where schemaname = '{schema}' + union all + select '{schema}.' || viewname as name, 'view' as type from pg_views where schemaname = '{schema}' """.format(schema=schema) + + cursor.execute(sql) + existing = [(name, relation_type) for (name, relation_type) in cursor.fetchall()] + + return dict(existing) + + def __drop(self, cursor, relation, relation_type): + sql = "drop {relation_type} if exists {relation} cascade".format(relation_type=relation_type, relation=relation) + cursor.execute(sql) + def __execute_models(self): target = self.__get_target() + with target.get_handle() as handle: with handle.cursor() as cursor: - for f in self.__compiled_files(): - with open(os.path.join(self.project['target-path'], f), 'r') as fh: - cursor.execute(fh.read()) - print " {}".format(cursor.statusmessage) + + existing = self.__query_for_existing(cursor, target.schema); + + for (relation, sql) in self.linker.as_dependency_list(): + + if relation in existing: + self.__drop(cursor, relation, existing[relation]) + handle.commit() + + print("creating {}".format(relation)) + #print(" {}...".format(re.sub( '\s+', ' ', sql[0:100] ).strip())) + cursor.execute(sql) + print(" {}".format(cursor.statusmessage)) + handle.commit() def run(self): self.__create_schema() + self.__load_models() self.__execute_models() + diff --git a/requirements.txt b/requirements.txt index bb604da014c..210bea51108 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,5 @@ argparse Jinja2>=2.8 PyYAML>=3.11 psycopg2==2.6.1 +sqlparse==0.1.19 +networkx==1.11 diff --git a/sample.dbt_project.yml b/sample.dbt_project.yml index 78219653860..eed170ec656 100644 --- a/sample.dbt_project.yml +++ b/sample.dbt_project.yml @@ -6,6 +6,17 @@ source-paths: ["model"] # paths with source code to compile target-path: "target" # path for compiled code clean-targets: ["target"] # directories removed by the clean task +model-defaults: + enabled: true # enable all models by default + materialized: false # If true, create tables. If false, create views + +models: + pardot: + enabled: false # disable all pardot models except where overriden + pardot_visitoractivity: # override configs for a particular model + enabled: true # enable this model + materialized: true # create a table instead of a view (overriding the base config) + # Run configuration # output environments outputs: diff --git a/setup.py b/setup.py index 1d0d30a00ab..a15f4c0afeb 100644 --- a/setup.py +++ b/setup.py @@ -13,9 +13,10 @@ 'scripts/dbt', ], install_requires=[ - 'argparse>=1.2.1', 'Jinja2>=2.8', 'PyYAML>=3.11', 'psycopg2==2.6.1', + 'sqlparse==0.1.19', + 'networkx==1.11', ], )