Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DAGnabit #9

Merged
merged 26 commits into from
Apr 3, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions dbt/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
import copy

default_project_cfg = {
'source-paths': ['model'],
'source-paths': ['models'],
'test-paths': ['test'],
'target-path': 'target',
'clean-targets': ['target'],
'outputs': {'default': {}},
'run-target': 'default',
'models': {},
'model-defaults': {
"enabled": True,
"materialized": False
}
}

default_profiles = {
Expand All @@ -18,7 +23,6 @@

default_active_profiles = ['user']


class Project:

def __init__(self, cfg, profiles, active_profile_names=[]):
Expand Down Expand Up @@ -46,6 +50,9 @@ def __contains__(self, key):
def __setitem__(self, key, value):
return self.cfg.__setitem__(key, value)

def get(self, key, default=None):
return self.cfg.get(key, default)

def run_environment(self):
target_name = self.cfg['run-target']
return self.cfg['outputs'][target_name]
Expand Down
87 changes: 63 additions & 24 deletions dbt/task/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,27 @@
import os
import fnmatch
import jinja2

import yaml
from collections import defaultdict

class CompileTask:
def __init__(self, args, project):
self.args = args
self.project = project

def __is_specified_model(self, path):
if 'models' not in self.project:
return True

path_parts = path.split("/")
if len(path_parts) < 2:
return False
else:
model = path_parts[1]
for allowed_model in self.project['models']:
if fnmatch.fnmatch(model, allowed_model):
return True
return False

def __src_index(self):
"""returns: {'model': ['pardot/model.sql', 'segment/model.sql']}
"""
indexed_files = {}
indexed_files = defaultdict(list)

for source_path in self.project['source-paths']:
for root, dirs, files in os.walk(source_path):
if not self.__is_specified_model(root):
continue
for filename in files:
abs_path = os.path.join(root, filename)
rel_path = os.path.relpath(abs_path, source_path)

if fnmatch.fnmatch(filename, "*.sql"):
abs_path = os.path.join(root, filename)
rel_path = os.path.relpath(abs_path, source_path)
indexed_files.setdefault(source_path, []).append(rel_path)
indexed_files[source_path].append(rel_path)

return indexed_files

Expand All @@ -46,17 +32,70 @@ def __write(self, path, payload):
if not os.path.exists(os.path.dirname(target_path)):
os.makedirs(os.path.dirname(target_path))
elif os.path.exists(target_path):
print "Compiler overwrite of {}".format(target_path)
print("Compiler overwrite of {}".format(target_path))

with open(target_path, 'w') as f:
f.write(payload)

def __wrap_in_create(self, path, query, model_config):
filename = os.path.basename(path)
identifier, ext = os.path.splitext(filename)

# default to view if not provided in config!
table_or_view = 'table' if model_config['materialized'] else 'view'

ctx = self.project.context()
schema = ctx['env']['schema']

create_template = "create {table_or_view} {schema}.{identifier} as ( {query} );"

opts = {
"table_or_view": table_or_view,
"schema": schema,
"identifier": identifier,
"query": query
}

return create_template.format(**opts)

def __get_model_identifiers(self, model_filepath):
model_group = os.path.dirname(model_filepath)
model_name, _ = os.path.splitext(os.path.basename(model_filepath))
return model_group, model_name

def __get_model_config(self, model_group, model_name):
"""merges model, model group, and base configs together. Model config
takes precedence, then model_group, then base config"""

config = self.project['model-defaults'].copy()

model_configs = self.project['models']
model_group_config = model_configs.get(model_group, {})
model_config = model_group_config.get(model_name, {})

config.update(model_group_config)
config.update(model_config)

return config

def __compile(self, src_index):
for src_path, files in src_index.iteritems():
for src_path, files in src_index.items():
jinja = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath=src_path))
for f in files:

model_group, model_name = self.__get_model_identifiers(f)
model_config = self.__get_model_config(model_group, model_name)

if not model_config.get('enabled'):
continue

template = jinja.get_template(f)
self.__write(f, template.render(self.project.context()))
rendered = template.render(self.project.context())

