Skip to content

Commit

Permalink
custom schemas (#522)
Browse files Browse the repository at this point in the history
* set schema on node, add validations

* adapter, runner functions around multiple schemas

* update adapters/materializations to support custom schemas

* ref tweaks

* fix unit tests, pep8

* make schema work in dbt_project.yml

* faster schema selection, fix for snowflake

* don't look for existing if no schemas

* slight cleanup

* dumb

* pep8

* integration tests

* use macro instead of schema_prefix

* custom schemas thru macros

* rm makefile change

* fix unit tests
  • Loading branch information
drewbanin authored Sep 14, 2017
1 parent 1e10f5b commit 5cc2e13
Show file tree
Hide file tree
Showing 31 changed files with 531 additions and 137 deletions.
34 changes: 24 additions & 10 deletions dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,14 @@ def open_connection(cls, connection):
return result

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
dataset = cls.get_dataset(profile, schema, model_name)
tables = dataset.list_tables()
def query_for_existing(cls, profile, schemas, model_name=None):
if not isinstance(schemas, (list, tuple)):
schemas = [schemas]

all_tables = []
for schema in schemas:
dataset = cls.get_dataset(profile, schema, model_name)
all_tables.extend(dataset.list_tables())

relation_type_lookup = {
'TABLE': 'table',
Expand All @@ -153,19 +158,18 @@ def query_for_existing(cls, profile, schema, model_name=None):
}

existing = [(table.name, relation_type_lookup.get(table.table_type))
for table in tables]
for table in all_tables]

return dict(existing)

@classmethod
def drop(cls, profile, relation, relation_type, model_name=None):
schema = cls.get_default_schema(profile)
def drop(cls, profile, schema, relation, relation_type, model_name=None):
dataset = cls.get_dataset(profile, schema, model_name)
relation_object = dataset.table(relation)
relation_object.delete()

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
def rename(cls, profile, schema, from_name, to_name, model_name=None):
raise dbt.exceptions.NotImplementedException(
'`rename` is not implemented for this adapter!')

Expand Down Expand Up @@ -234,10 +238,10 @@ def execute_model(cls, profile, model, materialization, model_name=None):
validate_connection(connection)

model_name = model.get('name')
model_schema = model.get('schema')
model_sql = model.get('injected_sql')

schema = cls.get_default_schema(profile)
dataset = cls.get_dataset(profile, schema, model_name)
dataset = cls.get_dataset(profile, model_schema, model_name)

