Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix catalog generation #808

Merged
merged 5 commits into from
Jul 3, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ def call(*args, **kwargs):
name = node.get('name')
module = template.make_module(
context, False, context)
macro = module.__dict__[dbt.utils.get_dbt_macro_name(name)]

if node['resource_type'] == NodeType.Operation:
macro = module.__dict__[dbt.utils.get_dbt_operation_name(name)]
else:
macro = module.__dict__[dbt.utils.get_dbt_macro_name(name)]
module.__dict__.update(context)

try:
Expand Down
3 changes: 0 additions & 3 deletions dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,6 @@ def compile(self):

self._check_resource_uniqueness(flat_graph)

flat_graph = dbt.parser.process_refs(flat_graph,
root_project.get('name'))

linked_graph = self.link_graph(linker, flat_graph)

stats = defaultdict(int)
Expand Down
60 changes: 38 additions & 22 deletions dbt/context/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import json
import os
import pytz

from dbt.adapters.factory import get_adapter
from dbt.compat import basestring, to_string
from dbt.compat import basestring
from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedMacro, ParsedNode

Expand Down Expand Up @@ -341,7 +340,32 @@ class AdapterWithContext(adapter_type):
return AdapterWithContext


def generate(model, project_cfg, flat_graph, provider=None):
def _add_model_context(context, node, project_cfg, flat_graph, db_wrapper,
provider):

# These fields do not apply to operations.
if node.get('resource_type') == NodeType.Operation:
return context

target_name = project_cfg.get('target')
profile = project_cfg.get('outputs').get(target_name)

pre_hooks = node.get('config', {}).get('pre-hook')
post_hooks = node.get('config', {}).get('post-hook')

model_context = {
"post_hooks": post_hooks,
"pre_hooks": pre_hooks,
"model": node,
"sql": node.get('injected_sql'),
"this": get_this_relation(db_wrapper, project_cfg, profile, node),
"ref": provider.ref(db_wrapper, node, project_cfg, profile, flat_graph)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref, sql should be in the global context

}

return dbt.utils.merge(context, model_context)