create_stmt = self.__wrap_in_create(f, rendered, model_config)

if create_stmt:
self.__write(f, create_stmt)

def run(self):
src_index = self.__src_index()
Expand Down
4 changes: 2 additions & 2 deletions dbt/task/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ def __init__(self, args, project):
self.project = project

def run(self):
print "args: {}".format(self.args)
print "project: "
print("args: {}".format(self.args))
print("project: ")
pprint.pprint(self.project)
139 changes: 135 additions & 4 deletions dbt/task/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import psycopg2
import os
import fnmatch
import re

import sqlparse
import networkx as nx

class RedshiftTarget:
def __init__(self, cfg):
Expand All @@ -27,11 +30,105 @@ def get_handle(self):
return psycopg2.connect(self.__get_spec())


class Relation(object):
def __init__(self, schema, name):
self.schema = schema
self.name = name

def valid(self):
return None not in (self.schema, self.name)

@property
def val(self):
return "{}.{}".format(self.schema, self.name)

def __repr__(self):
return self.val

def __str__(self):
return self.val

class Linker(object):
def __init__(self, graph=None):
if graph is None:
self.graph = nx.DiGraph()
else:
self.graph = graph

self.node_sql_map = {}

def extract_name_and_deps(self, stmt):
table_def = stmt.token_next_by_instance(0, sqlparse.sql.Identifier)
schema, tbl_or_view = table_def.get_parent_name(), table_def.get_real_name()
if schema is None or tbl_or_view is None:
raise RuntimeError('schema or view not defined?')

definition = table_def.token_next_by_instance(0, sqlparse.sql.Parenthesis)

definition_node = Relation(schema, tbl_or_view)

local_defs = set()
new_nodes = set()

def extract_deps(stmt):
token = stmt.token_first()
while token is not None:
excluded_types = [sqlparse.sql.Function] # don't dive into window functions
if type(token) not in excluded_types and token.is_group():
# this is a thing that has a name -- note that!
local_defs.add(token.get_name())
# recurse into the group
extract_deps(token)

if type(token) == sqlparse.sql.Identifier:
new_node = Relation(token.get_parent_name(), token.get_real_name())

if new_node.valid():
new_nodes.add(new_node) # don't add edges yet!

index = stmt.token_index(token)
token = stmt.token_next(index)

extract_deps(definition)

# only add nodes which don't reference locally defined constructs
for new_node in new_nodes:
if new_node.schema not in local_defs:
self.graph.add_node(new_node.val)
self.graph.add_edge(definition_node.val, new_node.val)

return definition_node.val

def as_dependency_list(self):
order = nx.topological_sort(self.graph, reverse=True)
for node in order:
if node in self.node_sql_map: # TODO :
yield (node, self.node_sql_map[node])
else:
pass

def register(self, node, sql):
if node in self.node_sql_map:
raise RuntimeError("multiple declarations of node: {}".format(node))
self.node_sql_map[node] = sql

def link(self, sql):
sql = sql.strip()
for statement in sqlparse.parse(sql):
if statement.get_type().startswith('CREATE'):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is expecting the model to be specified as create view as... or create table as... and then re-writes that create statement if necessary? I see that's a convenient way to allow the user to specify the name of the thing, but I think it will either break or produce surprising results if anything other than the most generic forms for create table/view as... are used. For example, what's going to happen if someone writes create temporary table as ...?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code runs after the compilation step. Input to the system should be a SELECT statement. This is compiled to a CREATE statement, which is then run. I suppose this code is redundant, but it was originally intended to avoid running destructive SQL unintentionally

There is a single select statement per file, and the table/view name is inferred from the file name containing the SELECT statement

Sent from my iPhone

On Apr 1, 2016, at 6:29 PM, Christopher Merrick notifications@github.com wrote:

