Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quote databases properly (#1396) #1402

Merged
merged 6 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,9 @@ def list_relations(self, database, schema, model_name=None):
information_schema = self.Relation.create(
database=database,
schema=schema,
model_name='').information_schema()
model_name='',
quote_policy=self.config.quoting
).information_schema()

# we can't build the relations cache because we don't have a
# manifest so we can't run any operations.
Expand All @@ -581,7 +583,7 @@ def _make_match_kwargs(self, database, schema, identifier):
if schema is not None and quoting['schema'] is False:
schema = schema.lower()

if database is not None and quoting['schema'] is False:
if database is not None and quoting['database'] is False:
database = database.lower()

return filter_null_values({
Expand Down
4 changes: 3 additions & 1 deletion core/dbt/adapters/sql/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def list_relations_without_caching(self, information_schema, schema,

relations = []
quote_policy = {
'database': True,
'schema': True,
'identifier': True
}
Expand Down Expand Up @@ -233,7 +234,8 @@ def list_schemas(self, database, model_name=None):

def check_schema_exists(self, database, schema, model_name=None):
information_schema = self.Relation.create(
database=database, schema=schema
database=database, schema=schema,
quote_policy=self.config.quoting
).information_schema()

kwargs = {'information_schema': information_schema, 'schema': schema}
Expand Down
5 changes: 4 additions & 1 deletion plugins/postgres/dbt/adapters/postgres/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ def date_function(cls):

@available_raw
def verify_database(self, database):
database = database.strip('"')
if database.startswith('"'):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't the most pleasant thing in the world, but I think it's a good and appropriate fix for this problem. Nice

database = database.strip('"')
else:
database = database.lower()
expected = self.config.credentials.database
if database != expected:
raise dbt.exceptions.NotImplementedException(
Expand Down
2 changes: 2 additions & 0 deletions test.env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ SNOWFLAKE_TEST_ACCOUNT=
SNOWFLAKE_TEST_USER=
SNOWFLAKE_TEST_PASSWORD=
SNOWFLAKE_TEST_DATABASE=
SNOWFLAKE_TEST_ALT_DATABASE=
SNOWFLAKE_TEST_QUOTED_DATABASE=
SNOWFLAKE_TEST_WAREHOUSE=

BIGQUERY_TYPE=
Expand Down
28 changes: 28 additions & 0 deletions test/integration/001_simple_copy_test/test_simple_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def models(self):


class TestSimpleCopy(BaseTestSimpleCopy):

@use_profile("postgres")
def test__postgres__simple_copy(self):
self.use_default_project({"data-paths": [self.dir("seed-initial")]})
Expand Down Expand Up @@ -83,6 +84,12 @@ def test__snowflake__simple_copy(self):

self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"])

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-update")],
})
self.run_dbt(['test'])

@use_profile("snowflake")
def test__snowflake__simple_copy__quoting_off(self):
self.use_default_project({
Expand All @@ -108,6 +115,13 @@ def test__snowflake__simple_copy__quoting_off(self):

self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"])

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-update")],
"quoting": {"identifier": False},
})
self.run_dbt(['test'])

@use_profile("snowflake")
def test__snowflake__seed__quoting_switch(self):
self.use_default_project({
Expand All @@ -124,6 +138,12 @@ def test__snowflake__seed__quoting_switch(self):
})
results = self.run_dbt(["seed"], expect_pass=False)

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-initial")],
})
self.run_dbt(['test'])

@use_profile("bigquery")
def test__bigquery__simple_copy(self):
self.use_default_project({"data-paths": [self.dir("seed-initial")]})
Expand Down Expand Up @@ -181,6 +201,8 @@ def test__snowflake__simple_copy__quoting_on(self):

self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"])

# can't run the test as this one's identifiers will be the wrong case


class BaseLowercasedSchemaTest(BaseTestSimpleCopy):
def unique_schema(self):
Expand Down Expand Up @@ -210,6 +232,12 @@ def test__snowflake__simple_copy(self):

self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"])

self.use_default_project({
"test-paths": [self.dir("tests")],
"data-paths": [self.dir("seed-update")],
})
self.run_dbt(['test'])


