Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile on-run-(start|end) hooks to file #412

Merged
merged 7 commits into from
May 9, 2017
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,12 @@ def link_node(self, linker, node, flat_graph):
def link_graph(self, linker, flat_graph):
linked_graph = {
'nodes': {},
'macros': flat_graph.get('macros'),
'operations': flat_graph.get('operations'),
'macros': flat_graph.get('macros')
}

for node_type in ['nodes', 'operations']:
for name, node in flat_graph.get(node_type).items():
self.link_node(linker, node, flat_graph)
linked_graph[node_type][name] = node
for name, node in flat_graph.get('nodes').items():
self.link_node(linker, node, flat_graph)
linked_graph['nodes'][name] = node

cycle = linker.find_cycles()

Expand Down Expand Up @@ -399,17 +397,6 @@ def get_parsed_macros(self, root_project, all_projects):

return parsed_macros

def get_parsed_operations(self, root_project, all_projects):
parsed_operations = {}

for name, project in all_projects.items():
parsed_operations.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects, 'on-run-start'))
parsed_operations.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects, 'on-run-end'))

return parsed_operations

def get_parsed_models(self, root_project, all_projects):
parsed_models = {}

Expand Down Expand Up @@ -473,9 +460,6 @@ def get_parsed_schema_tests(self, root_project, all_projects):
def load_all_macros(self, root_project, all_projects):
return self.get_parsed_macros(root_project, all_projects)

def load_all_operations(self, root_project, all_projects):
return self.get_parsed_operations(root_project, all_projects)

def load_all_nodes(self, root_project, all_projects):
all_nodes = {}

Expand All @@ -488,6 +472,8 @@ def load_all_nodes(self, root_project, all_projects):
all_nodes.update(
dbt.parser.parse_archives_from_projects(root_project,
all_projects))
all_nodes.update(
dbt.parser.load_and_parse_run_hooks(root_project, all_projects))

return all_nodes

Expand All @@ -498,13 +484,12 @@ def compile(self):
all_projects = self.get_all_projects()

all_macros = self.load_all_macros(root_project, all_projects)

all_nodes = self.load_all_nodes(root_project, all_projects)
all_operations = self.load_all_operations(root_project, all_projects)

flat_graph = {
'nodes': all_nodes,
'macros': all_macros,
'operations': all_operations
'macros': all_macros
}

flat_graph = dbt.parser.process_refs(flat_graph,
Expand All @@ -520,9 +505,6 @@ def compile(self):
for node_name, node in linked_graph.get('macros').items():
stats[node.get('resource_type')] += 1

for node_name, node in linked_graph.get('operations').items():
stats[node.get('resource_type')] += 1

print_compile_stats(stats)

return linked_graph, linker
16 changes: 14 additions & 2 deletions dbt/parser.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def get_hooks(all_projects, hook_type):
return project_hooks


def load_and_parse_run_hooks(root_project, all_projects, hook_type):
def load_and_parse_run_hook_type(root_project, all_projects, hook_type):

if dbt.flags.STRICT_MODE:
dbt.contracts.project.validate_list(all_projects)
Expand All @@ -374,7 +374,19 @@ def load_and_parse_run_hooks(root_project, all_projects, hook_type):
'raw_sql': hooks
})

return parse_sql_nodes(result, root_project, all_projects, tags={hook_type})
tags = {hook_type}
return parse_sql_nodes(result, root_project, all_projects, tags=tags)


def load_and_parse_run_hooks(root_project, all_projects):
hook_nodes = {}
for hook_type in dbt.utils.RunHookTypes.Both:
project_hooks = load_and_parse_run_hook_type(root_project,
all_projects,
hook_type)
hook_nodes.update(project_hooks)

return hook_nodes