In dbt/task/run.py:

  •    order = nx.topological_sort(self.graph, reverse=True)
    
  •    for node in order:
    
  •        if node in self.node_sql_map: # TODO :
    
  •            yield (node, self.node_sql_map[node])
    
  •        else:
    
  •            pass
    
  • def register(self, node, sql):
  •    if node in self.node_sql_map:
    
  •        raise RuntimeError("multiple declarations of node: {}".format(node))
    
  •    self.node_sql_map[node] = sql
    
  • def link(self, sql):
  •    sql = sql.strip()
    
  •    for statement in sqlparse.parse(sql):
    
  •        if statement.get_type().startswith('CREATE'):
    
    so this is expecting the model to be specified as create view as... or create table as... and then re-writes that create statement if necessary? I see that's a convenient way to allow the user to specify the name of the thing, but I think it will either break or produce surprising results if anything other than the most generic forms for create table/view as... are used. For example, what's going to happen if someone writes create temporary table as ...?


You are receiving this because you authored the thread.
Reply to this email directly or view it on GitHub

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok, the simultaneous introduction of dependency-graphing and create-wrapping got me mixed up. That sounds right. In the future it would be good to provide that blurb when you create the PR, it's useful for understanding the approach.

node = self.extract_name_and_deps(statement)
self.register(node, sql)
else:
print("Ignoring {}".format(sql[0:100].replace('\n', ' ')))


class RunTask:
def __init__(self, args, project):
self.args = args
self.project = project

self.linker = Linker()

def __compiled_files(self):
compiled_files = []
sql_path = self.project['target-path']
Expand Down Expand Up @@ -59,15 +156,49 @@ def __create_schema(self):
with handle.cursor() as cursor:
cursor.execute('create schema if not exists "{}"'.format(target_cfg['schema']))

def __load_models(self):
target = self.__get_target()
for f in self.__compiled_files():
with open(os.path.join(self.project['target-path'], f), 'r') as fh:
self.linker.link(fh.read())

def __query_for_existing(self, cursor, schema):
sql = """
select '{schema}.' || tablename as name, 'table' as type from pg_tables where schemaname = '{schema}'
union all
select '{schema}.' || viewname as name, 'view' as type from pg_views where schemaname = '{schema}' """.format(schema=schema)

cursor.execute(sql)
existing = [(name, relation_type) for (name, relation_type) in cursor.fetchall()]

return dict(existing)

def __drop(self, cursor, relation, relation_type):
sql = "drop {relation_type} if exists {relation} cascade".format(relation_type=relation_type, relation=relation)
cursor.execute(sql)

def __execute_models(self):
target = self.__get_target()

with target.get_handle() as handle:
with handle.cursor() as cursor:
for f in self.__compiled_files():
with open(os.path.join(self.project['target-path'], f), 'r') as fh:
cursor.execute(fh.read())
print " {}".format(cursor.statusmessage)

existing = self.__query_for_existing(cursor, target.schema);

for (relation, sql) in self.linker.as_dependency_list():

if relation in existing:
self.__drop(cursor, relation, existing[relation])
handle.commit()

print("creating {}".format(relation))
#print(" {}...".format(re.sub( '\s+', ' ', sql[0:100] ).strip()))
cursor.execute(sql)
print(" {}".format(cursor.statusmessage))
handle.commit()

def run(self):
self.__create_schema()
self.__load_models()
self.__execute_models()

2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@ argparse
Jinja2>=2.8
PyYAML>=3.11
psycopg2==2.6.1
sqlparse==0.1.19
networkx==1.11
11 changes: 11 additions & 0 deletions sample.dbt_project.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@ source-paths: ["model"] # paths with source code to compile
target-path: "target" # path for compiled code
clean-targets: ["target"] # directories removed by the clean task

model-defaults:
enabled: true # enable all models by default
materialized: false # If true, create tables. If false, create views

models:
pardot:
enabled: false # disable all pardot models except where overriden
pardot_visitoractivity: # override configs for a particular model
enabled: true # enable this model
materialized: true # create a table instead of a view (overriding the base config)

# Run configuration
# output environments
outputs:
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
'scripts/dbt',
],
install_requires=[
'argparse>=1.2.1',
'Jinja2>=2.8',
'PyYAML>=3.11',
'psycopg2==2.6.1',
'sqlparse==0.1.19',
'networkx==1.11',
],
)