class TestSnowflakeSimpleLowercasedSchemaQuoted(BaseLowercasedSchemaTest):
@property
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{%- set tgt = ref('seed') -%}
{%- set got = adapter.get_relation(database=tgt.database, schema=tgt.schema, identifier=tgt.identifier) | string -%}
{% set replaced = got.replace('"', '-') %}
{% set expected = "-" + tgt.database.upper() + '-.-' + tgt.schema.upper() + '-.-' + tgt.identifier.upper() + '-' %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-.- is exactly how i feel about Snowflake quoting


with cte as (
select '{{ replaced }}' as name
)
select * from cte where name not like '{{ expected }}'
6 changes: 6 additions & 0 deletions test/integration/040_override_database_test/models/view_1.sql
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
{#
We are running against a database that must be quoted.
These calls ensure that we trigger an error if we're failing to quote at parse-time
#}
{% do adapter.already_exists(this.schema, this.table) %}
{% do adapter.get_relation(this.database, this.schema, this.table) %}
select * from {{ ref('seed') }}
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
{{
config(database=var('alternate_db'))
}}

select * from {{ ref('seed') }}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from nose.plugins.attrib import attr
from test.integration.base import DBTIntegrationTest

import os


class BaseOverrideDatabase(DBTIntegrationTest):
setup_alternate_db = True
Expand All @@ -12,6 +14,45 @@ def schema(self):
def models(self):
return "test/integration/040_override_database_test/models"

@property
def alternative_database(self):
if self.adapter_type == 'snowflake':
return os.getenv('SNOWFLAKE_TEST_DATABASE')
else:
return super(BaseOverrideDatabase, self).alternative_database

def snowflake_profile(self):
return {
'config': {
'send_anonymous_usage_stats': False
},
'test': {
'outputs': {
'default2': {
'type': 'snowflake',
'threads': 4,
'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'),
'user': os.getenv('SNOWFLAKE_TEST_USER'),
'password': os.getenv('SNOWFLAKE_TEST_PASSWORD'),
'database': os.getenv('SNOWFLAKE_TEST_QUOTED_DATABASE'),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'schema': self.unique_schema(),
'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'),
},
'noaccess': {
'type': 'snowflake',
'threads': 4,
'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'),
'user': 'noaccess',
'password': 'password',
'database': os.getenv('SNOWFLAKE_TEST_DATABASE'),
'schema': self.unique_schema(),
'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'),
}
},
'target': 'default2'
}
}

@property
def project_config(self):
return {
Expand All @@ -20,9 +61,15 @@ def project_config(self):
'vars': {
'alternate_db': self.alternative_database,
},
},
'quoting': {
'database': True,
}
}

def run_dbt_notstrict(self, args):
return self.run_dbt(args, strict=False)


class TestModelOverride(BaseOverrideDatabase):
def run_database_override(self):
Expand All @@ -31,9 +78,9 @@ def run_database_override(self):
else:
func = lambda x: x

self.run_dbt(['seed'])
self.run_dbt_notstrict(['seed'])

self.assertEqual(len(self.run_dbt(['run'])), 4)
self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4)
self.assertManyRelationsEqual([
(func('seed'), self.unique_schema(), self.default_database),
(func('view_2'), self.unique_schema(), self.alternative_database),
Expand Down Expand Up @@ -71,9 +118,9 @@ def run_database_override(self):
},
}
})
self.run_dbt(['seed'])
self.run_dbt_notstrict(['seed'])

self.assertEqual(len(self.run_dbt(['run'])), 4)
self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4)
self.assertManyRelationsEqual([
(func('seed'), self.unique_schema(), self.default_database),
(func('view_2'), self.unique_schema(), self.alternative_database),
Expand Down Expand Up @@ -101,9 +148,9 @@ def run_database_override(self):
self.use_default_project({
'seeds': {'database': self.alternative_database}
})
self.run_dbt(['seed'])
self.run_dbt_notstrict(['seed'])

self.assertEqual(len(self.run_dbt(['run'])), 4)
self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4)
self.assertManyRelationsEqual([
(func('seed'), self.unique_schema(), self.alternative_database),
(func('view_2'), self.unique_schema(), self.alternative_database),
Expand Down
4 changes: 3 additions & 1 deletion test/integration/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,9 @@ def get_many_table_columns(self, tables, schema, database=None):
and ({table_filter})
order by column_name asc"""

db_string = '' if database is None else database + '.'
db_string = ''
if database:
db_string = self.quote_as_configured(database, 'database') + '.'

table_filters_s = " OR ".join(
self._ilike('table_name', table.replace('"', ''))
Expand Down