def load_and_parse_macros(package_name, root_project, all_projects, root_dir,
Expand Down
34 changes: 16 additions & 18 deletions dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from dbt.adapters.factory import get_adapter
from dbt.logger import GLOBAL_LOGGER as logger

from dbt.utils import get_materialization, NodeType, is_type
from dbt.utils import get_materialization, NodeType, is_type, get_nodes_by_tags

import dbt.clients.jinja
import dbt.compilation
Expand Down Expand Up @@ -369,15 +369,6 @@ def execute_archive(profile, node, context):
return result


def run_hooks(profile, hooks):
adapter = get_adapter(profile)

master_connection = adapter.begin(profile)
compiled_hooks = [hook['wrapped_sql'] for hook in hooks]
adapter.execute_all(profile=profile, sqls=compiled_hooks)
master_connection = adapter.commit(master_connection)


def track_model_run(index, num_nodes, run_model_result):
invocation_id = dbt.tracking.active_user.invocation_id
dbt.tracking.track_model_run({
Expand Down Expand Up @@ -619,6 +610,18 @@ def as_concurrent_dep_list(self, linker, nodes_to_run):

return concurrent_dependency_list

def run_hooks(self, profile, flat_graph, hook_type):
adapter = get_adapter(profile)

nodes = flat_graph.get('nodes', {}).values()
start_hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation)
hooks = [self.compile_node(hook, flat_graph) for hook in start_hooks]

master_connection = adapter.begin(profile)
compiled_hooks = [hook['wrapped_sql'] for hook in hooks]
adapter.execute_all(profile=profile, sqls=compiled_hooks)
master_connection = adapter.commit(master_connection)

def on_model_failure(self, linker, selected_nodes):
def skip_dependent(node):
dependent_nodes = linker.get_dependent_nodes(node.get('unique_id'))
Expand Down Expand Up @@ -672,9 +675,7 @@ def execute_nodes(self, flat_graph, node_dependency_list, on_failure,
start_time = time.time()

if should_run_hooks:
start_hooks = dbt.utils.get_nodes_by_tags(flat_graph, {'on-run-start'}, "operations")
hooks = [self.compile_node(hook, flat_graph) for hook in start_hooks]
run_hooks(profile, hooks)
self.run_hooks(profile, flat_graph, dbt.utils.RunHookTypes.Start)

def get_idx(node):
return node_id_to_index_map.get(node.get('unique_id'))
Expand Down Expand Up @@ -721,9 +722,7 @@ def get_idx(node):
pool.join()

if should_run_hooks:
end_hooks = dbt.utils.get_nodes_by_tags(flat_graph, {'on-run-end'}, "operations")
hooks = [self.compile_node(hook, flat_graph) for hook in end_hooks]
run_hooks(profile, hooks)
self.run_hooks(profile, flat_graph, dbt.utils.RunHookTypes.End)

execution_time = time.time() - start_time

Expand Down Expand Up @@ -879,8 +878,7 @@ def compile_models(self, include_spec, exclude_spec):
resource_types=resource_types,
tags=set(),
should_run_hooks=False,
should_execute=False,
flatten_graph=True)
should_execute=False)

def run_models(self, include_spec, exclude_spec):
return self.run_types_from_graph(include_spec,
Expand Down
16 changes: 11 additions & 5 deletions dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class NodeType(object):
Operation = 'operation'


class RunHookTypes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunHookType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reasonable

Start = 'on-run-start'
End = 'on-run-end'
Both = [Start, End]


class This(object):
def __init__(self, schema, table, name):
self.schema = schema
Expand Down Expand Up @@ -285,10 +291,10 @@ def get_run_status_line(results):
))


def get_nodes_by_tags(flat_graph, match_tags, resource_type):
nodes = []
for node_name, node in flat_graph[resource_type].items():
def get_nodes_by_tags(nodes, match_tags, resource_type):
matched_nodes = []
for node in nodes:
node_tags = node.get('tags', set())
if len(node_tags & match_tags):
nodes.append(node)
return nodes
matched_nodes.append(node)
return matched_nodes