Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom schemas #522

Merged
merged 16 commits into from
Sep 14, 2017
34 changes: 24 additions & 10 deletions dbt/adapters/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,14 @@ def open_connection(cls, connection):
return result

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
dataset = cls.get_dataset(profile, schema, model_name)
tables = dataset.list_tables()
def query_for_existing(cls, profile, schemas, model_name=None):
if not isinstance(schemas, (list, tuple)):
schemas = [schemas]

all_tables = []
for schema in schemas:
dataset = cls.get_dataset(profile, schema, model_name)
all_tables.extend(dataset.list_tables())

relation_type_lookup = {
'TABLE': 'table',
Expand All @@ -153,19 +158,18 @@ def query_for_existing(cls, profile, schema, model_name=None):
}

existing = [(table.name, relation_type_lookup.get(table.table_type))
for table in tables]
for table in all_tables]

return dict(existing)

@classmethod
def drop(cls, profile, relation, relation_type, model_name=None):
schema = cls.get_default_schema(profile)
def drop(cls, profile, schema, relation, relation_type, model_name=None):
dataset = cls.get_dataset(profile, schema, model_name)
relation_object = dataset.table(relation)
relation_object.delete()

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
def rename(cls, profile, schema, from_name, to_name, model_name=None):
raise dbt.exceptions.NotImplementedException(
'`rename` is not implemented for this adapter!')

Expand Down Expand Up @@ -234,10 +238,10 @@ def execute_model(cls, profile, model, materialization, model_name=None):
validate_connection(connection)

model_name = model.get('name')
model_schema = model.get('schema')
model_sql = model.get('injected_sql')

schema = cls.get_default_schema(profile)
dataset = cls.get_dataset(profile, schema, model_name)
dataset = cls.get_dataset(profile, model_schema, model_name)

