Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow muliple config() calls (#558) #1150

Merged
merged 3 commits into from
Nov 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dbt/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dbt.exceptions
import dbt.flags
import dbt.model
import dbt.utils
import dbt.hooks
import dbt.clients.jinja
Expand All @@ -11,6 +10,7 @@
from dbt.utils import coalesce
from dbt.logger import GLOBAL_LOGGER as logger
from dbt.contracts.graph.parsed import ParsedNode
from dbt.parser.source_config import SourceConfig


class BaseParser(object):
Expand Down Expand Up @@ -71,7 +71,7 @@ def parse_node(cls, node, node_path, root_project_config,
fqn = cls.get_fqn(node.get('path'), package_project_config,
fqn_extra)

config = dbt.model.SourceConfig(
config = SourceConfig(
root_project_config,
package_project_config,
fqn,
Expand Down
108 changes: 31 additions & 77 deletions dbt/model.py → dbt/parser/source_config.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import os.path

import dbt.exceptions

from dbt.compat import basestring

from dbt.utils import split_path, deep_merge, DBTConfigKeys
from dbt.utils import deep_merge, DBTConfigKeys
from dbt.node_types import NodeType


class SourceConfig(object):
ConfigKeys = DBTConfigKeys

AppendListFields = ['pre-hook', 'post-hook', 'tags']
ExtendDictFields = ['vars', 'column_types', 'quoting']
ClobberFields = [
AppendListFields = {'pre-hook', 'post-hook', 'tags'}
ExtendDictFields = {'vars', 'column_types', 'quoting'}
ClobberFields = {
'alias',
'schema',
'enabled',
Expand All @@ -23,8 +19,8 @@ class SourceConfig(object):
'sql_where',
'unique_key',
'sort_type',
'bind'
]
'bind',
}

def __init__(self, active_project, own_project, fqn, node_type):
self._config = None
Expand All @@ -37,7 +33,7 @@ def __init__(self, active_project, own_project, fqn, node_type):
self.in_model_config = {}

# make sure we categorize all configs
all_configs = self.AppendListFields + self.ExtendDictFields + \
all_configs = self.AppendListFields | self.ExtendDictFields | \
self.ClobberFields

for config in self.ConfigKeys:
Expand Down Expand Up @@ -90,15 +86,23 @@ def config(self):
return cfg

def update_in_model_config(self, config):
config = config.copy()

# make sure we're not clobbering an array of hooks with a single hook
# string
for field in self.AppendListFields:
if field in config:
config[field] = self.__get_as_list(config, field)

self.in_model_config.update(config)
for key, value in config.items():
if key in self.AppendListFields:
current = self.in_model_config.get(key, [])
if not isinstance(value, (list, tuple)):
value = [value]
current.extend(value)
self.in_model_config[key] = current
elif key in self.ExtendDictFields:
current = self.in_model_config.get(key, {})
if not isinstance(current, dict):
dbt.exceptions.raise_compiler_error(
'Invalid config field: "{}" must be a dict'.format(key)
)
current.update(value)
self.in_model_config[key] = current
else: # key in self.ClobberFields
self.in_model_config[key] = value

def __get_as_list(self, relevant_configs, key):
if key not in relevant_configs:
Expand All @@ -116,17 +120,17 @@ def smart_update(self, mutable_config, new_configs):
in new_configs if key in self.ConfigKeys
}

for key in SourceConfig.AppendListFields:
for key in self.AppendListFields:
append_fields = self.__get_as_list(relevant_configs, key)
mutable_config[key].extend([
f for f in append_fields if f not in mutable_config[key]
])

for key in SourceConfig.ExtendDictFields:
for key in self.ExtendDictFields:
dict_val = relevant_configs.get(key, {})
mutable_config[key].update(dict_val)

for key in SourceConfig.ClobberFields:
for key in self.ClobberFields:
if key in relevant_configs:
mutable_config[key] = relevant_configs[key]

Expand All @@ -136,9 +140,9 @@ def get_project_config(self, runtime_config):
# most configs are overwritten by a more specific config, but pre/post
# hooks are appended!
config = {}
for k in SourceConfig.AppendListFields:
for k in self.AppendListFields:
config[k] = []
for k in SourceConfig.ExtendDictFields:
for k in self.ExtendDictFields:
config[k] = {}

if self.node_type == NodeType.Seed:
Expand All @@ -163,8 +167,8 @@ def get_project_config(self, runtime_config):

clobber_configs = {
k: v for (k, v) in relevant_configs.items()
if k not in SourceConfig.AppendListFields and
k not in SourceConfig.ExtendDictFields
if k not in self.AppendListFields and
k not in self.ExtendDictFields
}

config.update(clobber_configs)
Expand All @@ -177,53 +181,3 @@ def load_config_from_own_project(self):

def load_config_from_active_project(self):
return self.get_project_config(self.active_project)


class DBTSource(object):
def __init__(self, project, top_dir, rel_filepath, own_project):
self._config = None
self.project = project
self.own_project = own_project

self.top_dir = top_dir
self.rel_filepath = rel_filepath
self.filepath = os.path.join(top_dir, rel_filepath)
self.name = self.fqn[-1]

self.source_config = SourceConfig(project, own_project, self.fqn)

def compile(self):
raise RuntimeError("Not implemented!")

@property
def config(self):
if self._config is not None:
return self._config

return self.source_config.config

@property
def fqn(self):
"""
fully-qualified name for model. Includes all subdirs below 'models'
path and the filename
"""
parts = split_path(self.filepath)
name, _ = os.path.splitext(parts[-1])
return [self.own_project['name']] + parts[1:-1] + [name]

@property
def nice_name(self):
return "{}.{}".format(self.fqn[0], self.fqn[-1])


class Csv(DBTSource):
def __init__(self, project, target_dir, rel_filepath, own_project):
super(Csv, self).__init__(
project, target_dir, rel_filepath, own_project
)

def __repr__(self):
return "<Csv {}.{}: {}>".format(
self.project['name'], self.model_name, self.filepath
)
1 change: 0 additions & 1 deletion dbt/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import dbt.exceptions
import dbt.linker
import dbt.tracking
import dbt.model
import dbt.ui.printer
import dbt.utils
from dbt.clients.system import write_json
Expand Down
39 changes: 0 additions & 39 deletions dbt/source.py

This file was deleted.

2 changes: 2 additions & 0 deletions test/integration/039_config_test/data/seed.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
id,value
4,2
15 changes: 15 additions & 0 deletions test/integration/039_config_test/models/model.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{{
config(
materialized='view',
tags=['tag_two'],
)
}}

{{
config(
materialized='table',
tags=['tag_three'],
)
}}

select 4 as id, 2 as value
41 changes: 41 additions & 0 deletions test/integration/039_config_test/test_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

from test.integration.base import DBTIntegrationTest, use_profile


class TestConfigs(DBTIntegrationTest):
@property
def schema(self):
return "config_039"

def unique_schema(self):
return super(TestConfigs, self).unique_schema().upper()

@property
def project_config(self):
return {
'data-paths': ['test/integration/039_config_test/data'],
'models': {
'test': {
# the model configs will override this
'materialized': 'invalid',
# the model configs will append to these
'tags': ['tag_one'],
},
},
}

@property
def models(self):
return "test/integration/039_config_test/models"

@use_profile('postgres')
def test_postgres_config_layering(self):
self.assertEqual(len(self.run_dbt(['seed'])), 1)
# test the project-level tag, and both config() call tags
self.assertEqual(len(self.run_dbt(['run', '--model', 'tag:tag_one'])), 1)
self.assertEqual(len(self.run_dbt(['run', '--model', 'tag:tag_two'])), 1)
self.assertEqual(len(self.run_dbt(['run', '--model', 'tag:tag_three'])), 1)
self.assertTablesEqual('seed', 'model')
# make sure we overwrote the materialization properly
models = self.get_models_in_schema()
self.assertEqual(models['model'], 'table')
1 change: 0 additions & 1 deletion test/unit/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import dbt.exceptions
import dbt.flags
import dbt.linker
import dbt.model
import dbt.config
import dbt.templates
import dbt.utils
Expand Down
Loading