if materialization == 'view':
res = cls.materialize_as_view(profile, dataset, model_name,
Expand Down Expand Up @@ -313,13 +317,23 @@ def drop_schema(cls, profile, schema, model_name=None):
cls.drop_tables_in_schema(dataset)
dataset.delete()

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
conn = cls.get_connection(profile, model_name)

client = conn.get('handle')

with cls.exception_handler(profile, 'list dataset', model_name):
all_datasets = client.list_datasets()
return [ds.name for ds in all_datasets]

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
conn = cls.get_connection(profile, model_name)

client = conn.get('handle')

with cls.exception_handler(profile, 'create dataset', model_name):
with cls.exception_handler(profile, 'get dataset', model_name):
all_datasets = client.list_datasets()
return any([ds.name == schema for ds in all_datasets])

Expand Down
47 changes: 24 additions & 23 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,15 @@ def alter_column_type(cls, profile, schema, table, column_name,
'`alter_column_type` is not implemented for this adapter!')

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
def query_for_existing(cls, profile, schemas, model_name=None):
raise dbt.exceptions.NotImplementedException(
'`query_for_existing` is not implemented for this adapter!')

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
raise dbt.exceptions.NotImplementedException(
'`get_existing_schemas` is not implemented for this adapter!')

@classmethod
def check_schema_exists(cls, profile, schema):
raise dbt.exceptions.NotImplementedException(
Expand All @@ -104,42 +109,43 @@ def get_result_from_cursor(cls, cursor):
return data

@classmethod
def drop(cls, profile, relation, relation_type, model_name=None):
def drop(cls, profile, schema, relation, relation_type, model_name=None):
if relation_type == 'view':
return cls.drop_view(profile, relation, model_name)
return cls.drop_view(profile, schema, relation, model_name)
elif relation_type == 'table':
return cls.drop_table(profile, relation, model_name)
return cls.drop_table(profile, schema, relation, model_name)
else:
raise RuntimeError(
"Invalid relation_type '{}'"
.format(relation_type))

@classmethod
def drop_view(cls, profile, view, model_name):
sql = ('drop view if exists {} cascade'
.format(cls._get_quoted_identifier(profile, view)))
def drop_relation(cls, profile, schema, rel_name, rel_type, model_name):
relation = cls.quote_schema_and_table(profile, schema, rel_name)
sql = 'drop {} if exists {} cascade'.format(rel_type, relation)

connection, cursor = cls.add_query(profile, sql, model_name)

@classmethod
def drop_table(cls, profile, table, model_name):
sql = ('drop table if exists {} cascade'
.format(cls._get_quoted_identifier(profile, table)))
def drop_view(cls, profile, schema, view, model_name):
cls.drop_relation(profile, schema, view, 'view', model_name)

connection, cursor = cls.add_query(profile, sql, model_name)
@classmethod
def drop_table(cls, profile, schema, table, model_name):
cls.drop_relation(profile, schema, table, 'table', model_name)

@classmethod
def truncate(cls, profile, table, model_name=None):
sql = ('truncate table {}'
.format(cls._get_quoted_identifier(profile, table)))
def truncate(cls, profile, schema, table, model_name=None):
relation = cls.quote_schema_and_table(profile, schema, table)
sql = 'truncate table {}'.format(relation)

connection, cursor = cls.add_query(profile, sql, model_name)

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
sql = ('alter table {} rename to {}'
.format(cls._get_quoted_identifier(profile, from_name),
cls.quote(to_name)))
def rename(cls, profile, schema, from_name, to_name, model_name=None):
from_relation = cls.quote_schema_and_table(profile, schema, from_name)
to_relation = cls.quote(to_name)
sql = 'alter table {} rename to {}'.format(from_relation, to_relation)

connection, cursor = cls.add_query(profile, sql, model_name)

Expand Down Expand Up @@ -576,11 +582,6 @@ def already_exists(cls, profile, schema, table, model_name=None):
"""
return cls.table_exists(profile, schema, table, model_name)

@classmethod
def _get_quoted_identifier(cls, profile, identifier):
return cls.quote_schema_and_table(
profile, cls.get_default_schema(profile), identifier)

@classmethod
def quote(cls, identifier):
return '"{}"'.format(identifier)
Expand Down
23 changes: 19 additions & 4 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,19 @@ def alter_column_type(cls, profile, schema, table, column_name,
return connection, cursor

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
def query_for_existing(cls, profile, schemas, model_name=None):
if not isinstance(schemas, (list, tuple)):
schemas = [schemas]

schema_list = ",".join(["'{}'".format(schema) for schema in schemas])

sql = """
select tablename as name, 'table' as type from pg_tables
where schemaname = '{schema}'
where schemaname in ({schema_list})
union all
select viewname as name, 'view' as type from pg_views
where schemaname = '{schema}'
""".format(schema=schema).strip() # noqa
where schemaname in ({schema_list})
""".format(schema_list=schema_list).strip() # noqa

connection, cursor = cls.add_query(profile, sql, model_name,
auto_begin=False)
Expand All @@ -125,6 +130,16 @@ def query_for_existing(cls, profile, schema, model_name=None):

return dict(existing)

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
sql = "select distinct nspname from pg_namespace"

connection, cursor = cls.add_query(profile, sql, model_name,
auto_begin=False)
results = cursor.fetchall()

return [row[0] for row in results]

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
sql = """
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def date_function(cls):
return 'getdate()'

@classmethod
def drop(cls, profile, relation, relation_type, model_name=None):
def drop(cls, profile, schema, relation, relation_type, model_name=None):
global drop_lock

to_return = None
Expand All @@ -34,7 +34,7 @@ def drop(cls, profile, relation, relation_type, model_name=None):
cls.begin(profile, connection.get('name'))

to_return = super(PostgresAdapter, cls).drop(
profile, relation, relation_type, model_name)
profile, schema, relation, relation_type, model_name)

cls.commit(profile, connection)
cls.begin(profile, connection.get('name'))
Expand Down
26 changes: 20 additions & 6 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,17 @@ def open_connection(cls, connection):
return result

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
def query_for_existing(cls, profile, schemas, model_name=None):
if not isinstance(schemas, (list, tuple)):
schemas = [schemas]

schema_list = ",".join(["'{}'".format(schema) for schema in schemas])

sql = """
select TABLE_NAME as name, TABLE_TYPE as type
from INFORMATION_SCHEMA.TABLES
where TABLE_SCHEMA = '{schema}'
""".format(schema=schema).strip() # noqa
where TABLE_SCHEMA in ({schema_list})
""".format(schema_list=schema_list).strip() # noqa

_, cursor = cls.add_query(profile, sql, model_name, auto_begin=False)
results = cursor.fetchall()
Expand All @@ -124,9 +129,7 @@ def query_for_existing(cls, profile, schema, model_name=None):
return dict(existing)

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
schema = cls.get_default_schema(profile)

def rename(cls, profile, schema, from_name, to_name, model_name=None):
sql = (('alter table "{schema}"."{from_name}" '
'rename to "{schema}"."{to_name}"')
.format(schema=schema,
Expand All @@ -146,6 +149,17 @@ def create_schema(cls, profile, schema, model_name=None):
sql = cls.get_create_schema_sql(schema)
return cls.add_query(profile, sql, model_name, select_schema=False)

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
sql = "select distinct SCHEMA_NAME from INFORMATION_SCHEMA.SCHEMATA"

connection, cursor = cls.add_query(profile, sql, model_name,
select_schema=False,
auto_begin=False)
results = cursor.fetchall()

return [row[0] for row in results]

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
sql = """
Expand Down
20 changes: 14 additions & 6 deletions dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def commit(self):


def _add_macros(context, model, flat_graph):
macros_to_add = {'global': [], 'local': []}

for unique_id, macro in flat_graph.get('macros', {}).items():
package_name = macro.get('package_name')

Expand All @@ -71,9 +73,15 @@ def _add_macros(context, model, flat_graph):
context.get(package_name, {}) \
.update(macro_map)

if(package_name == model.get('package_name') or
package_name == dbt.include.GLOBAL_PROJECT_NAME):
context.update(macro_map)
if package_name == model.get('package_name'):
macros_to_add['local'].append(macro_map)
elif package_name == dbt.include.GLOBAL_PROJECT_NAME:
macros_to_add['global'].append(macro_map)

# Load global macros before local macros -- local takes precedence
unprefixed_macros = macros_to_add['global'] + macros_to_add['local']
for macro_map in unprefixed_macros:
context.update(macro_map)

return context

Expand Down Expand Up @@ -268,14 +276,14 @@ def generate(model, project, flat_graph, provider=None):
"model": model,
"post_hooks": post_hooks,
"pre_hooks": pre_hooks,
"ref": provider.ref(model, project, profile, schema, flat_graph),
"schema": schema,
"ref": provider.ref(model, project, profile, flat_graph),
"schema": model.get('schema', schema),
"sql": model.get('injected_sql'),
"sql_now": adapter.date_function(),
"fromjson": fromjson(model),
"target": target,
"this": dbt.utils.This(
schema,
model.get('schema', schema),
dbt.utils.model_immediate_name(model, dbt.flags.NON_DESTRUCTIVE),
model.get('name')
)
Expand Down
2 changes: 1 addition & 1 deletion dbt/context/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
execute = False


def ref(model, project, profile, schema, flat_graph):
def ref(model, project, profile, flat_graph):

def ref(*args):
if len(args) == 1 or len(args) == 2:
Expand Down
3 changes: 2 additions & 1 deletion dbt/context/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
execute = True


def ref(model, project, profile, schema, flat_graph):
def ref(model, project, profile, flat_graph):
current_project = project.get('name')

def do_ref(*args):
Expand Down Expand Up @@ -54,6 +54,7 @@ def do_ref(*args):
else:
adapter = get_adapter(profile)
table = target_model.get('name')
schema = target_model.get('schema')

return adapter.quote_schema_and_table(profile, schema, table)

Expand Down
1 change: 1 addition & 0 deletions dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# identifiers
Required('unique_id'): All(basestring, Length(min=1, max=255)),
Required('fqn'): All(list, [All(basestring)]),
Required('schema'): basestring,

Required('refs'): [All(tuple)],

Expand Down
Loading

0 comments on commit 5cc2e13

Please sign in to comment.