if materialization == 'view':
res = cls.materialize_as_view(profile, dataset, model_name,
Expand Down Expand Up @@ -313,13 +317,23 @@ def drop_schema(cls, profile, schema, model_name=None):
cls.drop_tables_in_schema(dataset)
dataset.delete()

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
conn = cls.get_connection(profile, model_name)

client = conn.get('handle')

with cls.exception_handler(profile, 'list dataset', model_name):
all_datasets = client.list_datasets()
return [ds.name for ds in all_datasets]

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
conn = cls.get_connection(profile, model_name)

client = conn.get('handle')

with cls.exception_handler(profile, 'create dataset', model_name):
with cls.exception_handler(profile, 'get dataset', model_name):
all_datasets = client.list_datasets()
return any([ds.name == schema for ds in all_datasets])

Expand Down
47 changes: 24 additions & 23 deletions dbt/adapters/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,15 @@ def alter_column_type(cls, profile, schema, table, column_name,
'`alter_column_type` is not implemented for this adapter!')

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
def query_for_existing(cls, profile, schemas, model_name=None):
raise dbt.exceptions.NotImplementedException(
'`query_for_existing` is not implemented for this adapter!')

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
raise dbt.exceptions.NotImplementedException(
'`get_existing_schemas` is not implemented for this adapter!')

@classmethod
def check_schema_exists(cls, profile, schema):
raise dbt.exceptions.NotImplementedException(
Expand All @@ -104,42 +109,43 @@ def get_result_from_cursor(cls, cursor):
return data

@classmethod
def drop(cls, profile, relation, relation_type, model_name=None):
def drop(cls, profile, schema, relation, relation_type, model_name=None):
if relation_type == 'view':
return cls.drop_view(profile, relation, model_name)
return cls.drop_view(profile, schema, relation, model_name)
elif relation_type == 'table':
return cls.drop_table(profile, relation, model_name)
return cls.drop_table(profile, schema, relation, model_name)
else:
raise RuntimeError(
"Invalid relation_type '{}'"
.format(relation_type))

@classmethod
def drop_view(cls, profile, view, model_name):
sql = ('drop view if exists {} cascade'
.format(cls._get_quoted_identifier(profile, view)))
def drop_relation(cls, profile, schema, rel_name, rel_type, model_name):
relation = cls.quote_schema_and_table(profile, schema, rel_name)
sql = 'drop {} if exists {} cascade'.format(rel_type, relation)

connection, cursor = cls.add_query(profile, sql, model_name)

@classmethod
def drop_table(cls, profile, table, model_name):
sql = ('drop table if exists {} cascade'
.format(cls._get_quoted_identifier(profile, table)))
def drop_view(cls, profile, schema, view, model_name):
cls.drop_relation(profile, schema, view, 'view', model_name)

connection, cursor = cls.add_query(profile, sql, model_name)
@classmethod
def drop_table(cls, profile, schema, table, model_name):
cls.drop_relation(profile, schema, table, 'table', model_name)

@classmethod
def truncate(cls, profile, table, model_name=None):
sql = ('truncate table {}'
.format(cls._get_quoted_identifier(profile, table)))
def truncate(cls, profile, schema, table, model_name=None):
relation = cls.quote_schema_and_table(profile, schema, table)
sql = 'truncate table {}'.format(relation)

connection, cursor = cls.add_query(profile, sql, model_name)

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
sql = ('alter table {} rename to {}'
.format(cls._get_quoted_identifier(profile, from_name),
cls.quote(to_name)))
def rename(cls, profile, schema, from_name, to_name, model_name=None):
from_relation = cls.quote_schema_and_table(profile, schema, from_name)
to_relation = cls.quote(to_name)
sql = 'alter table {} rename to {}'.format(from_relation, to_relation)

connection, cursor = cls.add_query(profile, sql, model_name)

Expand Down Expand Up @@ -576,11 +582,6 @@ def already_exists(cls, profile, schema, table, model_name=None):
"""
return cls.table_exists(profile, schema, table, model_name)

@classmethod
def _get_quoted_identifier(cls, profile, identifier):
return cls.quote_schema_and_table(
profile, cls.get_default_schema(profile), identifier)

@classmethod
def quote(cls, identifier):
return '"{}"'.format(identifier)
Expand Down
23 changes: 19 additions & 4 deletions dbt/adapters/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,19 @@ def alter_column_type(cls, profile, schema, table, column_name,
return connection, cursor

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
def query_for_existing(cls, profile, schemas, model_name=None):
if not isinstance(schemas, (list, tuple)):
schemas = [schemas]

schema_list = ",".join(["'{}'".format(schema) for schema in schemas])

sql = """
select tablename as name, 'table' as type from pg_tables
where schemaname = '{schema}'
where schemaname in ({schema_list})
union all
select viewname as name, 'view' as type from pg_views
where schemaname = '{schema}'
""".format(schema=schema).strip() # noqa
where schemaname in ({schema_list})
""".format(schema_list=schema_list).strip() # noqa

connection, cursor = cls.add_query(profile, sql, model_name,
auto_begin=False)
Expand All @@ -125,6 +130,16 @@ def query_for_existing(cls, profile, schema, model_name=None):

return dict(existing)

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
sql = "select distinct nspname from pg_namespace"

connection, cursor = cls.add_query(profile, sql, model_name,
auto_begin=False)
results = cursor.fetchall()

return [row[0] for row in results]

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
sql = """
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def date_function(cls):
return 'getdate()'

@classmethod
def drop(cls, profile, relation, relation_type, model_name=None):
def drop(cls, profile, schema, relation, relation_type, model_name=None):
global drop_lock

to_return = None
Expand All @@ -34,7 +34,7 @@ def drop(cls, profile, relation, relation_type, model_name=None):
cls.begin(profile, connection.get('name'))

to_return = super(PostgresAdapter, cls).drop(
profile, relation, relation_type, model_name)
profile, schema, relation, relation_type, model_name)

cls.commit(profile, connection)
cls.begin(profile, connection.get('name'))
Expand Down
26 changes: 20 additions & 6 deletions dbt/adapters/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,17 @@ def open_connection(cls, connection):
return result

@classmethod
def query_for_existing(cls, profile, schema, model_name=None):
def query_for_existing(cls, profile, schemas, model_name=None):
if not isinstance(schemas, (list, tuple)):
schemas = [schemas]

schema_list = ",".join(["'{}'".format(schema) for schema in schemas])

sql = """
select TABLE_NAME as name, TABLE_TYPE as type
from INFORMATION_SCHEMA.TABLES
where TABLE_SCHEMA = '{schema}'
""".format(schema=schema).strip() # noqa
where TABLE_SCHEMA in ({schema_list})
""".format(schema_list=schema_list).strip() # noqa

_, cursor = cls.add_query(profile, sql, model_name, auto_begin=False)
results = cursor.fetchall()
Expand All @@ -124,9 +129,7 @@ def query_for_existing(cls, profile, schema, model_name=None):
return dict(existing)

@classmethod
def rename(cls, profile, from_name, to_name, model_name=None):
schema = cls.get_default_schema(profile)

def rename(cls, profile, schema, from_name, to_name, model_name=None):
sql = (('alter table "{schema}"."{from_name}" '
'rename to "{schema}"."{to_name}"')
.format(schema=schema,
Expand All @@ -146,6 +149,17 @@ def create_schema(cls, profile, schema, model_name=None):
sql = cls.get_create_schema_sql(schema)
return cls.add_query(profile, sql, model_name, select_schema=False)

@classmethod
def get_existing_schemas(cls, profile, model_name=None):
sql = "select distinct SCHEMA_NAME from INFORMATION_SCHEMA.SCHEMATA"

connection, cursor = cls.add_query(profile, sql, model_name,
select_schema=False,
auto_begin=False)
results = cursor.fetchall()

return [row[0] for row in results]

@classmethod
def check_schema_exists(cls, profile, schema, model_name=None):
sql = """
Expand Down
20 changes: 14 additions & 6 deletions dbt/context/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def commit(self):


def _add_macros(context, model, flat_graph):
macros_to_add = {'global': [], 'local': []}

for unique_id, macro in flat_graph.get('macros', {}).items():
package_name = macro.get('package_name')

Expand All @@ -71,9 +73,15 @@ def _add_macros(context, model, flat_graph):
context.get(package_name, {}) \
.update(macro_map)

if(package_name == model.get('package_name') or
package_name == dbt.include.GLOBAL_PROJECT_NAME):
context.update(macro_map)
if package_name == model.get('package_name'):
macros_to_add['local'].append(macro_map)
elif package_name == dbt.include.GLOBAL_PROJECT_NAME:
macros_to_add['global'].append(macro_map)

# Load global macros before local macros -- local takes precedence
unprefixed_macros = macros_to_add['global'] + macros_to_add['local']
for macro_map in unprefixed_macros:
context.update(macro_map)

return context

Expand Down Expand Up @@ -268,14 +276,14 @@ def generate(model, project, flat_graph, provider=None):
"model": model,
"post_hooks": post_hooks,
"pre_hooks": pre_hooks,
"ref": provider.ref(model, project, profile, schema, flat_graph),
"schema": schema,
"ref": provider.ref(model, project, profile, flat_graph),
"schema": model.get('schema', schema),
"sql": model.get('injected_sql'),
"sql_now": adapter.date_function(),
"fromjson": fromjson(model),
"target": target,
"this": dbt.utils.This(
schema,
model.get('schema', schema),
dbt.utils.model_immediate_name(model, dbt.flags.NON_DESTRUCTIVE),
model.get('name')
)
Expand Down
2 changes: 1 addition & 1 deletion dbt/context/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
execute = False


def ref(model, project, profile, schema, flat_graph):
def ref(model, project, profile, flat_graph):

def ref(*args):
if len(args) == 1 or len(args) == 2:
Expand Down
3 changes: 2 additions & 1 deletion dbt/context/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
execute = True


def ref(model, project, profile, schema, flat_graph):
def ref(model, project, profile, flat_graph):
current_project = project.get('name')

def do_ref(*args):
Expand Down Expand Up @@ -54,6 +54,7 @@ def do_ref(*args):
else:
adapter = get_adapter(profile)
table = target_model.get('name')
schema = target_model.get('schema')

return adapter.quote_schema_and_table(profile, schema, table)

Expand Down
1 change: 1 addition & 0 deletions dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
# identifiers
Required('unique_id'): All(basestring, Length(min=1, max=255)),
Required('fqn'): All(list, [All(basestring)]),
Required('schema'): basestring,

Required('refs'): [All(tuple)],

Expand Down
Loading