def generate(node, project_cfg, flat_graph, provider=None):
"""
Not meant to be called directly. Call with either:
dbt.context.parser.generate
Expand All @@ -358,17 +382,14 @@ def generate(model, project_cfg, flat_graph, provider=None):
target.pop('pass', None)
target['name'] = target_name
adapter = get_adapter(profile)
default_schema = profile.get('schema', 'public')

context = {'env': target}
schema = profile.get('schema', 'public')

pre_hooks = model.get('config', {}).get('pre-hook')
post_hooks = model.get('config', {}).get('post-hook')

relation_type = create_relation(adapter.Relation,
project_cfg.get('quoting'))

db_wrapper = DatabaseWrapper(model,
db_wrapper = DatabaseWrapper(node,
create_adapter(adapter, relation_type),
profile,
project_cfg)
Expand All @@ -382,44 +403,39 @@ def generate(model, project_cfg, flat_graph, provider=None):
"Column": adapter.Column,
},
"column": adapter.Column,
"config": provider.Config(model),
"config": provider.Config(node),
"env_var": _env_var,
"exceptions": dbt.exceptions,
"execute": provider.execute,
"flags": dbt.flags,
"graph": flat_graph,
"log": log,
"model": model,
"modules": {
"pytz": pytz,
"datetime": datetime
},
"post_hooks": post_hooks,
"pre_hooks": pre_hooks,
"ref": provider.ref(db_wrapper, model, project_cfg,
profile, flat_graph),
"return": _return,
"schema": model.get('schema', schema),
"sql": model.get('injected_sql'),
"sql_now": adapter.date_function(),
"fromjson": fromjson,
"schema": node.get('schema', default_schema),
"tojson": tojson,
"target": target,
"this": get_this_relation(db_wrapper, project_cfg, profile, model),
"try_or_compiler_error": try_or_compiler_error(model)
"try_or_compiler_error": try_or_compiler_error(node)
})

context = _add_tracking(context)
context = _add_validation(context)
context = _add_sql_handlers(context)
context = _add_model_context(context, node, project_cfg,
flat_graph, db_wrapper, provider)

# we make a copy of the context for each of these ^^

context = _add_macros(context, model, flat_graph)
context = _add_macros(context, node, flat_graph)

context["write"] = write(model, project_cfg.get('target-path'), 'run')
context["render"] = render(context, model)
context["var"] = Var(model, context=context, overrides=cli_var_overrides)
context["write"] = write(node, project_cfg.get('target-path'), 'run')
context["render"] = render(context, node)
context["var"] = Var(node, context=context, overrides=cli_var_overrides)
context['context'] = context

return context
4 changes: 3 additions & 1 deletion dbt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def load_all(cls, root_project, all_projects):
for loader in cls._LOADERS:
nodes.update(loader.load_all(root_project, all_projects, macros))

return ParsedManifest(nodes=nodes, macros=macros)
manifest = ParsedManifest(nodes=nodes, macros=macros)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tell the ParsedManifest to process itself:

  • process refs
  • incorporate Schema Spec info

UnparsedManifest.parse() ---> ParsedManifest

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i've started an issue for this here: #821

manifest = dbt.parser.process_refs(manifest, root_project)
return manifest

@classmethod
def register(cls, loader):
Expand Down
57 changes: 34 additions & 23 deletions dbt/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def resolve_ref(flat_graph, target_model_name, target_model_package,
None)


def process_refs(flat_graph, current_project):
for _, node in flat_graph.get('nodes').items():
def process_refs(manifest, current_project):
flat_graph = manifest.to_flat_graph()
for _, node in manifest.nodes.items():
target_model = None
target_model_name = None
target_model_package = None
Expand Down Expand Up @@ -106,7 +107,7 @@ def process_refs(flat_graph, current_project):
node['depends_on']['nodes'].append(target_model_id)
flat_graph['nodes'][node['unique_id']] = node

return flat_graph
return manifest


def get_fqn(path, package_project_config, extra=[]):
Expand Down Expand Up @@ -154,28 +155,38 @@ def parse_macro_file(macro_file_path,
raise e

for key, item in template.module.__dict__.items():
if type(item) == jinja2.runtime.Macro:
if type(item) != jinja2.runtime.Macro:
continue

node_type = None
if key.startswith(dbt.utils.MACRO_PREFIX):
node_type = NodeType.Macro
name = key.replace(dbt.utils.MACRO_PREFIX, '')

unique_id = get_path(resource_type,
package_name,
name)

merged = dbt.utils.deep_merge(
base_node.serialize(),
{
'name': name,
'unique_id': unique_id,
'tags': tags,
'resource_type': resource_type,
'depends_on': {'macros': []},
})

new_node = ParsedMacro(
template=template,
**merged)

to_return[unique_id] = new_node
elif key.startswith(dbt.utils.OPERATION_PREFIX):
node_type = NodeType.Operation
name = key.replace(dbt.utils.OPERATION_PREFIX, '')

if node_type != resource_type:
continue

unique_id = get_path(resource_type, package_name, name)

merged = dbt.utils.deep_merge(
base_node.serialize(),
{
'name': name,
'unique_id': unique_id,
'tags': tags,
'resource_type': resource_type,
'depends_on': {'macros': []},
})

new_node = ParsedMacro(
template=template,
**merged)

to_return[unique_id] = new_node

return to_return

Expand Down
38 changes: 9 additions & 29 deletions dbt/task/generate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import json
import os

from dbt.contracts.graph.parsed import ParsedManifest, ParsedNode, ParsedMacro
from dbt.adapters.factory import get_adapter
from dbt.clients.system import write_file
from dbt.compat import bigint
from dbt.include import GLOBAL_DBT_MODULES_PATH
from dbt.node_types import NodeType
import dbt.ui.printer
import dbt.utils
import dbt.compilation

from dbt.task.base_task import BaseTask

Expand Down Expand Up @@ -92,42 +92,22 @@ def unflatten(columns):


class GenerateTask(BaseTask):
def get_all_projects(self):
root_project = self.project.cfg
all_projects = {root_project.get('name'): root_project}
# we only need to load the global deps. We haven't compiled, so our
# project['module-path'] does not exist.
dependency_projects = dbt.utils.dependencies_for_path(
self.project, GLOBAL_DBT_MODULES_PATH
)

for project in dependency_projects:
name = project.cfg.get('name', 'unknown')
all_projects[name] = project.cfg

if dbt.flags.STRICT_MODE:
dbt.contracts.project.ProjectList(**all_projects)

return all_projects
def _get_manifest(self, project):
compiler = dbt.compilation.Compiler(project)
compiler.initialize()

def _get_manifest(self):
# TODO: I'd like to do this better. We can't use
# utils.dependency_projects because it assumes you have compiled your
# project (I think?) - it assumes that you have an existing and
# populated project['modules-path'], but 'catalog generate' shouldn't
# require that. It might be better to suppress the exception in
# dependency_projects if that's reasonable, or make it a flag.
root_project = self.project.cfg
all_projects = self.get_all_projects()
root_project = project.cfg
all_projects = compiler.get_all_projects()

manifest = dbt.loader.GraphLoader.load_all(root_project, all_projects)
return manifest

def run(self):
manifest = self._get_manifest()
manifest = self._get_manifest(self.project)
profile = self.project.run_environment()
adapter = get_adapter(profile)

dbt.ui.printer.print_timestamped_line("Building catalog")
results = adapter.get_catalog(profile, self.project.cfg, manifest)

results = [
Expand Down
7 changes: 6 additions & 1 deletion dbt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,17 @@ def find_in_subgraph_by_name(subgraph, target_name, target_package, nodetype):


MACRO_PREFIX = 'dbt_macro__'
OPERATION_PREFIX = 'dbt_operation__'


def get_dbt_macro_name(name):
return '{}{}'.format(MACRO_PREFIX, name)


def get_dbt_operation_name(name):
return '{}{}'.format(OPERATION_PREFIX, name)


def get_materialization_macro_name(materialization_name, adapter_type=None,
with_prefix=True):
if adapter_type is None:
Expand Down Expand Up @@ -193,7 +198,7 @@ def get_materialization_macro(flat_graph, materialization_name,

def get_operation_macro_name(operation_name, with_prefix=True):
if with_prefix:
return get_dbt_macro_name(operation_name)
return get_dbt_operation_name(operation_name)
else:
return operation_name

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def project_config(self):

@attr(type='postgres')
def test_simple_generate(self):
self.run_dbt(["deps"])
self.run_dbt(["docs", "generate"])
self.assertTrue(os.path.exists('./target/catalog.json'))

Expand Down
18 changes: 14 additions & 4 deletions test/unit/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dbt.parser

from dbt.node_types import NodeType
from dbt.contracts.graph.parsed import ParsedManifest, ParsedNode, ParsedMacro

def get_os_path(unix_path):
return os.path.normpath(unix_path)
Expand Down Expand Up @@ -680,8 +681,14 @@ def test__process_refs__packages(self):
}
}

manifest = ParsedManifest(
nodes={k: ParsedNode(**v) for (k,v) in graph['nodes'].items()},
macros={k: ParsedMacro(**v) for (k,v) in graph['macros'].items()},
)

processed_manifest = dbt.parser.process_refs(manifest, 'root')
self.assertEquals(
dbt.parser.process_refs(graph, 'root'),
processed_manifest.to_flat_graph(),
{
'macros': {},
'nodes': {
Expand All @@ -703,7 +710,8 @@ def test__process_refs__packages(self):
'path': 'events.sql',
'original_file_path': 'events.sql',
'root_path': get_os_path('/usr/src/app'),
'raw_sql': 'does not matter'
'raw_sql': 'does not matter',
'agate_table': None,
},
'model.root.events': {
'name': 'events',
Expand All @@ -723,7 +731,8 @@ def test__process_refs__packages(self):
'path': 'events.sql',
'original_file_path': 'events.sql',
'root_path': get_os_path('/usr/src/app'),
'raw_sql': 'does not matter'
'raw_sql': 'does not matter',
'agate_table': None,
},
'model.root.dep': {
'name': 'dep',
Expand All @@ -743,7 +752,8 @@ def test__process_refs__packages(self):
'path': 'multi.sql',
'original_file_path': 'multi.sql',
'root_path': get_os_path('/usr/src/app'),
'raw_sql': 'does not matter'
'raw_sql': 'does not matter',
'agate_table': None,
}
}
}
Expand Down