diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index e290042ce80..29c11cc1601 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -561,7 +561,9 @@ def list_relations(self, database, schema, model_name=None): information_schema = self.Relation.create( database=database, schema=schema, - model_name='').information_schema() + model_name='', + quote_policy=self.config.quoting + ).information_schema() # we can't build the relations cache because we don't have a # manifest so we can't run any operations. @@ -581,7 +583,7 @@ def _make_match_kwargs(self, database, schema, identifier): if schema is not None and quoting['schema'] is False: schema = schema.lower() - if database is not None and quoting['schema'] is False: + if database is not None and quoting['database'] is False: database = database.lower() return filter_null_values({ diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 8a6ace3bef7..65a0344d606 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -203,6 +203,7 @@ def list_relations_without_caching(self, information_schema, schema, relations = [] quote_policy = { + 'database': True, 'schema': True, 'identifier': True } @@ -233,7 +234,8 @@ def list_schemas(self, database, model_name=None): def check_schema_exists(self, database, schema, model_name=None): information_schema = self.Relation.create( - database=database, schema=schema + database=database, schema=schema, + quote_policy=self.config.quoting ).information_schema() kwargs = {'information_schema': information_schema, 'schema': schema} diff --git a/plugins/postgres/dbt/adapters/postgres/impl.py b/plugins/postgres/dbt/adapters/postgres/impl.py index 87487c7f791..34870c56eba 100644 --- a/plugins/postgres/dbt/adapters/postgres/impl.py +++ b/plugins/postgres/dbt/adapters/postgres/impl.py @@ -25,7 +25,10 @@ def date_function(cls): @available_raw def verify_database(self, database): - database = database.strip('"') + if database.startswith('"'): + database = database.strip('"') + else: + database = database.lower() expected = self.config.credentials.database if database != expected: raise dbt.exceptions.NotImplementedException( diff --git a/test.env.sample b/test.env.sample index 8cb09b20ef3..84bfeaddcd3 100644 --- a/test.env.sample +++ b/test.env.sample @@ -3,6 +3,8 @@ SNOWFLAKE_TEST_ACCOUNT= SNOWFLAKE_TEST_USER= SNOWFLAKE_TEST_PASSWORD= SNOWFLAKE_TEST_DATABASE= +SNOWFLAKE_TEST_ALT_DATABASE= +SNOWFLAKE_TEST_QUOTED_DATABASE= SNOWFLAKE_TEST_WAREHOUSE= BIGQUERY_TYPE= diff --git a/test/integration/001_simple_copy_test/test_simple_copy.py b/test/integration/001_simple_copy_test/test_simple_copy.py index bbe9349e439..d6a393761c4 100644 --- a/test/integration/001_simple_copy_test/test_simple_copy.py +++ b/test/integration/001_simple_copy_test/test_simple_copy.py @@ -17,6 +17,7 @@ def models(self): class TestSimpleCopy(BaseTestSimpleCopy): + @use_profile("postgres") def test__postgres__simple_copy(self): self.use_default_project({"data-paths": [self.dir("seed-initial")]}) @@ -83,6 +84,12 @@ def test__snowflake__simple_copy(self): self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"]) + self.use_default_project({ + "test-paths": [self.dir("tests")], + "data-paths": [self.dir("seed-update")], + }) + self.run_dbt(['test']) + @use_profile("snowflake") def test__snowflake__simple_copy__quoting_off(self): self.use_default_project({ @@ -108,6 +115,13 @@ def test__snowflake__simple_copy__quoting_off(self): self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"]) + self.use_default_project({ + "test-paths": [self.dir("tests")], + "data-paths": [self.dir("seed-update")], + "quoting": {"identifier": False}, + }) + self.run_dbt(['test']) + @use_profile("snowflake") def test__snowflake__seed__quoting_switch(self): self.use_default_project({ @@ -124,6 +138,12 @@ def test__snowflake__seed__quoting_switch(self): }) results = self.run_dbt(["seed"], expect_pass=False) + self.use_default_project({ + "test-paths": [self.dir("tests")], + "data-paths": [self.dir("seed-initial")], + }) + self.run_dbt(['test']) + @use_profile("bigquery") def test__bigquery__simple_copy(self): self.use_default_project({"data-paths": [self.dir("seed-initial")]}) @@ -181,6 +201,8 @@ def test__snowflake__simple_copy__quoting_on(self): self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"]) + # can't run the test as this one's identifiers will be the wrong case + class BaseLowercasedSchemaTest(BaseTestSimpleCopy): def unique_schema(self): @@ -210,6 +232,12 @@ def test__snowflake__simple_copy(self): self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"]) + self.use_default_project({ + "test-paths": [self.dir("tests")], + "data-paths": [self.dir("seed-update")], + }) + self.run_dbt(['test']) + class TestSnowflakeSimpleLowercasedSchemaQuoted(BaseLowercasedSchemaTest): @property diff --git a/test/integration/001_simple_copy_test/tests/get_relation_quoting.sql b/test/integration/001_simple_copy_test/tests/get_relation_quoting.sql new file mode 100644 index 00000000000..822d468fae7 --- /dev/null +++ b/test/integration/001_simple_copy_test/tests/get_relation_quoting.sql @@ -0,0 +1,9 @@ +{%- set tgt = ref('seed') -%} +{%- set got = adapter.get_relation(database=tgt.database, schema=tgt.schema, identifier=tgt.identifier) | string -%} +{% set replaced = got.replace('"', '-') %} +{% set expected = "-" + tgt.database.upper() + '-.-' + tgt.schema.upper() + '-.-' + tgt.identifier.upper() + '-' %} + +with cte as ( + select '{{ replaced }}' as name +) +select * from cte where name not like '{{ expected }}' diff --git a/test/integration/040_override_database_test/models/view_1.sql b/test/integration/040_override_database_test/models/view_1.sql index 4b91aa0f2fa..a43f04646b8 100644 --- a/test/integration/040_override_database_test/models/view_1.sql +++ b/test/integration/040_override_database_test/models/view_1.sql @@ -1 +1,7 @@ +{# + We are running against a database that must be quoted. + These calls ensure that we trigger an error if we're failing to quote at parse-time +#} +{% do adapter.already_exists(this.schema, this.table) %} +{% do adapter.get_relation(this.database, this.schema, this.table) %} select * from {{ ref('seed') }} diff --git a/test/integration/040_override_database_test/models/view_2.sql b/test/integration/040_override_database_test/models/view_2.sql index efa1268fab8..bdfa5369fa7 100644 --- a/test/integration/040_override_database_test/models/view_2.sql +++ b/test/integration/040_override_database_test/models/view_2.sql @@ -1,5 +1,4 @@ {{ config(database=var('alternate_db')) }} - select * from {{ ref('seed') }} diff --git a/test/integration/040_override_database_test/test_override_database.py b/test/integration/040_override_database_test/test_override_database.py index a3319441d6b..3fc320957e5 100644 --- a/test/integration/040_override_database_test/test_override_database.py +++ b/test/integration/040_override_database_test/test_override_database.py @@ -1,6 +1,8 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest +import os + class BaseOverrideDatabase(DBTIntegrationTest): setup_alternate_db = True @@ -12,6 +14,45 @@ def schema(self): def models(self): return "test/integration/040_override_database_test/models" + @property + def alternative_database(self): + if self.adapter_type == 'snowflake': + return os.getenv('SNOWFLAKE_TEST_DATABASE') + else: + return super(BaseOverrideDatabase, self).alternative_database + + def snowflake_profile(self): + return { + 'config': { + 'send_anonymous_usage_stats': False + }, + 'test': { + 'outputs': { + 'default2': { + 'type': 'snowflake', + 'threads': 4, + 'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'), + 'user': os.getenv('SNOWFLAKE_TEST_USER'), + 'password': os.getenv('SNOWFLAKE_TEST_PASSWORD'), + 'database': os.getenv('SNOWFLAKE_TEST_QUOTED_DATABASE'), + 'schema': self.unique_schema(), + 'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'), + }, + 'noaccess': { + 'type': 'snowflake', + 'threads': 4, + 'account': os.getenv('SNOWFLAKE_TEST_ACCOUNT'), + 'user': 'noaccess', + 'password': 'password', + 'database': os.getenv('SNOWFLAKE_TEST_DATABASE'), + 'schema': self.unique_schema(), + 'warehouse': os.getenv('SNOWFLAKE_TEST_WAREHOUSE'), + } + }, + 'target': 'default2' + } + } + @property def project_config(self): return { @@ -20,9 +61,15 @@ def project_config(self): 'vars': { 'alternate_db': self.alternative_database, }, + }, + 'quoting': { + 'database': True, } } + def run_dbt_notstrict(self, args): + return self.run_dbt(args, strict=False) + class TestModelOverride(BaseOverrideDatabase): def run_database_override(self): @@ -31,9 +78,9 @@ def run_database_override(self): else: func = lambda x: x - self.run_dbt(['seed']) + self.run_dbt_notstrict(['seed']) - self.assertEqual(len(self.run_dbt(['run'])), 4) + self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4) self.assertManyRelationsEqual([ (func('seed'), self.unique_schema(), self.default_database), (func('view_2'), self.unique_schema(), self.alternative_database), @@ -71,9 +118,9 @@ def run_database_override(self): }, } }) - self.run_dbt(['seed']) + self.run_dbt_notstrict(['seed']) - self.assertEqual(len(self.run_dbt(['run'])), 4) + self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4) self.assertManyRelationsEqual([ (func('seed'), self.unique_schema(), self.default_database), (func('view_2'), self.unique_schema(), self.alternative_database), @@ -101,9 +148,9 @@ def run_database_override(self): self.use_default_project({ 'seeds': {'database': self.alternative_database} }) - self.run_dbt(['seed']) + self.run_dbt_notstrict(['seed']) - self.assertEqual(len(self.run_dbt(['run'])), 4) + self.assertEqual(len(self.run_dbt_notstrict(['run'])), 4) self.assertManyRelationsEqual([ (func('seed'), self.unique_schema(), self.alternative_database), (func('view_2'), self.unique_schema(), self.alternative_database), diff --git a/test/integration/base.py b/test/integration/base.py index a4c52326772..70620f506a6 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -582,7 +582,9 @@ def get_many_table_columns(self, tables, schema, database=None): and ({table_filter}) order by column_name asc""" - db_string = '' if database is None else database + '.' + db_string = '' + if database: + db_string = self.quote_as_configured(database, 'database') + '.' table_filters_s = " OR ".join( self._ilike('table_name', table.replace('"', ''))