From ff64f8166a853bfba1acee4723929b2849ad9c94 Mon Sep 17 00:00:00 2001 From: Claire Carroll Date: Sun, 22 Jul 2018 16:05:33 +1000 Subject: [PATCH 001/133] Upgrade version of Jinja --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 417a0974478..a25b5b729d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -Jinja2>=2.8 +Jinja2>=2.10 PyYAML>=3.11 psycopg2>=2.7.5,<2.8 sqlparse==0.2.3 diff --git a/setup.py b/setup.py index 8cc744a4295..4e69ebe229e 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ def read(fname): 'scripts/dbt', ], install_requires=[ - 'Jinja2>=2.8', + 'Jinja2>=2.10', 'PyYAML>=3.11', 'psycopg2>=2.7.5,<2.8', 'sqlparse==0.2.3', From c367d5bc75423ef13bc8f9149296fbc7ffe86bb8 Mon Sep 17 00:00:00 2001 From: Claire Carroll Date: Tue, 28 Aug 2018 10:22:52 +0200 Subject: [PATCH 002/133] Check for unused configs in project file --- dbt/compilation.py | 19 +++++++++++ dbt/config.py | 69 ++++++++++++++++++++++++++++++++++++++++ test/unit/test_config.py | 53 ++++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+) diff --git a/dbt/compilation.py b/dbt/compilation.py index 4057263b538..06ff94dfcc0 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -19,6 +19,7 @@ import dbt.exceptions import dbt.flags import dbt.loader +import dbt.config from dbt.contracts.graph.compiled import CompiledNode, CompiledGraph from dbt.clients.system import write_json @@ -233,6 +234,16 @@ def _check_resource_uniqueness(cls, manifest): names_resources[name] = node alias_resources[alias] = node + def get_resource_fqns(self, manifest): + resource_fqns = {} + for unique_id, node in manifest.nodes.items(): + resource_type_plural = node['resource_type'] + 's' + if resource_type_plural not in resource_fqns: + resource_fqns[resource_type_plural] = [] + resource_fqns[resource_type_plural].append(node['fqn']) + + return resource_fqns + def compile(self): linker = Linker() @@ -244,6 +255,14 @@ def compile(self): self._check_resource_uniqueness(manifest) + resource_fqns = self.get_resource_fqns(manifest) + + root_project_resource_config_paths = dbt.config.get_project_resource_config_paths( + self.project) + + dbt.config.warn_for_unused_resource_config_paths( + root_project_resource_config_paths, resource_fqns) + self.link_graph(linker, manifest) stats = defaultdict(int) diff --git a/dbt/config.py b/dbt/config.py index 51681da6ddc..8c00ad4ce59 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -5,6 +5,7 @@ import dbt.clients.system from dbt.logger import GLOBAL_LOGGER as logger +from dbt.utils import DBTConfigKeys INVALID_PROFILE_MESSAGE = """ @@ -13,6 +14,12 @@ {error_string} """ +UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE = """\ +WARNING: Configuration paths exist in your dbt_project.yml file which do not \ +apply to any resources. +There are {} unused configuration paths:\ +""" + def read_profile(profiles_dir): path = os.path.join(profiles_dir, 'profiles.yml') @@ -43,3 +50,65 @@ def send_anonymous_usage_stats(config): def colorize_output(config): return config.get('use_colors', True) + + +def get_config_paths(config, path=None, paths=None): + if path is None: + path = [] + + if paths is None: + paths = [] + + for key, value in config.items(): + if isinstance(value, dict): + if key in DBTConfigKeys: + if path not in paths: + paths.append(path) + else: + get_config_paths(value, path + [key], paths) + else: + if path not in paths: + paths.append(path) + + return paths + + +def get_project_resource_config_paths(project): + resource_config_paths = {} + for resource_type in ['models', 'seeds']: + if resource_type in project: + resource_config_paths[resource_type] = get_config_paths( + project[resource_type]) + return resource_config_paths + + +def is_config_used(config_path, fqns): + for fqn in fqns: + if len(config_path) <= len(fqn) and fqn[:len(config_path)] == config_path: + return True + return False + + +def get_unused_resource_config_paths(resource_config_paths, resource_fqns): + unused_resource_config_paths = [] + for resource_type, config_paths in resource_config_paths.items(): + for config_path in config_paths: + if not is_config_used(config_path, resource_fqns[resource_type]): + unused_resource_config_paths.append( + [resource_type] + config_path) + return unused_resource_config_paths + + +def warn_for_unused_resource_config_paths(resource_config_paths, resource_fqns): + unused_resource_config_paths = get_unused_resource_config_paths( + resource_config_paths, resource_fqns) + if len(unused_resource_config_paths) == 0: + return + logger.info( + dbt.ui.printer.yellow(UNUSED_RESOURCE_CONFIGURATION_PATH_MESSAGE.format( + len(unused_resource_config_paths)))) + for unused_resource_config_path in unused_resource_config_paths: + logger.info( + dbt.ui.printer.yellow(" - {}".format( + ".".join(unused_resource_config_path)))) + logger.info("") diff --git a/test/unit/test_config.py b/test/unit/test_config.py index e15ab695501..9be959e6e60 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -9,6 +9,43 @@ else: TMPDIR = '/tmp' +model_config = { + 'my_package_name': { + 'enabled': True, + 'adwords': { + 'adwords_ads': { + 'materialized': 'table', + 'enabled': True, + 'schema': 'analytics' + } + }, + 'snowplow': { + 'snowplow_sessions': { + 'sort': 'timestamp', + 'materialized': 'incremental', + 'dist': 'user_id', + 'sql_where': 'created_at > (select max(created_at) from {{ this }})', + 'unique_key': 'id' + }, + 'base': { + 'snowplow_events': { + 'sort': ['timestamp', 'userid'], + 'materialized': 'table', + 'sort_type': 'interleaved', + 'dist': 'userid' + } + } + } + } +} + +model_fqns = [ + ['my_package_name', 'snowplow', 'snowplow_sessions'], + ['my_package_name', 'snowplow', 'base', 'snowplow_events'], + ['my_package_name', 'adwords', 'adwords_ads'] +] + + class ConfigTest(unittest.TestCase): def set_up_empty_config(self): @@ -64,3 +101,19 @@ def test__explicit_opt_in(self): self.set_up_config_options(use_colors=True) config = dbt.config.read_config(TMPDIR) self.assertTrue(dbt.config.colorize_output(config)) + + def test__no_unused_resource_config_paths(self): + resource_config = {'models': model_config} + resource_config_paths = dbt.config.get_project_resource_config_paths( + resource_config) + resource_fqns = {'models': model_fqns} + self.assertTrue(len(dbt.config.get_unused_resource_config_paths( + resource_config_paths, resource_fqns)) == 0) + + def test__unused_resource_config_paths(self): + resource_config = {'models': model_config['my_package_name']} + resource_config_paths = dbt.config.get_project_resource_config_paths( + resource_config) + resource_fqns = {'models': model_fqns} + self.assertFalse(len(dbt.config.get_unused_resource_config_paths( + resource_config_paths, resource_fqns)) == 0) From aa06a8a6065208b947d5a03520e8d7bf5cf56aa4 Mon Sep 17 00:00:00 2001 From: Claire Carroll Date: Tue, 28 Aug 2018 17:41:01 +0200 Subject: [PATCH 003/133] Handle case when configs supplied but no resource of that type in project --- dbt/config.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index 8c00ad4ce59..f3c20c8975e 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -83,9 +83,10 @@ def get_project_resource_config_paths(project): def is_config_used(config_path, fqns): - for fqn in fqns: - if len(config_path) <= len(fqn) and fqn[:len(config_path)] == config_path: - return True + if fqns: + for fqn in fqns: + if len(config_path) <= len(fqn) and fqn[:len(config_path)] == config_path: + return True return False @@ -93,7 +94,7 @@ def get_unused_resource_config_paths(resource_config_paths, resource_fqns): unused_resource_config_paths = [] for resource_type, config_paths in resource_config_paths.items(): for config_path in config_paths: - if not is_config_used(config_path, resource_fqns[resource_type]): + if not is_config_used(config_path, resource_fqns.get(resource_type)): unused_resource_config_paths.append( [resource_type] + config_path) return unused_resource_config_paths From 8587bd4435a43c2091b2d158d4f1026aa194c512 Mon Sep 17 00:00:00 2001 From: Claire Carroll Date: Tue, 28 Aug 2018 17:50:20 +0200 Subject: [PATCH 004/133] Add seeds to default_project_cfg --- dbt/config.py | 5 ++--- dbt/project.py | 1 + test/unit/test_config.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index f3c20c8975e..2489f8e1e43 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -76,9 +76,8 @@ def get_config_paths(config, path=None, paths=None): def get_project_resource_config_paths(project): resource_config_paths = {} for resource_type in ['models', 'seeds']: - if resource_type in project: - resource_config_paths[resource_type] = get_config_paths( - project[resource_type]) + resource_config_paths[resource_type] = get_config_paths( + project[resource_type]) return resource_config_paths diff --git a/dbt/project.py b/dbt/project.py index ddaf25a43c7..f3c9a4f979f 100644 --- a/dbt/project.py +++ b/dbt/project.py @@ -28,6 +28,7 @@ 'outputs': {'default': {}}, 'target': 'default', 'models': {}, + 'seeds': {}, 'quoting': {}, 'profile': None, 'packages': [], diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 9be959e6e60..2ea3d22011a 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -103,7 +103,7 @@ def test__explicit_opt_in(self): self.assertTrue(dbt.config.colorize_output(config)) def test__no_unused_resource_config_paths(self): - resource_config = {'models': model_config} + resource_config = {'models': model_config, 'seeds': {}} resource_config_paths = dbt.config.get_project_resource_config_paths( resource_config) resource_fqns = {'models': model_fqns} @@ -111,7 +111,7 @@ def test__no_unused_resource_config_paths(self): resource_config_paths, resource_fqns)) == 0) def test__unused_resource_config_paths(self): - resource_config = {'models': model_config['my_package_name']} + resource_config = {'models': model_config['my_package_name'], 'seeds': {}} resource_config_paths = dbt.config.get_project_resource_config_paths( resource_config) resource_fqns = {'models': model_fqns} From 1620a17ecae2d040af467ae5f4edd392debacfe4 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 13 Sep 2018 12:56:30 -0600 Subject: [PATCH 005/133] Create a somewhat sane, if huge, configuration object --- dbt/adapters/bigquery/impl.py | 254 +++--- dbt/adapters/bigquery/relation.py | 4 +- dbt/adapters/default/impl.py | 308 ++++--- dbt/adapters/default/relation.py | 4 +- dbt/adapters/factory.py | 5 +- dbt/adapters/postgres/impl.py | 62 +- dbt/adapters/redshift/impl.py | 38 +- dbt/adapters/snowflake/impl.py | 83 +- dbt/adapters/snowflake/relation.py | 4 +- dbt/api/object.py | 5 + dbt/compilation.py | 28 +- dbt/config.py | 738 +++++++++++++++++ dbt/context/common.py | 69 +- dbt/context/parser.py | 8 +- dbt/context/runtime.py | 10 +- dbt/contracts/common.py | 11 + dbt/contracts/connection.py | 131 ++- dbt/contracts/graph/manifest.py | 10 +- dbt/contracts/project.py | 311 +++++++- dbt/contracts/results.py | 12 +- dbt/exceptions.py | 7 + dbt/loader.py | 50 +- dbt/main.py | 95 +-- dbt/model.py | 8 +- dbt/node_runners.py | 85 +- dbt/parser/archives.py | 18 +- dbt/parser/base.py | 10 +- dbt/parser/docs.py | 3 +- dbt/parser/hooks.py | 11 +- dbt/parser/util.py | 4 +- dbt/project.py | 314 -------- dbt/runner.py | 51 +- dbt/task/archive.py | 6 +- dbt/task/base_task.py | 4 +- dbt/task/clean.py | 6 +- dbt/task/compile.py | 6 +- dbt/task/debug.py | 25 +- dbt/task/deps.py | 89 ++- dbt/task/generate.py | 13 +- dbt/task/init.py | 4 +- dbt/task/run.py | 4 +- dbt/task/seed.py | 6 +- dbt/task/serve.py | 2 +- dbt/task/test.py | 3 +- dbt/tracking.py | 28 +- dbt/utils.py | 24 +- .../001_simple_copy_test/test_simple_copy.py | 80 +- .../test_varchar_widening.py | 15 +- .../test_simple_reference.py | 25 +- .../test_simple_archive.py | 4 +- .../test_seed_type_override.py | 3 +- .../test_simple_dependency.py | 17 +- .../test_simple_dependency_with_configs.py | 56 +- .../test_schema_test_graph_selection.py | 10 +- .../test_schema_tests.py | 24 +- .../test_schema_v2_tests.py | 37 +- .../009_data_tests_test/test_data_tests.py | 4 +- .../014_hook_tests/test_model_hooks_bq.py | 11 +- .../014_hook_tests/test_run_hooks_bq.py | 2 +- .../test_cli_invocation.py | 2 +- .../016_macro_tests/test_macros.py | 22 +- .../020_ephemeral_test/test_ephemeral.py | 8 - .../021_concurrency_test/test_concurrency.py | 8 - .../test_bigquery_adapter_functions.py | 2 - .../test_bigquery_date_partitioning.py | 2 - .../023_exit_codes_test/test_exit_codes.py | 32 +- .../test_duplicate_model.py | 31 +- .../025_timezones_test/test_timezones.py | 3 - .../026_aliases_test/test_aliases.py | 3 - .../integration/027_cycle_test/test_cycles.py | 6 - .../028_cli_vars/test_cli_var_override.py | 2 - .../test_docs_generate.py | 18 +- .../models-bq/statement_actual.sql | 23 + .../030_statement_test/test_statements.py | 27 +- .../test_thread_count.py | 10 +- .../033_event_tracking_test/test_events.py | 13 +- .../test_late_binding_view.py | 8 +- .../test_changing_relation_type.py | 15 +- .../035_docs_blocks/test_docs_blocks.py | 1 - test/integration/base.py | 165 ++-- test/unit/test_bigquery_adapter.py | 37 +- test/unit/test_config.py | 752 +++++++++++++++++- test/unit/test_docs_blocks.py | 43 +- test/unit/test_graph.py | 54 +- test/unit/test_graph_selection.py | 42 - test/unit/test_manifest.py | 12 +- test/unit/test_parser.py | 59 +- test/unit/test_postgres_adapter.py | 36 +- test/unit/test_project.py | 85 -- test/unit/test_redshift_adapter.py | 67 +- test/unit/utils.py | 23 + 91 files changed, 3136 insertions(+), 1734 deletions(-) create mode 100644 dbt/contracts/common.py delete mode 100644 dbt/project.py create mode 100644 test/integration/030_statement_test/models-bq/statement_actual.sql delete mode 100644 test/unit/test_project.py create mode 100644 test/unit/utils.py diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 037fd7725fd..c102c0a3479 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -80,7 +80,7 @@ def handle_error(cls, error, message, sql): @classmethod @contextmanager - def exception_handler(cls, profile, sql, model_name=None, + def exception_handler(cls, config, sql, model_name=None, connection_name='master'): try: yield @@ -107,11 +107,11 @@ def date_function(cls): return 'CURRENT_TIMESTAMP()' @classmethod - def begin(cls, profile, name='master'): + def begin(cls, config, name='master'): pass @classmethod - def commit(cls, profile, connection): + def commit(cls, config, connection): pass @classmethod @@ -120,8 +120,8 @@ def get_status(cls, cursor): '`get_status` is not implemented for this adapter!') @classmethod - def get_bigquery_credentials(cls, config): - method = config.get('method') + def get_bigquery_credentials(cls, profile_credentials): + method = profile_credentials.method creds = google.oauth2.service_account.Credentials if method == 'oauth': @@ -129,71 +129,64 @@ def get_bigquery_credentials(cls, config): return credentials elif method == 'service-account': - keyfile = config.get('keyfile') + keyfile = profile_credentials.keyfile return creds.from_service_account_file(keyfile, scopes=cls.SCOPE) elif method == 'service-account-json': - details = config.get('keyfile_json') + details = profile_credentials.keyfile_json return creds.from_service_account_info(details, scopes=cls.SCOPE) error = ('Invalid `method` in profile: "{}"'.format(method)) raise dbt.exceptions.FailedToConnectException(error) @classmethod - def get_bigquery_client(cls, config): - project_name = config.get('project') - creds = cls.get_bigquery_credentials(config) + def get_bigquery_client(cls, profile_credentials): + project_name = profile_credentials.project + creds = cls.get_bigquery_credentials(profile_credentials) return google.cloud.bigquery.Client(project_name, creds) @classmethod def open_connection(cls, connection): - if connection.get('state') == 'open': + if connection.state == 'open': logger.debug('Connection is already open, skipping open.') return connection - result = connection.copy() - credentials = connection.get('credentials', {}) - try: - handle = cls.get_bigquery_client(credentials) + handle = cls.get_bigquery_client(connection.credentials) except google.auth.exceptions.DefaultCredentialsError as e: logger.info("Please log into GCP to continue") dbt.clients.gcloud.setup_default_credentials() - handle = cls.get_bigquery_client(credentials) + handle = cls.get_bigquery_client(connection.credentials) except Exception as e: raise logger.debug("Got an error when attempting to create a bigquery " "client: '{}'".format(e)) - result['handle'] = None - result['state'] = 'fail' + connection.handle = None + connection.state = 'fail' raise dbt.exceptions.FailedToConnectException(str(e)) - result['handle'] = handle - result['state'] = 'open' - return result + connection.handle = handle + connection.state = 'open' + return connection @classmethod def close(cls, connection): - if dbt.flags.STRICT_MODE: - Connection(**connection) - - connection['state'] = 'closed' + connection.state = 'closed' return connection @classmethod - def list_relations(cls, profile, project_cfg, schema, model_name=None): - connection = cls.get_connection(profile, model_name) - client = connection.get('handle') + def list_relations(cls, config, schema, model_name=None): + connection = cls.get_connection(config, model_name) + client = connection.handle - bigquery_dataset = cls.get_dataset( - profile, project_cfg, schema, model_name) + bigquery_dataset = cls.get_dataset(config, schema, model_name) all_tables = client.list_tables( bigquery_dataset, @@ -215,7 +208,7 @@ def list_relations(cls, profile, project_cfg, schema, model_name=None): return [] @classmethod - def get_relation(cls, profile, project_cfg, schema=None, identifier=None, + def get_relation(cls, config, schema=None, identifier=None, relations_list=None, model_name=None): if schema is None and relations_list is None: raise dbt.exceptions.RuntimeException( @@ -223,32 +216,31 @@ def get_relation(cls, profile, project_cfg, schema=None, identifier=None, 'of relations to use') if relations_list is None and identifier is not None: - table = cls.get_bq_table(profile, project_cfg, schema, identifier) + table = cls.get_bq_table(config, schema, identifier) return cls.bq_table_to_relation(table) return super(BigQueryAdapter, cls).get_relation( - profile, project_cfg, schema, identifier, relations_list, + config, schema, identifier, relations_list, model_name) @classmethod - def drop_relation(cls, profile, project_cfg, relation, model_name=None): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + def drop_relation(cls, config, relation, model_name=None): + conn = cls.get_connection(config, model_name) + client = conn.handle - dataset = cls.get_dataset( - profile, project_cfg, relation.schema, model_name) + dataset = cls.get_dataset(config, relation.schema, model_name) relation_object = dataset.table(relation.identifier) client.delete_table(relation_object) @classmethod - def rename(cls, profile, project_cfg, schema, + def rename(cls, config, schema, from_name, to_name, model_name=None): raise dbt.exceptions.NotImplementedException( '`rename` is not implemented for this adapter!') @classmethod - def rename_relation(cls, profile, project_cfg, from_relation, to_relation, + def rename_relation(cls, config, from_relation, to_relation, model_name=None): raise dbt.exceptions.NotImplementedException( '`rename_relation` is not implemented for this adapter!') @@ -259,13 +251,13 @@ def get_timeout(cls, conn): return credentials.get('timeout_seconds', cls.QUERY_TIMEOUT) @classmethod - def materialize_as_view(cls, profile, project_cfg, dataset, model): + def materialize_as_view(cls, config, dataset, model): model_name = model.get('name') model_alias = model.get('alias') model_sql = model.get('injected_sql') - conn = cls.get_connection(profile, project_cfg, model_name) - client = conn.get('handle') + conn = cls.get_connection(config, model_name) + client = conn.handle view_ref = dataset.table(model_alias) view = google.cloud.bigquery.Table(view_ref) @@ -274,7 +266,7 @@ def materialize_as_view(cls, profile, project_cfg, dataset, model): logger.debug("Model SQL ({}):\n{}".format(model_name, model_sql)) - with cls.exception_handler(profile, model_sql, model_name, model_name): + with cls.exception_handler(config, model_sql, model_name, model_name): client.create_table(view) return "CREATE VIEW" @@ -295,13 +287,12 @@ def poll_until_job_completes(cls, job, timeout): raise job.exception() @classmethod - def make_date_partitioned_table(cls, profile, project_cfg, dataset_name, - identifier, model_name=None): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + def make_date_partitioned_table(cls, config, dataset_name, identifier, + model_name=None): + conn = cls.get_connection(config, model_name) + client = conn.handle - dataset = cls.get_dataset(profile, project_cfg, - dataset_name, identifier) + dataset = cls.get_dataset(config, dataset_name, identifier) table_ref = dataset.table(identifier) table = google.cloud.bigquery.Table(table_ref) table.partitioning_type = 'DAY' @@ -309,13 +300,13 @@ def make_date_partitioned_table(cls, profile, project_cfg, dataset_name, return client.create_table(table) @classmethod - def materialize_as_table(cls, profile, project_cfg, dataset, - model, model_sql, decorator=None): + def materialize_as_table(cls, config, dataset, model, model_sql, + decorator=None): model_name = model.get('name') model_alias = model.get('alias') - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + conn = cls.get_connection(config, model_name) + client = conn.handle if decorator is None: table_name = model_alias @@ -331,14 +322,14 @@ def materialize_as_table(cls, profile, project_cfg, dataset, query_job = client.query(model_sql, job_config=job_config) # this waits for the job to complete - with cls.exception_handler(profile, model_sql, model_alias, + with cls.exception_handler(config, model_sql, model_alias, model_name): query_job.result(timeout=cls.get_timeout(conn)) return "CREATE TABLE" @classmethod - def execute_model(cls, profile, project_cfg, model, + def execute_model(cls, config, model, materialization, sql_override=None, decorator=None, model_name=None): @@ -346,20 +337,20 @@ def execute_model(cls, profile, project_cfg, model, sql_override = model.get('injected_sql') if flags.STRICT_MODE: - connection = cls.get_connection(profile, model.get('name')) + connection = cls.get_connection(config, model.get('name')) Connection(**connection) model_name = model.get('name') model_schema = model.get('schema') - dataset = cls.get_dataset(profile, project_cfg, + dataset = cls.get_dataset(config, model_schema, model_name) if materialization == 'view': - res = cls.materialize_as_view(profile, project_cfg, dataset, model) + res = cls.materialize_as_view(config, dataset, model) elif materialization == 'table': res = cls.materialize_as_table( - profile, project_cfg, dataset, model, + config, dataset, model, sql_override, decorator) else: msg = "Invalid relation type: '{}'".format(materialization) @@ -368,9 +359,9 @@ def execute_model(cls, profile, project_cfg, model, return res @classmethod - def raw_execute(cls, profile, sql, model_name=None, fetch=False, **kwargs): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + def raw_execute(cls, config, sql, model_name=None, fetch=False, **kwargs): + conn = cls.get_connection(config, model_name) + client = conn.handle logger.debug('On %s: %s', model_name, sql) @@ -379,17 +370,17 @@ def raw_execute(cls, profile, sql, model_name=None, fetch=False, **kwargs): query_job = client.query(sql, job_config) # this blocks until the query has completed - with cls.exception_handler(profile, sql, model_name): + with cls.exception_handler(config, sql, model_name): iterator = query_job.result() return query_job, iterator @classmethod - def create_temporary_table(cls, profile, project, sql, model_name=None, + def create_temporary_table(cls, config, sql, model_name=None, **kwargs): # BQ queries always return a temp table with their results - query_job, _ = cls.raw_execute(profile, sql, model_name) + query_job, _ = cls.raw_execute(config, sql, model_name) bq_table = query_job.destination return cls.Relation.create( @@ -403,17 +394,16 @@ def create_temporary_table(cls, profile, project, sql, model_name=None, type=BigQueryRelation.Table) @classmethod - def alter_table_add_columns(cls, profile, project, relation, columns, + def alter_table_add_columns(cls, config, relation, columns, model_name=None): logger.debug('Adding columns ({}) to table {}".'.format( columns, relation)) - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + conn = cls.get_connection(config, model_name) + client = conn.handle - dataset = cls.get_dataset(profile, project, relation.schema, - model_name) + dataset = cls.get_dataset(config, relation.schema, model_name) table_ref = dataset.table(relation.name) table = client.get_table(table_ref) @@ -425,8 +415,8 @@ def alter_table_add_columns(cls, profile, project, relation, columns, client.update_table(new_table, ['schema']) @classmethod - def execute(cls, profile, sql, model_name=None, fetch=None, **kwargs): - _, iterator = cls.raw_execute(profile, sql, model_name, fetch, + def execute(cls, config, sql, model_name=None, fetch=None, **kwargs): + _, iterator = cls.raw_execute(config, sql, model_name, fetch, **kwargs) if fetch: @@ -439,8 +429,8 @@ def execute(cls, profile, sql, model_name=None, fetch=None, **kwargs): return status, res @classmethod - def execute_and_fetch(cls, profile, sql, model_name, auto_begin=None): - status, table = cls.execute(profile, sql, model_name, fetch=True) + def execute_and_fetch(cls, config, sql, model_name, auto_begin=None): + status, table = cls.execute(config, sql, model_name, fetch=True) return status, table @classmethod @@ -452,72 +442,71 @@ def get_table_from_response(cls, resp): # BigQuery doesn't support BEGIN/COMMIT, so stub these out. @classmethod - def add_begin_query(cls, profile, name): + def add_begin_query(cls, config, name): pass @classmethod - def add_commit_query(cls, profile, name): + def add_commit_query(cls, config, name): pass @classmethod - def create_schema(cls, profile, project_cfg, schema, model_name=None): + def create_schema(cls, config, schema, model_name=None): logger.debug('Creating schema "%s".', schema) - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + conn = cls.get_connection(config, model_name) + client = conn.handle - dataset = cls.get_dataset(profile, project_cfg, schema, model_name) + dataset = cls.get_dataset(config, schema, model_name) # Emulate 'create schema if not exists ...' try: client.get_dataset(dataset) except google.api_core.exceptions.NotFound: - with cls.exception_handler(profile, 'create dataset', model_name): + with cls.exception_handler(config, 'create dataset', model_name): client.create_dataset(dataset) @classmethod - def drop_tables_in_schema(cls, profile, project_cfg, dataset): - conn = cls.get_connection(profile) - client = conn.get('handle') + def drop_tables_in_schema(cls, config, dataset): + conn = cls.get_connection(config) + client = conn.handle for table in client.list_tables(dataset): client.delete_table(table.reference) @classmethod - def drop_schema(cls, profile, project_cfg, schema, model_name=None): + def drop_schema(cls, config, schema, model_name=None): logger.debug('Dropping schema "%s".', schema) - if not cls.check_schema_exists(profile, project_cfg, + if not cls.check_schema_exists(config, schema, model_name): return - conn = cls.get_connection(profile) - client = conn.get('handle') + conn = cls.get_connection(config) + client = conn.handle - dataset = cls.get_dataset(profile, project_cfg, schema, model_name) - with cls.exception_handler(profile, 'drop dataset', model_name): - cls.drop_tables_in_schema(profile, project_cfg, dataset) + dataset = cls.get_dataset(config, schema, model_name) + with cls.exception_handler(config, 'drop dataset', model_name): + cls.drop_tables_in_schema(config, dataset) client.delete_dataset(dataset) @classmethod - def get_existing_schemas(cls, profile, project_cfg, model_name=None): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + def get_existing_schemas(cls, config, model_name=None): + conn = cls.get_connection(config, model_name) + client = conn.handle - with cls.exception_handler(profile, 'list dataset', model_name): + with cls.exception_handler(config, 'list dataset', model_name): all_datasets = client.list_datasets() return [ds.dataset_id for ds in all_datasets] @classmethod - def get_columns_in_table(cls, profile, project_cfg, - schema_name, table_name, + def get_columns_in_table(cls, config, schema_name, table_name, database=None, model_name=None): # BigQuery does not have databases -- the database parameter is here # for consistency with the base implementation - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + conn = cls.get_connection(config, model_name) + client = conn.handle try: dataset_ref = client.dataset(schema_name) @@ -544,21 +533,17 @@ def get_dbt_columns_from_bq_table(cls, table): return columns @classmethod - def check_schema_exists(cls, profile, project_cfg, - schema, model_name=None): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + def check_schema_exists(cls, config, schema, model_name=None): + conn = cls.get_connection(config, model_name) - with cls.exception_handler(profile, 'get dataset', model_name): - all_datasets = client.list_datasets() + with cls.exception_handler(config, 'get dataset', model_name): + all_datasets = conn.handle.list_datasets() return any([ds.dataset_id == schema for ds in all_datasets]) @classmethod - def get_dataset(cls, profile, project_cfg, dataset_name, model_name=None): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') - - dataset_ref = client.dataset(dataset_name) + def get_dataset(cls, config, dataset_name, model_name=None): + conn = cls.get_connection(config, model_name) + dataset_ref = conn.handle.dataset(dataset_name) return google.cloud.bigquery.Dataset(dataset_ref) @classmethod @@ -577,18 +562,15 @@ def bq_table_to_relation(cls, bq_table): type=cls.RELATION_TYPES.get(bq_table.table_type)) @classmethod - def get_bq_table(cls, profile, project_cfg, dataset_name, identifier, - model_name=None): - conn = cls.get_connection(profile, model_name) - client = conn.get('handle') + def get_bq_table(cls, config, dataset_name, identifier, model_name=None): + conn = cls.get_connection(config, model_name) - dataset = cls.get_dataset( - profile, project_cfg, dataset_name, model_name) + dataset = cls.get_dataset(config, dataset_name, model_name) table_ref = dataset.table(identifier) try: - return client.get_table(table_ref) + return conn.handle.get_table(table_ref) except google.cloud.exceptions.NotFound: return None @@ -599,7 +581,7 @@ def warning_on_hooks(cls, hook_type): dbt.ui.printer.COLOR_FG_YELLOW) @classmethod - def add_query(cls, profile, sql, model_name=None, auto_begin=True, + def add_query(cls, config, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): if model_name in ['on-run-start', 'on-run-end']: cls.warning_on_hooks(model_name) @@ -616,17 +598,14 @@ def quote(cls, identifier): return '`{}`'.format(identifier) @classmethod - def quote_schema_and_table(cls, profile, project_cfg, schema, + def quote_schema_and_table(cls, config, schema, table, model_name=None): - return cls.render_relation(profile, project_cfg, - cls.quote(schema), - cls.quote(table)) + return cls.render_relation(config, cls.quote(schema), cls.quote(table)) @classmethod - def render_relation(cls, profile, project_cfg, schema, table): - connection = cls.get_connection(profile) - credentials = connection.get('credentials', {}) - project = credentials.get('project') + def render_relation(cls, config, schema, table): + connection = cls.get_connection(config) + project = connection.credentials.project return '{}.{}.{}'.format(cls.quote(project), schema, table) @classmethod @@ -657,14 +636,13 @@ def _agate_to_schema(cls, agate_table, column_override): return bq_schema @classmethod - def load_dataframe(cls, profile, project_cfg, schema, - table_name, agate_table, + def load_dataframe(cls, config, schema, table_name, agate_table, column_override, model_name=None): bq_schema = cls._agate_to_schema(agate_table, column_override) - dataset = cls.get_dataset(profile, project_cfg, schema, None) + dataset = cls.get_dataset(config, schema, None) table = dataset.table(table_name) - conn = cls.get_connection(profile, None) - client = conn.get('handle') + conn = cls.get_connection(config, None) + client = conn.handle load_config = google.cloud.bigquery.LoadJobConfig() load_config.skip_leading_rows = 1 @@ -674,11 +652,11 @@ def load_dataframe(cls, profile, project_cfg, schema, job = client.load_table_from_file(f, table, rewind=True, job_config=load_config) - with cls.exception_handler(profile, "LOAD TABLE"): + with cls.exception_handler(config, "LOAD TABLE"): cls.poll_until_job_completes(job, cls.get_timeout(conn)) @classmethod - def expand_target_column_types(cls, profile, project_cfg, temp_table, + def expand_target_column_types(cls, config, temp_table, to_schema, to_table, model_name=None): # This is a no-op on BigQuery pass @@ -749,9 +727,9 @@ def _get_stats_columns(cls, table, relation_type): return zip(column_names, column_values) @classmethod - def get_catalog(cls, profile, project_cfg, manifest): - connection = cls.get_connection(profile, 'catalog') - client = connection.get('handle') + def get_catalog(cls, config, manifest): + connection = cls.get_connection(config, 'catalog') + client = connection.handle schemas = { node.to_dict()['schema'] @@ -774,7 +752,7 @@ def get_catalog(cls, profile, project_cfg, manifest): columns = [] for schema_name in schemas: - relations = cls.list_relations(profile, project_cfg, schema_name) + relations = cls.list_relations(config, schema_name) for relation in relations: # This relation contains a subset of the info we care about. diff --git a/dbt/adapters/bigquery/relation.py b/dbt/adapters/bigquery/relation.py index 1e696a848ee..f5807dc8e3e 100644 --- a/dbt/adapters/bigquery/relation.py +++ b/dbt/adapters/bigquery/relation.py @@ -86,9 +86,9 @@ def matches(self, project=None, schema=None, identifier=None): return True @classmethod - def create_from_node(cls, profile, node, **kwargs): + def create_from_node(cls, config, node, **kwargs): return cls.create( - project=profile.get('project'), + project=config.credentials.project, schema=node.get('schema'), identifier=node.get('alias'), **kwargs) diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 920a996c3b5..5fdf83667f5 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -98,7 +98,7 @@ class DefaultAdapter(object): ### @classmethod @contextmanager - def exception_handler(cls, profile, sql, model_name=None, + def exception_handler(cls, config, sql, model_name=None, connection_name=None): raise dbt.exceptions.NotImplementedException( '`exception_handler` is not implemented for this adapter!') @@ -119,13 +119,13 @@ def get_status(cls, cursor): '`get_status` is not implemented for this adapter!') @classmethod - def alter_column_type(cls, profile, project_cfg, schema, table, + def alter_column_type(cls, config, schema, table, column_name, new_column_type, model_name=None): raise dbt.exceptions.NotImplementedException( '`alter_column_type` is not implemented for this adapter!') @classmethod - def query_for_existing(cls, profile, project_cfg, schemas, + def query_for_existing(cls, config, schemas, model_name=None): if not isinstance(schemas, (list, tuple)): schemas = [schemas] @@ -134,23 +134,23 @@ def query_for_existing(cls, profile, project_cfg, schemas, for schema in schemas: all_relations.extend( - cls.list_relations(profile, project_cfg, schema, model_name)) + cls.list_relations(config, schema, model_name)) return {relation.identifier: relation.type for relation in all_relations} @classmethod - def get_existing_schemas(cls, profile, project_cfg, model_name=None): + def get_existing_schemas(cls, config, model_name=None): raise dbt.exceptions.NotImplementedException( '`get_existing_schemas` is not implemented for this adapter!') @classmethod - def check_schema_exists(cls, profile, project_cfg, schema): + def check_schema_exists(cls, config, schema): raise dbt.exceptions.NotImplementedException( '`check_schema_exists` is not implemented for this adapter!') @classmethod - def cancel_connection(cls, project_cfg, connection): + def cancel_connection(cls, config, connection): raise dbt.exceptions.NotImplementedException( '`cancel_connection` is not implemented for this adapter!') @@ -171,19 +171,18 @@ def get_result_from_cursor(cls, cursor): return dbt.clients.agate_helper.table_from_data(data, column_names) @classmethod - def drop(cls, profile, project_cfg, schema, - relation, relation_type, model_name=None): + def drop(cls, config, schema, relation, relation_type, model_name=None): identifier = relation relation = cls.Relation.create( schema=schema, identifier=identifier, type=relation_type, - quote_policy=project_cfg.get('quoting', {})) + quote_policy=config.quoting) - return cls.drop_relation(profile, project_cfg, relation, model_name) + return cls.drop_relation(config, relation, model_name) @classmethod - def drop_relation(cls, profile, project_cfg, relation, model_name=None): + def drop_relation(cls, config, relation, model_name=None): if relation.type is None: dbt.exceptions.raise_compiler_error( 'Tried to drop relation {}, but its type is null.' @@ -191,31 +190,30 @@ def drop_relation(cls, profile, project_cfg, relation, model_name=None): sql = 'drop {} if exists {} cascade'.format(relation.type, relation) - connection, cursor = cls.add_query(profile, sql, model_name, + connection, cursor = cls.add_query(config, sql, model_name, auto_begin=False) @classmethod - def truncate(cls, profile, project_cfg, schema, table, model_name=None): + def truncate(cls, config, schema, table, model_name=None): relation = cls.Relation.create( schema=schema, identifier=table, type='table', - quote_policy=project_cfg.get('quoting', {})) + quote_policy=config.quoting) - return cls.truncate_relation(profile, project_cfg, - relation, model_name) + return cls.truncate_relation(config, relation, model_name) @classmethod - def truncate_relation(cls, profile, project_cfg, + def truncate_relation(cls, config, relation, model_name=None): sql = 'truncate table {}'.format(relation) - connection, cursor = cls.add_query(profile, sql, model_name) + connection, cursor = cls.add_query(config, sql, model_name) @classmethod - def rename(cls, profile, project_cfg, schema, + def rename(cls, config, schema, from_name, to_name, model_name=None): - quote_policy = project_cfg.get('quoting', {}) + quote_policy = config.quoting from_relation = cls.Relation.create( schema=schema, identifier=from_name, @@ -226,25 +224,25 @@ def rename(cls, profile, project_cfg, schema, quote_policy=quote_policy ) return cls.rename_relation( - profile, project_cfg, + config, from_relation=from_relation, to_relation=to_relation, model_name=model_name) @classmethod - def rename_relation(cls, profile, project_cfg, from_relation, - to_relation, model_name=None): + def rename_relation(cls, config, from_relation, to_relation, + model_name=None): sql = 'alter table {} rename to {}'.format( from_relation, to_relation.include(schema=False)) - connection, cursor = cls.add_query(profile, sql, model_name) + connection, cursor = cls.add_query(config, sql, model_name) @classmethod def is_cancelable(cls): return True @classmethod - def get_missing_columns(cls, profile, project_cfg, + def get_missing_columns(cls, config, from_schema, from_table, to_schema, to_table, model_name=None): @@ -252,11 +250,11 @@ def get_missing_columns(cls, profile, project_cfg, missing from to_table""" from_columns = {col.name: col for col in cls.get_columns_in_table( - profile, project_cfg, from_schema, from_table, + config, from_schema, from_table, model_name=model_name)} to_columns = {col.name: col for col in cls.get_columns_in_table( - profile, project_cfg, to_schema, to_table, + config, to_schema, to_table, model_name=model_name)} missing_columns = set(from_columns.keys()) - set(to_columns.keys()) @@ -290,11 +288,10 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): return sql @classmethod - def get_columns_in_table(cls, profile, project_cfg, schema_name, + def get_columns_in_table(cls, config, schema_name, table_name, database=None, model_name=None): sql = cls._get_columns_in_table_sql(schema_name, table_name, database) - connection, cursor = cls.add_query( - profile, sql, model_name) + connection, cursor = cls.add_query(config, sql, model_name) data = cursor.fetchall() columns = [] @@ -311,18 +308,18 @@ def _table_columns_to_dict(cls, columns): return {col.name: col for col in columns} @classmethod - def expand_target_column_types(cls, profile, project_cfg, + def expand_target_column_types(cls, config, temp_table, to_schema, to_table, model_name=None): reference_columns = cls._table_columns_to_dict( cls.get_columns_in_table( - profile, project_cfg, None, temp_table, model_name=model_name)) + config, None, temp_table, model_name=model_name)) target_columns = cls._table_columns_to_dict( cls.get_columns_in_table( - profile, project_cfg, to_schema, to_table, + config, to_schema, to_table, model_name=model_name)) for column_name, reference_column in reference_columns.items(): @@ -338,7 +335,7 @@ def expand_target_column_types(cls, profile, project_cfg, to_schema, to_table) - cls.alter_column_type(profile, project_cfg, to_schema, + cls.alter_column_type(config, to_schema, to_table, column_name, new_type, model_name) @@ -346,25 +343,23 @@ def expand_target_column_types(cls, profile, project_cfg, # RELATIONS ### @classmethod - def list_relations(cls, profile, project_cfg, schema, model_name=None): + def list_relations(cls, config, schema, model_name=None): raise dbt.exceptions.NotImplementedException( '`list_relations` is not implemented for this adapter!') @classmethod - def _make_match_kwargs(cls, project_cfg, schema, identifier): - if identifier is not None and \ - project_cfg.get('quoting', {}).get('identifier') is False: + def _make_match_kwargs(cls, config, schema, identifier): + if identifier is not None and config.quoting['identifier'] is False: identifier = identifier.lower() - if schema is not None and \ - project_cfg.get('quoting', {}).get('schema') is False: + if schema is not None and config.quoting['schema'] is False: schema = schema.lower() return filter_null_values({'identifier': identifier, 'schema': schema}) @classmethod - def get_relation(cls, profile, project_cfg, schema=None, identifier=None, + def get_relation(cls, config, schema=None, identifier=None, relations_list=None, model_name=None): if schema is None and relations_list is None: raise dbt.exceptions.RuntimeException( @@ -372,12 +367,11 @@ def get_relation(cls, profile, project_cfg, schema=None, identifier=None, 'of relations to use') if relations_list is None: - relations_list = cls.list_relations( - profile, project_cfg, schema, model_name) + relations_list = cls.list_relations(config, schema, model_name) matches = [] - search = cls._make_match_kwargs(project_cfg, schema, identifier) + search = cls._make_match_kwargs(config, schema, identifier) for relation in relations_list: if relation.matches(**search): @@ -396,15 +390,15 @@ def get_relation(cls, profile, project_cfg, schema=None, identifier=None, # SANE ANSI SQL DEFAULTS ### @classmethod - def get_create_schema_sql(cls, project_cfg, schema): - schema = cls._quote_as_configured(project_cfg, schema, 'schema') + def get_create_schema_sql(cls, config, schema): + schema = cls._quote_as_configured(config, schema, 'schema') return ('create schema if not exists {schema}' .format(schema=schema)) @classmethod - def get_drop_schema_sql(cls, project_cfg, schema): - schema = cls._quote_as_configured(project_cfg, schema, 'schema') + def get_drop_schema_sql(cls, config, schema): + schema = cls._quote_as_configured(config, schema, 'schema') return ('drop schema if exists {schema} cascade' .format(schema=schema)) @@ -414,11 +408,11 @@ def get_drop_schema_sql(cls, project_cfg, schema): # although some adapters may override them ### @classmethod - def get_default_schema(cls, profile, project_cfg): - return profile.get('schema') + def get_default_schema(cls, config): + return config.credentials.schema @classmethod - def get_connection(cls, profile, name=None, recache_if_missing=True): + def get_connection(cls, config, name=None, recache_if_missing=True): global connections_in_use if name is None: @@ -437,20 +431,20 @@ def get_connection(cls, profile, name=None, recache_if_missing=True): logger.debug('Acquiring new {} connection "{}".' .format(cls.type(), name)) - connection = cls.acquire_connection(profile, name) + connection = cls.acquire_connection(config, name) connections_in_use[name] = connection - return cls.get_connection(profile, name) + return cls.get_connection(config, name) @classmethod - def cancel_open_connections(cls, profile): + def cancel_open_connections(cls, config): global connections_in_use for name, connection in connections_in_use.items(): if name == 'master': continue - cls.cancel_connection(profile, connection) + cls.cancel_connection(config, connection) yield name @classmethod @@ -460,22 +454,21 @@ def total_connections_allocated(cls): return len(connections_in_use) + len(connections_available) @classmethod - def acquire_connection(cls, profile, name): + def acquire_connection(cls, config, name): global connections_available, lock # we add a magic number, 2 because there are overhead connections, # one for pre- and post-run hooks and other misc operations that occur # before the run starts, and one for integration tests. - max_connections = profile.get('threads', 1) + 2 + max_connections = config.threads + 2 - try: - lock.acquire() + with lock: num_allocated = cls.total_connections_allocated() if len(connections_available) > 0: logger.debug('Re-using an available connection from the pool.') to_return = connections_available.pop() - to_return['name'] = name + to_return.name = name return to_return elif num_allocated >= max_connections: @@ -487,46 +480,35 @@ def acquire_connection(cls, profile, name): logger.debug('Opening a new connection ({} currently allocated)' .format(num_allocated)) - credentials = copy.deepcopy(profile) - - credentials.pop('type', None) - credentials.pop('threads', None) - - result = { - 'type': cls.type(), - 'name': name, - 'state': 'init', - 'transaction_open': False, - 'handle': None, - 'credentials': credentials - } - - if dbt.flags.STRICT_MODE: - Connection(**result) + result = Connection( + type=cls.type(), + name=name, + state='init', + transaction_open=False, + handle=None, + credentials=config.credentials + ) return cls.open_connection(result) - finally: - lock.release() @classmethod - def release_connection(cls, profile, name='master'): + def release_connection(cls, config, name='master'): global connections_in_use, connections_available, lock - if connections_in_use.get(name) is None: + if name not in connections_in_use: return - to_release = cls.get_connection(profile, name, - recache_if_missing=False) + to_release = cls.get_connection(config, name, recache_if_missing=False) try: lock.acquire() - if to_release.get('state') == 'open': + if to_release.state == 'open': - if to_release.get('transaction_open') is True: + if to_release.transaction_open is True: cls.rollback(to_release) - to_release['name'] = None + to_release.name = None connections_available.append(to_release) else: cls.close(to_release) @@ -539,9 +521,7 @@ def release_connection(cls, profile, name='master'): def cleanup_connections(cls): global connections_in_use, connections_available, lock - try: - lock.acquire() - + with lock: for name, connection in connections_in_use.items(): if connection.get('state') != 'closed': logger.debug("Connection '{}' was left open." @@ -558,44 +538,41 @@ def cleanup_connections(cls): connections_in_use = {} connections_available = [] - finally: - lock.release() - @classmethod def reload(cls, connection): - return cls.get_connection(connection.get('credentials'), - connection.get('name')) + return cls.get_connection(connection.credentials, + connection.name) @classmethod - def add_begin_query(cls, profile, name): - return cls.add_query(profile, 'BEGIN', name, auto_begin=False) + def add_begin_query(cls, config, name): + return cls.add_query(config, 'BEGIN', name, auto_begin=False) @classmethod - def add_commit_query(cls, profile, name): - return cls.add_query(profile, 'COMMIT', name, auto_begin=False) + def add_commit_query(cls, config, name): + return cls.add_query(config, 'COMMIT', name, auto_begin=False) @classmethod - def begin(cls, profile, name='master'): + def begin(cls, config, name='master'): global connections_in_use - connection = cls.get_connection(profile, name) + connection = cls.get_connection(config, name) if dbt.flags.STRICT_MODE: - Connection(**connection) + assert isinstance(connection, Connection) - if connection['transaction_open'] is True: + if connection.transaction_open is True: raise dbt.exceptions.InternalException( 'Tried to begin a new transaction on connection "{}", but ' 'it already had one open!'.format(connection.get('name'))) - cls.add_begin_query(profile, name) + cls.add_begin_query(config, name) - connection['transaction_open'] = True + connection.transaction_open = True connections_in_use[name] = connection return connection @classmethod - def commit_if_has_connection(cls, profile, name): + def commit_if_has_connection(cls, config, name): global connections_in_use if name is None: @@ -604,29 +581,29 @@ def commit_if_has_connection(cls, profile, name): if connections_in_use.get(name) is None: return - connection = cls.get_connection(profile, name, False) + connection = cls.get_connection(config, name, False) - return cls.commit(profile, connection) + return cls.commit(config, connection) @classmethod - def commit(cls, profile, connection): + def commit(cls, config, connection): global connections_in_use if dbt.flags.STRICT_MODE: - Connection(**connection) + assert isinstance(connection, Connection) connection = cls.reload(connection) - if connection['transaction_open'] is False: + if connection.transaction_open is False: raise dbt.exceptions.InternalException( 'Tried to commit transaction on connection "{}", but ' - 'it does not have one open!'.format(connection.get('name'))) + 'it does not have one open!'.format(connection.name)) - logger.debug('On {}: COMMIT'.format(connection.get('name'))) - cls.add_commit_query(profile, connection.get('name')) + logger.debug('On {}: COMMIT'.format(connection.name)) + cls.add_commit_query(config, connection.name) - connection['transaction_open'] = False - connections_in_use[connection.get('name')] = connection + connection.transaction_open = False + connections_in_use[connection.name] = connection return connection @@ -637,49 +614,52 @@ def rollback(cls, connection): connection = cls.reload(connection) - if connection['transaction_open'] is False: + if connection.transaction_open is False: raise dbt.exceptions.InternalException( 'Tried to rollback transaction on connection "{}", but ' - 'it does not have one open!'.format(connection.get('name'))) + 'it does not have one open!'.format(connection.name)) - logger.debug('On {}: ROLLBACK'.format(connection.get('name'))) - connection.get('handle').rollback() + logger.debug('On {}: ROLLBACK'.format(connection.name)) + connection.handle.rollback() - connection['transaction_open'] = False - connections_in_use[connection.get('name')] = connection + connection.transaction_open = False + connections_in_use[connection.name] = connection return connection @classmethod def close(cls, connection): if dbt.flags.STRICT_MODE: - Connection(**connection) + assert isinstance(connection, Connection) + + # On windows, sometimes connection handles don't have a close() attr. + if hasattr(connection.handle, 'close'): + connection.handle.close() - connection.get('handle').close() - connection['state'] = 'closed' + connection.state = 'closed' return connection @classmethod - def add_query(cls, profile, sql, model_name=None, auto_begin=True, + def add_query(cls, config, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): - connection = cls.get_connection(profile, model_name) - connection_name = connection.get('name') + connection = cls.get_connection(config, model_name) + connection_name = connection.name - if auto_begin and connection['transaction_open'] is False: - cls.begin(profile, connection_name) + if auto_begin and connection.transaction_open is False: + cls.begin(config, connection_name) logger.debug('Using {} connection "{}".' .format(cls.type(), connection_name)) - with cls.exception_handler(profile, sql, model_name, connection_name): + with cls.exception_handler(config, sql, model_name, connection_name): if abridge_sql_log: logger.debug('On %s: %s....', connection_name, sql[0:512]) else: logger.debug('On %s: %s', connection_name, sql) pre = time.time() - cursor = connection.get('handle').cursor() + cursor = connection.handle.cursor() cursor.execute(sql, bindings) logger.debug("SQL status: %s in %0.2f seconds", @@ -688,69 +668,67 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True, return connection, cursor @classmethod - def clear_transaction(cls, profile, conn_name='master'): - conn = cls.begin(profile, conn_name) - cls.commit(profile, conn) + def clear_transaction(cls, config, conn_name='master'): + conn = cls.begin(config, conn_name) + cls.commit(config, conn) return conn_name @classmethod - def execute_one(cls, profile, sql, model_name=None, auto_begin=False): - cls.get_connection(profile, model_name) + def execute_one(cls, config, sql, model_name=None, auto_begin=False): + cls.get_connection(config, model_name) - return cls.add_query(profile, sql, model_name, auto_begin) + return cls.add_query(config, sql, model_name, auto_begin) @classmethod - def execute_and_fetch(cls, profile, sql, model_name=None, + def execute_and_fetch(cls, config, sql, model_name=None, auto_begin=False): - _, cursor = cls.execute_one(profile, sql, model_name, auto_begin) + _, cursor = cls.execute_one(config, sql, model_name, auto_begin) status = cls.get_status(cursor) table = cls.get_result_from_cursor(cursor) return status, table @classmethod - def execute(cls, profile, sql, model_name=None, auto_begin=False, + def execute(cls, config, sql, model_name=None, auto_begin=False, fetch=False): if fetch: - return cls.execute_and_fetch(profile, sql, model_name, auto_begin) + return cls.execute_and_fetch(config, sql, model_name, auto_begin) else: - _, cursor = cls.execute_one(profile, sql, model_name, auto_begin) + _, cursor = cls.execute_one(config, sql, model_name, auto_begin) status = cls.get_status(cursor) return status, dbt.clients.agate_helper.empty_table() @classmethod - def execute_all(cls, profile, sqls, model_name=None): - connection = cls.get_connection(profile, model_name) + def execute_all(cls, config, sqls, model_name=None): + connection = cls.get_connection(config, model_name) if len(sqls) == 0: return connection for i, sql in enumerate(sqls): - connection, _ = cls.add_query(profile, sql, model_name) + connection, _ = cls.add_query(config, sql, model_name) return connection @classmethod - def create_schema(cls, profile, project_cfg, schema, model_name=None): + def create_schema(cls, config, schema, model_name=None): logger.debug('Creating schema "%s".', schema) - sql = cls.get_create_schema_sql(project_cfg, schema) - res = cls.add_query(profile, sql, model_name) + sql = cls.get_create_schema_sql(config, schema) + res = cls.add_query(config, sql, model_name) - cls.commit_if_has_connection(profile, model_name) + cls.commit_if_has_connection(config, model_name) return res @classmethod - def drop_schema(cls, profile, project_cfg, schema, model_name=None): + def drop_schema(cls, config, schema, model_name=None): logger.debug('Dropping schema "%s".', schema) - sql = cls.get_drop_schema_sql(project_cfg, schema) - return cls.add_query(profile, sql, model_name) + sql = cls.get_drop_schema_sql(config, schema) + return cls.add_query(config, sql, model_name) @classmethod - def already_exists(cls, profile, project_cfg, - schema, table, model_name=None): - relation = cls.get_relation( - profile, project_cfg, schema=schema, identifier=table) + def already_exists(cls, config, schema, table, model_name=None): + relation = cls.get_relation(config, schema=schema, identifier=table) return relation is not None @classmethod @@ -758,18 +736,18 @@ def quote(cls, identifier): return '"{}"'.format(identifier) @classmethod - def _quote_as_configured(cls, project_cfg, identifier, quote_key): + def _quote_as_configured(cls, config, identifier, quote_key): """This is the actual implementation of quote_as_configured, without the extra arguments needed for use inside materialization code. """ default = cls.Relation.DEFAULTS['quote_policy'].get(quote_key) - if project_cfg.get('quoting', {}).get(quote_key, default): + if config.quoting.get(quote_key, default): return cls.quote(identifier) else: return identifier @classmethod - def quote_as_configured(cls, profile, project_cfg, identifier, quote_key, + def quote_as_configured(cls, config, identifier, quote_key, model_name=None): """Quote or do not quote the given identifer as configured in the project config for the quote key. @@ -777,7 +755,7 @@ def quote_as_configured(cls, profile, project_cfg, identifier, quote_key, The quote key should be one of 'database' (on bigquery, 'profile'), 'identifier', or 'schema', or it will be treated as if you set `True`. """ - return cls._quote_as_configured(project_cfg, identifier, quote_key) + return cls._quote_as_configured(config, identifier, quote_key) @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -832,7 +810,7 @@ def convert_agate_type(cls, agate_table, col_idx): # Operations involving the manifest ### @classmethod - def run_operation(cls, profile, project_cfg, manifest, operation_name): + def run_operation(cls, config, manifest, operation_name): """Look the operation identified by operation_name up in the manifest and run it. @@ -846,7 +824,7 @@ def run_operation(cls, profile, project_cfg, manifest, operation_name): import dbt.context.runtime context = dbt.context.runtime.generate( operation, - project_cfg, + config, manifest, ) @@ -861,12 +839,12 @@ def _filter_table(cls, table, manifest): return table.where(_filter_schemas(manifest)) @classmethod - def get_catalog(cls, profile, project_cfg, manifest): + def get_catalog(cls, config, manifest): try: - table = cls.run_operation(profile, project_cfg, manifest, + table = cls.run_operation(config, manifest, GET_CATALOG_OPERATION_NAME) finally: - cls.release_connection(profile, GET_CATALOG_OPERATION_NAME) + cls.release_connection(config, GET_CATALOG_OPERATION_NAME) results = cls._filter_table(table, manifest) return results diff --git a/dbt/adapters/default/relation.py b/dbt/adapters/default/relation.py index 686b0a003f4..b5d1f46d38f 100644 --- a/dbt/adapters/default/relation.py +++ b/dbt/adapters/default/relation.py @@ -173,9 +173,9 @@ def quoted(self, identifier): identifier=identifier) @classmethod - def create_from_node(cls, profile, node, table_name=None, **kwargs): + def create_from_node(cls, project, node, table_name=None, **kwargs): return cls.create( - database=profile.get('dbname'), + database=project.credentials.dbname, schema=node.get('schema'), identifier=node.get('alias'), table_name=table_name, diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index a7be4e00e17..7d450dc2b73 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -29,6 +29,5 @@ def get_adapter_by_name(adapter_name): return adapter -def get_adapter(profile): - adapter_type = profile.get('type', None) - return get_adapter_by_name(adapter_type) +def get_adapter(config): + return get_adapter_by_name(config.credentials.type) diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 065529bf416..3ad39e3fb46 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -16,7 +16,7 @@ class PostgresAdapter(dbt.adapters.default.DefaultAdapter): @classmethod @contextmanager - def exception_handler(cls, profile, sql, model_name=None, + def exception_handler(cls, config, sql, model_name=None, connection_name=None): try: yield @@ -26,7 +26,7 @@ def exception_handler(cls, profile, sql, model_name=None, try: # attempt to release the connection - cls.release_connection(profile, connection_name) + cls.release_connection(config, connection_name) except psycopg2.Error: logger.debug("Failed to release connection!") pass @@ -37,7 +37,7 @@ def exception_handler(cls, profile, sql, model_name=None, except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - cls.release_connection(profile, connection_name) + cls.release_connection(config, connection_name) raise dbt.exceptions.RuntimeException(e) @classmethod @@ -58,14 +58,12 @@ def get_credentials(cls, credentials): @classmethod def open_connection(cls, connection): - if connection.get('state') == 'open': + if connection.state == 'open': logger.debug('Connection is already open, skipping open.') return connection - result = connection.copy() - - base_credentials = connection.get('credentials', {}) - credentials = cls.get_credentials(base_credentials.copy()) + base_credentials = connection.credentials + credentials = cls.get_credentials(connection.credentials.incorporate()) kwargs = {} keepalives_idle = credentials.get('keepalives_idle', cls.DEFAULT_TCP_KEEPALIVE) @@ -76,38 +74,38 @@ def open_connection(cls, connection): try: handle = psycopg2.connect( - dbname=credentials.get('dbname'), - user=credentials.get('user'), - host=credentials.get('host'), - password=credentials.get('pass'), - port=credentials.get('port'), + dbname=credentials.dbname, + user=credentials.user, + host=credentials.host, + password=credentials.password, + port=credentials.port, connect_timeout=10, **kwargs) - result['handle'] = handle - result['state'] = 'open' + connection.handle = handle + connection.state = 'open' except psycopg2.Error as e: logger.debug("Got an error when attempting to open a postgres " "connection: '{}'" .format(e)) - result['handle'] = None - result['state'] = 'fail' + connection.handle = None + connection.state = 'fail' raise dbt.exceptions.FailedToConnectException(str(e)) - return result + return connection @classmethod - def cancel_connection(cls, profile, connection): - connection_name = connection.get('name') - pid = connection.get('handle').get_backend_pid() + def cancel_connection(cls, config, connection): + connection_name = connection.name + pid = connection.handle.get_backend_pid() sql = "select pg_terminate_backend({})".format(pid) logger.debug("Cancelling query '{}' ({})".format(connection_name, pid)) - _, cursor = cls.add_query(profile, sql, 'master') + _, cursor = cls.add_query(config, sql, 'master') res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) @@ -116,7 +114,7 @@ def cancel_connection(cls, profile, connection): # These require the profile AND project, as they need to know # database-specific configs at the project level. @classmethod - def alter_column_type(cls, profile, project, schema, table, column_name, + def alter_column_type(cls, config, schema, table, column_name, new_column_type, model_name=None): """ 1. Create a new column (w/ temp name and correct type) @@ -128,7 +126,7 @@ def alter_column_type(cls, profile, project, schema, table, column_name, relation = cls.Relation.create( schema=schema, identifier=table, - quote_policy=project.get('quoting', {}) + quote_policy=config.quoting ) opts = { @@ -145,12 +143,12 @@ def alter_column_type(cls, profile, project, schema, table, column_name, alter table {relation} rename column "{tmp_column}" to "{old_column}"; """.format(**opts).strip() # noqa - connection, cursor = cls.add_query(profile, sql, model_name) + connection, cursor = cls.add_query(config, sql, model_name) return connection, cursor @classmethod - def list_relations(cls, profile, project, schema, model_name=None): + def list_relations(cls, config, schema, model_name=None): sql = """ select tablename as name, schemaname as schema, 'table' as type from pg_tables where schemaname ilike '{schema}' @@ -159,13 +157,13 @@ def list_relations(cls, profile, project, schema, model_name=None): where schemaname ilike '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(profile, sql, model_name, + connection, cursor = cls.add_query(config, sql, model_name, auto_begin=False) results = cursor.fetchall() return [cls.Relation.create( - database=profile.get('dbname'), + database=config.credentials.dbname, schema=_schema, identifier=name, quote_policy={ @@ -176,22 +174,22 @@ def list_relations(cls, profile, project, schema, model_name=None): for (name, _schema, type) in results] @classmethod - def get_existing_schemas(cls, profile, project, model_name=None): + def get_existing_schemas(cls, config, model_name=None): sql = "select distinct nspname from pg_namespace" - connection, cursor = cls.add_query(profile, sql, model_name, + connection, cursor = cls.add_query(config, sql, model_name, auto_begin=False) results = cursor.fetchall() return [row[0] for row in results] @classmethod - def check_schema_exists(cls, profile, project, schema, model_name=None): + def check_schema_exists(cls, config, schema, model_name=None): sql = """ select count(*) from pg_namespace where nspname = '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(profile, sql, model_name, + connection, cursor = cls.add_query(config, sql, model_name, auto_begin=False) results = cursor.fetchone() diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index d908a3da298..a1f36cb297e 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -54,21 +54,21 @@ def get_tmp_iam_cluster_credentials(cls, credentials): "authentication method selected") cluster_creds = cls.fetch_cluster_credentials( - credentials.get('user'), - credentials.get('dbname'), - credentials.get('cluster_id'), + credentials.user, + credentials.dbname, + credentials.cluster_id, iam_duration_s, ) # replace username and password with temporary redshift credentials - return dbt.utils.merge(credentials, { - 'user': cluster_creds.get('DbUser'), - 'pass': cluster_creds.get('DbPassword') - }) + return credentials.incorporate( + user=cluster_creds.get('DbUser'), + password=cluster_creds.get('DbPassword') + ) @classmethod def get_credentials(cls, credentials): - method = credentials.get('method') + method = credentials.method # Support missing 'method' for backwards compatibility if method == 'database' or method is None: @@ -157,7 +157,7 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): return sql @classmethod - def drop_relation(cls, profile, project, relation, model_name=None): + def drop_relation(cls, config, relation, model_name=None): """ In Redshift, DROP TABLE ... CASCADE should not be used inside a transaction. Redshift doesn't prevent the CASCADE @@ -176,27 +176,23 @@ def drop_relation(cls, profile, project, relation, model_name=None): to_return = None - try: - drop_lock.acquire() + with drop_lock: - connection = cls.get_connection(profile, model_name) + connection = cls.get_connection(config, model_name) - if connection.get('transaction_open'): - cls.commit(profile, connection) + if connection.transaction_open: + cls.commit(config, connection) - cls.begin(profile, connection.get('name')) + cls.begin(config, connection.name) to_return = super(PostgresAdapter, cls).drop_relation( - profile, project, relation, model_name) + config, relation, model_name) - cls.commit(profile, connection) - cls.begin(profile, connection.get('name')) + cls.commit(config, connection) + cls.begin(config, connection.name) return to_return - finally: - drop_lock.release() - @classmethod def convert_text_type(cls, agate_table, col_idx): column = agate_table.columns[col_idx] diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 18214da1dc4..47bc08efa96 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -22,9 +22,9 @@ class SnowflakeAdapter(PostgresAdapter): @classmethod @contextmanager - def exception_handler(cls, profile, sql, model_name=None, + def exception_handler(cls, config, sql, model_name=None, connection_name='master'): - connection = cls.get_connection(profile, connection_name) + connection = cls.get_connection(config, connection_name) try: yield @@ -36,7 +36,7 @@ def exception_handler(cls, profile, sql, model_name=None, if 'Empty SQL statement' in msg: logger.debug("got empty sql statement, moving on") elif 'This session does not have a current database' in msg: - cls.release_connection(profile, connection_name) + cls.release_connection(config, connection_name) raise dbt.exceptions.FailedToConnectException( ('{}\n\nThis error sometimes occurs when invalid ' 'credentials are provided, or when your default role ' @@ -44,12 +44,12 @@ def exception_handler(cls, profile, sql, model_name=None, 'Please double check your profile and try again.') .format(msg)) else: - cls.release_connection(profile, connection_name) + cls.release_connection(config, connection_name) raise dbt.exceptions.DatabaseException(msg) except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - cls.release_connection(profile, connection_name) + cls.release_connection(config, connection_name) raise dbt.exceptions.RuntimeException(e.msg) @classmethod @@ -71,41 +71,39 @@ def get_status(cls, cursor): @classmethod def open_connection(cls, connection): - if connection.get('state') == 'open': + if connection.state == 'open': logger.debug('Connection is already open, skipping open.') return connection - result = connection.copy() - try: - credentials = connection.get('credentials', {}) + credentials = connection.credentials handle = snowflake.connector.connect( - account=credentials.get('account'), - user=credentials.get('user'), - password=credentials.get('password'), - database=credentials.get('database'), - schema=credentials.get('schema'), - warehouse=credentials.get('warehouse'), + account=credentials.account, + user=credentials.user, + password=credentials.password, + database=credentials.database, + schema=credentials.schema, + warehouse=credentials.warehouse, role=credentials.get('role', None), autocommit=False ) - result['handle'] = handle - result['state'] = 'open' + connection.handle = handle + connection.state = 'open' except snowflake.connector.errors.Error as e: logger.debug("Got an error when attempting to open a snowflake " "connection: '{}'" .format(e)) - result['handle'] = None - result['state'] = 'fail' + connection.handle = None + connection.state = 'fail' raise dbt.exceptions.FailedToConnectException(str(e)) - return result + return connection @classmethod - def list_relations(cls, profile, project_cfg, schema, model_name=None): + def list_relations(cls, config, schema, model_name=None): sql = """ select table_name as name, table_schema as schema, table_type as type @@ -114,7 +112,7 @@ def list_relations(cls, profile, project_cfg, schema, model_name=None): """.format(schema=schema).strip() # noqa _, cursor = cls.add_query( - profile, sql, model_name, auto_begin=False) + config, sql, model_name, auto_begin=False) results = cursor.fetchall() @@ -124,7 +122,7 @@ def list_relations(cls, profile, project_cfg, schema, model_name=None): } return [cls.Relation.create( - database=profile.get('database'), + database=config.credentials.database, schema=_schema, identifier=name, quote_policy={ @@ -135,37 +133,36 @@ def list_relations(cls, profile, project_cfg, schema, model_name=None): for (name, _schema, type) in results] @classmethod - def rename_relation(cls, profile, project_cfg, from_relation, - to_relation, model_name=None): + def rename_relation(cls, config, from_relation, to_relation, + model_name=None): sql = 'alter table {} rename to {}'.format( from_relation, to_relation) - connection, cursor = cls.add_query(profile, sql, model_name) + connection, cursor = cls.add_query(config, sql, model_name) @classmethod - def add_begin_query(cls, profile, name): - return cls.add_query(profile, 'BEGIN', name, auto_begin=False) + def add_begin_query(cls, config, name): + return cls.add_query(config, 'BEGIN', name, auto_begin=False) @classmethod - def get_existing_schemas(cls, profile, project_cfg, model_name=None): + def get_existing_schemas(cls, config, model_name=None): sql = "select distinct schema_name from information_schema.schemata" - connection, cursor = cls.add_query(profile, sql, model_name, + connection, cursor = cls.add_query(config, sql, model_name, auto_begin=False) results = cursor.fetchall() return [row[0] for row in results] @classmethod - def check_schema_exists(cls, profile, project_cfg, - schema, model_name=None): + def check_schema_exists(cls, config, schema, model_name=None): sql = """ select count(*) from information_schema.schemata where upper(schema_name) = upper('{schema}') """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(profile, sql, model_name, + connection, cursor = cls.add_query(config, sql, model_name, auto_begin=False) results = cursor.fetchone() @@ -181,7 +178,7 @@ def _split_queries(cls, sql): return [part[0] for part in split_query] @classmethod - def add_query(cls, profile, sql, model_name=None, auto_begin=True, + def add_query(cls, config, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): connection = None @@ -206,7 +203,7 @@ def add_query(cls, profile, sql, model_name=None, auto_begin=True, continue connection, cursor = super(PostgresAdapter, cls).add_query( - profile, individual_query, model_name, auto_begin, + config, individual_query, model_name, auto_begin, bindings=bindings, abridge_sql_log=abridge_sql_log) if cursor is None: @@ -228,30 +225,28 @@ def _filter_table(cls, table, manifest): return super(SnowflakeAdapter, cls)._filter_table(lowered, manifest) @classmethod - def _make_match_kwargs(cls, project_cfg, schema, identifier): - if identifier is not None and \ - project_cfg.get('quoting', {}).get('identifier', False) is False: + def _make_match_kwargs(cls, config, schema, identifier): + if identifier is not None and config.quoting['identifier'] is False: identifier = identifier.upper() - if schema is not None and \ - project_cfg.get('quoting', {}).get('schema', False) is False: + if schema is not None and config.quoting['schema'] is False: schema = schema.upper() return filter_null_values({'identifier': identifier, 'schema': schema}) @classmethod - def cancel_connection(cls, profile, connection): - handle = connection['handle'] + def cancel_connection(cls, config, connection): + handle = connection.handle sid = handle.session_id - connection_name = connection.get('name') + connection_name = connection.name sql = 'select system$abort_session({})'.format(sid) logger.debug("Cancelling query '{}' ({})".format(connection_name, sid)) - _, cursor = cls.add_query(profile, sql, 'master') + _, cursor = cls.add_query(config, sql, 'master') res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/dbt/adapters/snowflake/relation.py b/dbt/adapters/snowflake/relation.py index fcb3bd3207d..bd879965404 100644 --- a/dbt/adapters/snowflake/relation.py +++ b/dbt/adapters/snowflake/relation.py @@ -44,9 +44,9 @@ class SnowflakeRelation(DefaultRelation): } @classmethod - def create_from_node(cls, profile, node, **kwargs): + def create_from_node(cls, config, node, **kwargs): return cls.create( - database=profile.get('database'), + database=config.credentials.database, schema=node.get('schema'), identifier=node.get('alias'), **kwargs) diff --git a/dbt/api/object.py b/dbt/api/object.py index 68d5f929eea..45cc3232b6a 100644 --- a/dbt/api/object.py +++ b/dbt/api/object.py @@ -42,6 +42,11 @@ def __str__(self): def __repr__(self): return '{}(**{})'.format(self.__class__.__name__, self._contents) + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.serialize() == other.serialize() + def incorporate(self, **kwargs): """ Given a list of kwargs, incorporate these arguments diff --git a/dbt/compilation.py b/dbt/compilation.py index 4057263b538..c35ce1d4909 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -4,7 +4,6 @@ from collections import OrderedDict, defaultdict import sqlparse -import dbt.project import dbt.utils import dbt.include import dbt.tracking @@ -96,12 +95,12 @@ def recursively_prepend_ctes(model, manifest): class Compiler(object): - def __init__(self, project): - self.project = project + def __init__(self, config): + self.config = config def initialize(self): - dbt.clients.system.make_directory(self.project['target-path']) - dbt.clients.system.make_directory(self.project['modules-path']) + dbt.clients.system.make_directory(self.config.target_path) + dbt.clients.system.make_directory(self.config.modules_path) def compile_node(self, node, manifest): logger.debug("Compiling {}".format(node.get('unique_id'))) @@ -117,7 +116,7 @@ def compile_node(self, node, manifest): compiled_node = CompiledNode(**data) context = dbt.context.runtime.generate( - compiled_node, self.project, manifest) + compiled_node, self.config, manifest) compiled_node.compiled_sql = dbt.clients.jinja.get_rendered( node.get('raw_sql'), @@ -162,12 +161,12 @@ def write_manifest_file(self, manifest): manifest should be a Manifest. """ filename = manifest_file_name - manifest_path = os.path.join(self.project['target-path'], filename) + manifest_path = os.path.join(self.config.target_path, filename) write_json(manifest_path, manifest.serialize()) def write_graph_file(self, linker): filename = graph_file_name - graph_path = os.path.join(self.project['target-path'], filename) + graph_path = os.path.join(self.config.target_path, filename) linker.write_graph(graph_path) def link_node(self, linker, node, manifest): @@ -196,13 +195,12 @@ def link_graph(self, linker, manifest): raise RuntimeError("Found a cycle: {}".format(cycle)) def get_all_projects(self): - root_project = self.project.cfg - all_projects = {root_project.get('name'): root_project} - dependency_projects = dbt.utils.dependency_projects(self.project) + all_projects = {self.config.project_name: self.config} + dependency_projects = dbt.utils.dependency_projects(self.config) - for project in dependency_projects: - name = project.cfg.get('name', 'unknown') - all_projects[name] = project.cfg + for project_cfg in dependency_projects: + name = project_cfg.project_name + all_projects[name] = project_cfg if dbt.flags.STRICT_MODE: dbt.contracts.project.ProjectList(**all_projects) @@ -238,7 +236,7 @@ def compile(self): all_projects = self.get_all_projects() - manifest = dbt.loader.GraphLoader.load_all(self.project, all_projects) + manifest = dbt.loader.GraphLoader.load_all(self.config, all_projects) self.write_manifest_file(manifest) diff --git a/dbt/config.py b/dbt/config.py index 51681da6ddc..0326bd3d148 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -1,18 +1,73 @@ import os.path +from copy import deepcopy +import hashlib +import pprint import dbt.exceptions import dbt.clients.yaml_helper import dbt.clients.system +import dbt.utils +from dbt.contracts.connection import Connection, create_credentials +from dbt.contracts.project import Project as ProjectContract, Configuration, \ + PackageConfig, ProfileConfig +from dbt.context.common import env_var +from dbt import compat from dbt.logger import GLOBAL_LOGGER as logger +DEFAULT_THREADS = 1 +DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True +DEFAULT_USE_COLORS = True +DEFAULT_QUOTING_GLOBAL = { + 'identifier': True, + 'schema': True, +} +# some adapters need different quoting rules, for example snowflake gets a bit +# weird with quoting on +DEFAULT_QUOTING_ADAPTER = { + 'snowflake': { + 'identifier': False, + 'schema': False, + }, +} +DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt') + + INVALID_PROFILE_MESSAGE = """ dbt encountered an error while trying to read your profiles.yml file. {error_string} """ +NO_SUPPLIED_PROFILE_ERROR = """\ +dbt cannot run because no profile was specified for this dbt project. +To specify a profile for this project, add a line like the this to +your dbt_project.yml file: + +profile: [profile name] + +Here, [profile name] should be replaced with a profile name +defined in your profiles.yml file. You can find profiles.yml here: + +{profiles_file}/profiles.yml +""".format(profiles_file=DEFAULT_PROFILES_DIR) + + +class DbtConfigError(Exception): + def __init__(self, message, project=None, result_type='invalid_project'): + self.project = project + super(DbtConfigError, self).__init__(message) + self.result_type = result_type + + +class DbtProjectError(DbtConfigError): + pass + + +class DbtProfileError(DbtConfigError): + pass + def read_profile(profiles_dir): path = os.path.join(profiles_dir, 'profiles.yml') @@ -29,6 +84,21 @@ def read_profile(profiles_dir): return {} +def read_profiles(profiles_dir=None): + """This is only used in main, for some error handling""" + if profiles_dir is None: + profiles_dir = DEFAULT_PROFILES_DIR + + raw_profiles = read_profile(profiles_dir) + + if raw_profiles is None: + profiles = {} + else: + profiles = {k: v for (k, v) in raw_profiles.items() if k != 'config'} + + return profiles + + def read_config(profiles_dir): profile = read_profile(profiles_dir) if profile is None: @@ -43,3 +113,671 @@ def send_anonymous_usage_stats(config): def colorize_output(config): return config.get('use_colors', True) + + +def _render(key, value, ctx): + """Render an entry in the credentials dictionary, in case it's jinja. + + If the parsed entry is a string and has the name 'port', this will attempt + to cast it to an int, and on failure will return the parsed string. + + :param key str: The key to convert on. + :param value Any: The value to potentially render + :param ctx dict: The context dictionary, mapping function names to + functions that take a str and return a value + :return Any: The rendered entry. + """ + if not isinstance(value, compat.basestring): + return value + result = dbt.clients.jinja.get_rendered(value, ctx) + if key == 'port': + try: + return int(result) + except ValueError: + pass # let the validator or connection handle this + return result + + +class Project(object): + def __init__(self, project_name, version, project_root, profile_name, + source_paths, macro_paths, data_paths, test_paths, + analysis_paths, docs_paths, target_path, clean_targets, + log_path, modules_path, quoting, models, on_run_start, + on_run_end, archive, seeds, packages): + self.project_name = project_name + self.version = version + self.project_root = project_root + self.profile_name = profile_name + self.source_paths = source_paths + self.macro_paths = macro_paths + self.data_paths = data_paths + self.test_paths = test_paths + self.analysis_paths = analysis_paths + self.docs_paths = docs_paths + self.target_path = target_path + self.clean_targets = clean_targets + self.log_path = log_path + self.modules_path = modules_path + self.quoting = quoting + self.models = models + self.on_run_start = on_run_start + self.on_run_end = on_run_end + self.archive = archive + self.seeds = seeds + self.packages = packages + + @classmethod + def from_project_config(cls, project_dict, packages_dict=None): + """Create a project from its project and package configuration, as read + by yaml.safe_load(). + + :param project_dict dict: The dictionary as read from disk + :param packages_dict Optional[dict]: If it exists, the packages file as + read from disk. + :raises DbtProjectError: If the project is missing or invalid, or if + the packages file exists and is invalid. + :returns Project: The project, with defaults populated. + """ + # just for validation. + try: + ProjectContract(**project_dict) + except dbt.exceptions.ValidationException as e: + raise DbtProjectError(str(e)) + + # name/version are required in the Project definition, so we can assume + # they are present + name = project_dict['name'] + version = project_dict['version'] + # this is added at project_dict parse time and should always be here + # once we see it. + project_root = project_dict['project-root'] + # this is only optional in the sense that if it's not present, it needs + # to have been a cli argument. + profile_name = project_dict.get('profile') + # these are optional + source_paths = project_dict.get('source-paths', ['models']) + macro_paths = project_dict.get('macro-paths', ['macros']) + data_paths = project_dict.get('data-paths', ['data']) + test_paths = project_dict.get('test-paths', ['test']) + analysis_paths = project_dict.get('analysis-paths', []) + docs_paths = project_dict.get('docs-paths', source_paths[:]) + target_path = project_dict.get('target-path', 'target') + # should this also include the modules path by default? + clean_targets = project_dict.get('clean-targets', [target_path]) + log_path = project_dict.get('log-path', 'logs') + modules_path = project_dict.get('modules-path', 'dbt_modules') + # in the default case we'll populate this once we know the adapter type + quoting = project_dict.get('quoting', {}) + models = project_dict.get('models', {}) + on_run_start = project_dict.get('on-run-start', []) + on_run_end = project_dict.get('on-run-end', []) + archive = project_dict.get('archive', []) + seeds = project_dict.get('seeds', {}) + + packages = package_config_from_data(packages_dict) + + project = cls( + project_name=name, + version=version, + project_root=project_root, + profile_name=profile_name, + source_paths=source_paths, + macro_paths=macro_paths, + data_paths=data_paths, + test_paths=test_paths, + analysis_paths=analysis_paths, + docs_paths=docs_paths, + target_path=target_path, + clean_targets=clean_targets, + log_path=log_path, + modules_path=modules_path, + quoting=quoting, + models=models, + on_run_start=on_run_start, + on_run_end=on_run_end, + archive=archive, + seeds=seeds, + packages=packages + ) + # sanity check - this means an internal issue + project.validate() + return project + + def __str__(self): + cfg = self.to_project_config(with_packages=True) + return pprint.pformat(cfg) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.to_project_config(with_packages=True) == \ + other.to_project_config(with_packages=True) + + def to_project_config(self, with_packages=False): + """Return a dict representation of the config that could be written to + disk with `yaml.safe_dump` to get this configuration. + + :param with_packages bool: If True, include the serialized packages + file in the root. + :returns dict: The serialized profile. + """ + result = deepcopy({ + 'name': self.project_name, + 'version': self.version, + 'project-root': self.project_root, + 'profile': self.profile_name, + 'source-paths': self.source_paths, + 'macro-paths': self.macro_paths, + 'data-paths': self.data_paths, + 'test-paths': self.test_paths, + 'analysis-paths': self.analysis_paths, + 'docs-paths': self.docs_paths, + 'target-path': self.target_path, + 'clean-targets': self.clean_targets, + 'log-path': self.log_path, + 'quoting': self.quoting, + 'models': self.models, + 'on-run-start': self.on_run_start, + 'on-run-end': self.on_run_end, + 'archive': self.archive, + 'seeds': self.seeds, + }) + if with_packages: + result.update(self.packages.serialize()) + return result + + def validate(self): + try: + ProjectContract(**self.to_project_config()) + except dbt.exceptions.ValidationException as exc: + raise DbtProjectError(str(exc)) + + @classmethod + def from_project_root(cls, project_root): + """Create a project from a root directory. Reads in dbt_project.yml and + packages.yml, if it exists. + + :param project_root str: The path to the project root to load. + :raises DbtProjectError: If the project is missing or invalid, or if + the packages file exists and is invalid. + :returns Project: The project, with defaults populated. + """ + project_root = os.path.normpath(project_root) + project_yaml_filepath = os.path.join(project_root, 'dbt_project.yml') + + # get the project.yml contents + if not dbt.clients.system.path_exists(project_yaml_filepath): + raise DbtProjectError( + 'no dbt_project.yml found at expected path {}' + .format(project_yaml_filepath) + ) + + project_dict = _load_yaml(project_yaml_filepath) + project_dict['project-root'] = project_root + packages_dict = package_data_from_root(project_root) + return cls.from_project_config(project_dict, packages_dict) + + @classmethod + def from_current_directory(cls): + return cls.from_project_root(os.getcwd()) + + def hashed_name(self): + return hashlib.md5(self.project_name.encode('utf-8')).hexdigest() + + +class Profile(object): + def __init__(self, profile_name, target_name, send_anonymous_usage_stats, + use_colors, threads, credentials): + self.profile_name = profile_name + self.target_name = target_name + self.send_anonymous_usage_stats = send_anonymous_usage_stats + self.use_colors = use_colors + self.threads = threads + self.credentials = credentials + + def to_profile_info(self, serialize_credentials=False): + """Unlike to_project_config, this dict is not a mirror of any existing + on-disk data structure. It's used when creating a new profile from an + existing one. + + :param serialize_credentials bool: If True, serialize the credentials. + Otherwise, the Credentials object will be copied. + :returns dict: The serialized profile. + """ + result = { + 'profile_name': self.profile_name, + 'target_name': self.target_name, + 'send_anonymous_usage_stats': self.send_anonymous_usage_stats, + 'use_colors': self.use_colors, + 'threads': self.threads, + 'credentials': self.credentials.incorporate(), + } + if serialize_credentials: + result['credentials'] = result['credentials'].serialize() + return result + + def __str__(self): + return pprint.pformat(self.to_profile_info()) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.to_profile_info() == other.to_profile_info() + + def validate(self): + try: + ProfileConfig(**self.to_profile_info(serialize_credentials=True)) + except dbt.exceptions.ValidationException as exc: + raise DbtProfileError(str(exc)) + + @staticmethod + def _rendered_profile(profile): + # if entries are strings, we want to render them so we can get any + # environment variables that might store important credentials + # elements. + return { + k: _render(k, v, {'env_var': env_var}) + for k, v in profile.items() + } + + @staticmethod + def _credentials_from_profile(profile, profile_name, target_name): + # credentials carry their 'type' in their actual type, not their + # attributes. We do want this in order to pick our Credentials class. + if 'type' not in profile: + raise DbtProfileError( + 'required field "type" not found in profile {} and target {}' + .format(profile_name, target_name)) + + typename = profile.pop('type') + try: + credentials = create_credentials(typename, profile) + except dbt.exceptions.RuntimeException as e: + raise DbtProfileError( + 'Credentials in profile "{}", target "{}" invalid: {}' + .format(profile_name, target_name, str(e)) + ) + return credentials + + @staticmethod + def _pick_profile_name(args_profile_name, project_profile_name=None): + profile_name = project_profile_name + if args_profile_name is not None: + profile_name = args_profile_name + if profile_name is None: + raise DbtProjectError(NO_SUPPLIED_PROFILE_ERROR) + return profile_name + + @staticmethod + def _pick_target(raw_profile, profile_name, target_override=None): + + if target_override is not None: + target_name = target_override + elif 'target' in raw_profile: + target_name = raw_profile['target'] + else: + raise DbtProfileError( + "target not specified in profile '{}'".format(profile_name) + ) + return target_name + + @staticmethod + def _get_profile_data(raw_profile, profile_name, target_name): + if 'outputs' not in raw_profile: + raise DbtProfileError( + "outputs not specified in profile '{}'".format(profile_name) + ) + outputs = raw_profile['outputs'] + + if target_name not in outputs: + outputs = '\n'.join(' - {}'.format(output) + for output in outputs) + msg = ("The profile '{}' does not have a target named '{}'. The " + "valid target names for this profile are:\n{}" + .format(profile_name, target_name, outputs)) + raise DbtProfileError(msg, result_type='invalid_target') + profile_data = outputs[target_name] + return profile_data + + @classmethod + def from_credentials(cls, credentials, threads, profile_name, target_name, + user_cfg=None): + """Create a profile from an existing set of Credentials and the + remaining information. + + :param credentials Credentials: The credentials for this profile. + :param threads int: The number of threads to use for connections. + :param profile_name str: The profile name used for this profile. + :param target_name str: The target name used for this profile. + :param user_cfg Optional[dict]: The user-level config block from the + raw profiles, if specified. + :raises DbtProfileError: If the profile is invalid. + :returns Profile: The new Profile object. + """ + if user_cfg is None: + user_cfg = {} + send_anonymous_usage_stats = user_cfg.get( + 'send_anonymous_usage_stats', + DEFAULT_SEND_ANONYMOUS_USAGE_STATS + ) + use_colors = user_cfg.get( + 'use_colors', + DEFAULT_USE_COLORS + ) + profile = cls( + profile_name=profile_name, + target_name=target_name, + send_anonymous_usage_stats=send_anonymous_usage_stats, + use_colors=use_colors, + threads=threads, + credentials=credentials + ) + profile.validate() + return profile + + @classmethod + def from_raw_profile_info(cls, raw_profile, profile_name, user_cfg=None, + target_override=None, threads_override=None): + """Create a profile from its raw profile information. + + (this is an intermediate step, mostly useful for unit testing) + + :param raw_profiles dict: The profile data for a single profile, from + disk as yaml. + :param profile_name str: The profile name used. + :param user_cfg Optional[dict]: The global config for the user, if it + was present. + :param target_override Optional[str]: The target to use, if provided on + the command line. + :param threads_override Optional[str]: The thread count to use, if + provided on the command line. + :raises DbtProfileError: If the profile is invalid or missing, or the + target could not be found + :returns Profile: The new Profile object. + """ + target_name = cls._pick_target( + raw_profile, profile_name, target_override + ) + profile_data = cls._get_profile_data( + raw_profile, profile_name, target_name + ) + rendered_profile = cls._rendered_profile(profile_data) + + # valid connections never include the number of threads, but it's + # stored on a per-connection level in the raw configs + threads = rendered_profile.pop('threads', DEFAULT_THREADS) + if threads_override is not None: + threads = threads_override + + credentials = cls._credentials_from_profile( + rendered_profile, profile_name, target_name + ) + return cls.from_credentials( + credentials=credentials, + profile_name=profile_name, + target_name=target_name, + threads=threads, + user_cfg=user_cfg + ) + + @classmethod + def from_raw_profiles(cls, raw_profiles, profile_name, + target_override=None, threads_override=None): + """ + :param raw_profiles dict: The profile data, from disk as yaml. + :param profile_name str: The profile name to use. + :param target_override Optional[str]: The target to use, if provided on + the command line. + :param threads_override Optional[str]: The thread count to use, if + provided on the command line. + :raises DbtProjectError: If there is no profile name specified in the + project or the command line arguments + :raises DbtProfileError: If the profile is invalid or missing, or the + target could not be found + :returns Profile: The new Profile object. + """ + # TODO(jeb): Validate the raw_profiles structure right here + if profile_name not in raw_profiles: + raise DbtProjectError( + "Could not find profile named '{}'".format(profile_name) + ) + raw_profile = raw_profiles[profile_name] + user_cfg = raw_profiles.get('config') + + return cls.from_raw_profile_info( + raw_profile=raw_profile, + profile_name=profile_name, + user_cfg=user_cfg, + target_override=target_override, + threads_override=threads_override, + ) + + @classmethod + def from_args(cls, args, project_profile_name=None): + """Given the raw profiles as read from disk and the name of the desired + profile if specified, return the profile component of the runtime + config. + + :param args argparse.Namespace: The arguments as parsed from the cli. + :param project_profile_name Optional[str]: The profile name, if + specified in a project. + :raises DbtProjectError: If there is no profile name specified in the + project or the command line arguments, or if the specified profile + is not found + :raises DbtProfileError: If the profile is invalid or missing, or the + target could not be found. + :returns Profile: The new Profile object. + """ + threads_override = getattr(args, 'threads', None) + # TODO(jeb): is it even possible for this to not be set? + profiles_dir = getattr(args, 'profiles_dir', DEFAULT_PROFILES_DIR) + target_override = getattr(args, 'target', None) + raw_profiles = read_profile(profiles_dir) + profile_name = cls._pick_profile_name(args.profile, + project_profile_name) + + return cls.from_raw_profiles( + raw_profiles=raw_profiles, + profile_name=profile_name, + target_override=target_override, + threads_override=threads_override + ) + + +def package_config_from_data(packages_data): + if packages_data is None: + packages_data = {'packages': []} + + try: + packages = PackageConfig(**packages_data) + except dbt.exceptions.ValidationException as e: + raise DbtProjectError('Invalid package config: {}'.format(str(e))) + return packages + + +def package_data_from_root(project_root): + package_filepath = dbt.clients.system.resolve_path_from_base( + 'packages.yml', project_root + ) + + if dbt.clients.system.path_exists(package_filepath): + packages_dict = _load_yaml(package_filepath) + else: + packages_dict = None + return packages_dict + + +def package_config_from_root(project_root): + packages_dict = package_data_from_root(project_root) + return package_config_from_data(packages_dict) + + +class RuntimeConfig(Project, Profile): + """The runtime configuration, as constructed from its components. There's a + lot because there is a lot of stuff! + """ + def __init__(self, project_name, version, project_root, source_paths, + macro_paths, data_paths, test_paths, analysis_paths, + docs_paths, target_path, clean_targets, log_path, + modules_path, quoting, models, on_run_start, on_run_end, + archive, seeds, profile_name, target_name, + send_anonymous_usage_stats, use_colors, threads, credentials, + packages, cli_vars): + # 'vars' + self.cli_vars = cli_vars + # 'project' + Project.__init__( + self, + project_name=project_name, + version=version, + project_root=project_root, + profile_name=profile_name, + source_paths=source_paths, + macro_paths=macro_paths, + data_paths=data_paths, + test_paths=test_paths, + analysis_paths=analysis_paths, + docs_paths=docs_paths, + target_path=target_path, + clean_targets=clean_targets, + log_path=log_path, + modules_path=modules_path, + quoting=quoting, + models=models, + on_run_start=on_run_start, + on_run_end=on_run_end, + archive=archive, + seeds=seeds, + packages=packages, + ) + # 'profile' + Profile.__init__( + self, + profile_name=profile_name, + target_name=target_name, + send_anonymous_usage_stats=send_anonymous_usage_stats, + use_colors=use_colors, + threads=threads, + credentials=credentials + ) + self.validate() + + @classmethod + def from_parts(cls, project, profile, cli_vars): + """Instantiate a RuntimeConfig from its components. + + :param profile Profile: A parsed dbt Profile. + :param project Project: A parsed dbt Project. + :param cli_vars dict: A dict of vars, as provided from the command + line. + :returns RuntimeConfig: The new configuration. + """ + quoting = deepcopy( + DEFAULT_QUOTING_ADAPTER.get(profile.credentials.type, + DEFAULT_QUOTING_GLOBAL) + ) + quoting.update(project.quoting) + return cls( + project_name=project.project_name, + version=project.version, + project_root=project.project_root, + source_paths=project.source_paths, + macro_paths=project.macro_paths, + data_paths=project.data_paths, + test_paths=project.test_paths, + analysis_paths=project.analysis_paths, + docs_paths=project.docs_paths, + target_path=project.target_path, + clean_targets=project.clean_targets, + log_path=project.log_path, + modules_path=project.modules_path, + quoting=quoting, + models=project.models, + on_run_start=project.on_run_start, + on_run_end=project.on_run_end, + archive=project.archive, + seeds=project.seeds, + packages=project.packages, + profile_name=profile.profile_name, + target_name=profile.target_name, + send_anonymous_usage_stats=profile.send_anonymous_usage_stats, + use_colors=profile.use_colors, + threads=profile.threads, + credentials=profile.credentials, + cli_vars=cli_vars + ) + + def new_project(self, project_root): + """Given a new project root, read in its project dictionary, supply the + existing project's profile info, and create a new project file. + + :param project_root str: A filepath to a dbt project. + :raises DbtProfileError: If the profile is invalid. + :raises DbtProjectError: If project is missing or invalid. + :returns RuntimeConfig: The new configuration. + """ + # copy profile + profile = Profile(**self.to_profile_info()) + profile.validate() + # load the new project and its packages + project = Project.from_project_root(project_root) + + return self.from_parts( + project=project, + profile=profile, + cli_vars=deepcopy(self.cli_vars) + ) + + def serialize(self): + """Serialize the full configuration to a single dictionary. For any + instance that has passed validate() (which happens in __init__), it + matches the Configuration contract. + + :returns dict: The serialized configuration. + """ + result = self.to_project_config(with_packages=True) + result.update(self.to_profile_info(serialize_credentials=True)) + result['cli_vars'] = deepcopy(self.cli_vars) + return result + + def __str__(self): + return pprint.pformat(self.serialize()) + + def validate(self): + """Validate the configuration against its contract. + + :raises DbtProjectError: If the configuration fails validation. + """ + try: + Configuration(**self.serialize()) + except dbt.exceptions.ValidationException as e: + raise DbtProjectError(str(e)) + + @classmethod + def from_args(cls, args): + """Given arguments, read in dbt_project.yml from the current directory, + read in packages.yml if it exists, and use them to find the profile to + load. + + :param args argparse.Namespace: The arguments as parsed from the cli. + :raises DbtProjectError: If the project is invalid or missing. + :raises DbtProfileError: If the profile is invalid or missing. + :raises ValidationException: If the cli variables are invalid. + """ + # build the project and read in packages.yml + project = Project.from_current_directory() + + # build the profile + profile = Profile.from_args(args, project.profile_name) + + cli_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) + return cls.from_parts( + project=project, + profile=profile, + cli_vars=cli_vars + ) + + +def _load_yaml(path): + contents = dbt.clients.system.load_file_contents(path) + return dbt.clients.yaml_helper.load_yaml_text(contents) diff --git a/dbt/context/common.py b/dbt/context/common.py index 2f627272b0f..8187207270e 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -30,11 +30,10 @@ class DatabaseWrapper(object): functions. """ - def __init__(self, model, adapter, profile, project): + def __init__(self, model, adapter, config): self.model = model self.adapter = adapter - self.profile = profile - self.project = project + self.config = config self.Relation = adapter.Relation # Fun with metaprogramming @@ -44,12 +43,12 @@ def __init__(self, model, adapter, profile, project): for context_function in self.adapter.context_functions: setattr(self, context_function, - self.wrap(context_function, (self.profile, self.project,))) + self.wrap(context_function, (self.config,))) for profile_function in self.adapter.profile_functions: setattr(self, profile_function, - self.wrap(profile_function, (self.profile,))) + self.wrap(profile_function, (self.config,))) for raw_function in self.adapter.raw_functions: setattr(self, @@ -69,7 +68,7 @@ def type(self): def commit(self): return self.adapter.commit_if_has_connection( - self.profile, self.model.get('name')) + self.config, self.model.get('name')) def _add_macros(context, model, manifest): @@ -140,7 +139,7 @@ def inner(value): {'validation': validation_utils}) -def _env_var(var, default=None): +def env_var(var, default=None): if var in os.environ: return os.environ[var] elif default is not None: @@ -303,9 +302,9 @@ def _return(value): raise dbt.exceptions.MacroReturn(value) -def get_this_relation(db_wrapper, project_cfg, profile, model): +def get_this_relation(db_wrapper, config, model): return db_wrapper.adapter.Relation.create_from_node( - profile, model) + config, model) def create_relation(relation_type, quoting_config): @@ -336,33 +335,34 @@ class AdapterWithContext(adapter_type): return AdapterWithContext -def generate_base(model, model_dict, project_cfg, manifest, source_config, +def generate_base(model, model_dict, config, manifest, source_config, provider): """Generate the common aspects of the config dict.""" if provider is None: raise dbt.exceptions.InternalException( "Invalid provider given to context: {}".format(provider)) - target_name = project_cfg.get('target') - profile = project_cfg.get('outputs').get(target_name) - target = profile.copy() + target_name = config.target_name + target = config.to_profile_info() + del target['credentials'] + target.update(config.credentials.serialize()) + target['type'] = config.credentials.type target.pop('pass', None) target['name'] = target_name - adapter = get_adapter(profile) + adapter = get_adapter(config) context = {'env': target} - schema = profile.get('schema', 'public') + schema = config.credentials.schema pre_hooks = None post_hooks = None relation_type = create_relation(adapter.Relation, - project_cfg.get('quoting')) + config.quoting) db_wrapper = DatabaseWrapper(model_dict, create_adapter(adapter, relation_type), - profile, - project_cfg) + config) context = dbt.utils.merge(context, { "adapter": db_wrapper, "api": { @@ -371,7 +371,7 @@ def generate_base(model, model_dict, project_cfg, manifest, source_config, }, "column": adapter.Column, "config": provider.Config(model_dict, source_config), - "env_var": _env_var, + "env_var": env_var, "exceptions": dbt.exceptions, "execute": provider.execute, "flags": dbt.flags, @@ -385,8 +385,7 @@ def generate_base(model, model_dict, project_cfg, manifest, source_config, }, "post_hooks": post_hooks, "pre_hooks": pre_hooks, - "ref": provider.ref(db_wrapper, model, project_cfg, - profile, manifest), + "ref": provider.ref(db_wrapper, model, config, manifest), "return": _return, "schema": schema, "sql": None, @@ -406,19 +405,19 @@ def generate_base(model, model_dict, project_cfg, manifest, source_config, # https://github.com/fishtown-analytics/dbt/issues/878 if model.resource_type == NodeType.Operation: this = db_wrapper.adapter.Relation.create( - schema=target['schema'], + schema=config.credentials.schema, identifier=model.name ) else: - this = get_this_relation(db_wrapper, project_cfg, profile, model_dict) + this = get_this_relation(db_wrapper, config, model_dict) context["this"] = this return context -def modify_generated_context(context, model, model_dict, project_cfg, +def modify_generated_context(context, model, model_dict, config, manifest): - cli_var_overrides = project_cfg.get('cli_vars', {}) + cli_var_overrides = config.cli_vars context = _add_tracking(context) context = _add_validation(context) @@ -428,7 +427,7 @@ def modify_generated_context(context, model, model_dict, project_cfg, context = _add_macros(context, model, manifest) - context["write"] = write(model_dict, project_cfg.get('target-path'), 'run') + context["write"] = write(model_dict, config.target_path, 'run') context["render"] = render(context, model_dict) context["var"] = Var(model, context=context, overrides=cli_var_overrides) context['context'] = context @@ -436,21 +435,21 @@ def modify_generated_context(context, model, model_dict, project_cfg, return context -def generate_operation_macro(model, project_cfg, manifest, provider): +def generate_operation_macro(model, config, manifest, provider): """This is an ugly hack to support the fact that the `docs generate` operation ends up in here, and macros are not nodes. """ model_dict = model.serialize() - context = generate_base(model, model_dict, project_cfg, manifest, + context = generate_base(model, model_dict, config, manifest, None, provider) - return modify_generated_context(context, model, model_dict, project_cfg, + return modify_generated_context(context, model, model_dict, config, manifest) -def generate_model(model, project_cfg, manifest, source_config, provider): +def generate_model(model, config, manifest, source_config, provider): model_dict = model.to_dict() - context = generate_base(model, model_dict, project_cfg, manifest, + context = generate_base(model, model_dict, config, manifest, source_config, provider) # overwrite schema if we have it, and hooks + sql context.update({ @@ -460,11 +459,11 @@ def generate_model(model, project_cfg, manifest, source_config, provider): 'sql': model.get('injected_sql'), }) - return modify_generated_context(context, model, model_dict, project_cfg, + return modify_generated_context(context, model, model_dict, config, manifest) -def generate(model, project_cfg, manifest, source_config=None, provider=None): +def generate(model, config, manifest, source_config=None, provider=None): """ Not meant to be called directly. Call with either: dbt.context.parser.generate @@ -472,7 +471,7 @@ def generate(model, project_cfg, manifest, source_config=None, provider=None): dbt.context.runtime.generate """ if isinstance(model, ParsedMacro): - return generate_operation_macro(model, project_cfg, manifest, provider) + return generate_operation_macro(model, config, manifest, provider) else: - return generate_model(model, project_cfg, manifest, source_config, + return generate_model(model, config, manifest, source_config, provider) diff --git a/dbt/context/parser.py b/dbt/context/parser.py index 1794a22a5dd..2788b04c68b 100644 --- a/dbt/context/parser.py +++ b/dbt/context/parser.py @@ -6,7 +6,7 @@ execute = False -def ref(db_wrapper, model, project_cfg, profile, manifest): +def ref(db_wrapper, model, config, manifest): def ref(*args): if len(args) == 1 or len(args) == 2: @@ -15,7 +15,7 @@ def ref(*args): else: dbt.exceptions.ref_invalid_args(model, args) - return db_wrapper.adapter.Relation.create_from_node(profile, model) + return db_wrapper.adapter.Relation.create_from_node(config, model) return ref @@ -73,6 +73,6 @@ def get(self, name, validator=None, default=None): return '' -def generate(model, project_cfg, manifest, source_config): +def generate(model, runtime_config, manifest, source_config): return dbt.context.common.generate( - model, project_cfg, manifest, source_config, dbt.context.parser) + model, runtime_config, manifest, source_config, dbt.context.parser) diff --git a/dbt/context/runtime.py b/dbt/context/runtime.py index b9861c5af0c..7ae496f40c8 100644 --- a/dbt/context/runtime.py +++ b/dbt/context/runtime.py @@ -11,8 +11,8 @@ execute = True -def ref(db_wrapper, model, project_cfg, profile, manifest): - current_project = project_cfg.get('name') +def ref(db_wrapper, model, config, manifest): + current_project = config.project_name adapter = db_wrapper.adapter def do_ref(*args): @@ -55,7 +55,7 @@ def do_ref(*args): identifier=add_ephemeral_model_prefix( target_model_name)).quote(identifier=False) else: - return adapter.Relation.create_from_node(profile, target_model) + return adapter.Relation.create_from_node(config, target_model) return do_ref @@ -94,6 +94,6 @@ def get(self, name, validator=None, default=None): return to_return -def generate(model, project_cfg, manifest): +def generate(model, runtime_config, manifest): return dbt.context.common.generate( - model, project_cfg, manifest, None, dbt.context.runtime) + model, runtime_config, manifest, None, dbt.context.runtime) diff --git a/dbt/contracts/common.py b/dbt/contracts/common.py new file mode 100644 index 00000000000..b42f6306be8 --- /dev/null +++ b/dbt/contracts/common.py @@ -0,0 +1,11 @@ + + +def named_property(name, doc=None): + def get_prop(self): + return self._contents.get(name) + + def set_prop(self, value): + self._contents[name] = value + self.validate() + + return property(get_prop, set_prop, doc=doc) diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index 111450fc407..a3d00852c9c 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -1,4 +1,6 @@ +import dbt.exceptions from dbt.api.object import APIObject +from dbt.contracts.common import named_property from dbt.logger import GLOBAL_LOGGER as logger # noqa POSTGRES_CREDENTIALS_CONTRACT = { @@ -18,16 +20,9 @@ 'type': 'string', }, 'port': { - 'oneOf': [ - { - 'type': 'integer', - 'minimum': 0, - 'maximum': 65535, - }, - { - 'type': 'string' - }, - ], + 'type': 'integer', + 'minimum': 0, + 'maximum': 65535, }, 'schema': { 'type': 'string', @@ -62,16 +57,9 @@ 'type': 'string', }, 'port': { - 'oneOf': [ - { - 'type': 'integer', - 'minimum': 0, - 'maximum': 65535, - }, - { - 'type': 'string' - }, - ], + 'type': 'integer', + 'minimum': 0, + 'maximum': 65535, }, 'schema': { 'type': 'string', @@ -169,9 +157,11 @@ 'transaction_open': { 'type': 'boolean', }, - 'handle': { - 'type': ['null', 'object'], - }, + # we can't serialize this so we can't require it as part of the + # contract. + # 'handle': { + # 'type': ['null', 'object'], + # }, 'credentials': { 'description': ( 'The credentials object here should match the connection type.' @@ -185,26 +175,67 @@ } }, 'required': [ - 'type', 'name', 'state', 'transaction_open', 'handle', 'credentials' + 'type', 'name', 'state', 'transaction_open', 'credentials' ], } -class PostgresCredentials(APIObject): +class Credentials(APIObject): + """Common base class for credentials. This is not valid to instantiate""" + SCHEMA = NotImplemented + + @property + def type(self): + raise NotImplementedError( + 'type not implemented for base credentials class' + ) + + +class PostgresCredentials(Credentials): SCHEMA = POSTGRES_CREDENTIALS_CONTRACT + @property + def type(self): + return 'postgres' + + def incorporate(self, **kwargs): + if 'password' in kwargs: + kwargs['pass'] = kwargs.pop('password') + return super(PostgresCredentials, self).incorporate(**kwargs) + + @property + def password(self): + # we can't access this as 'pass' since that's reserved + return self._contents['pass'] + -class RedshiftCredentials(APIObject): +class RedshiftCredentials(PostgresCredentials): SCHEMA = REDSHIFT_CREDENTIALS_CONTRACT + def __init__(self, *args, **kwargs): + kwargs.setdefault('method', 'database') + super(RedshiftCredentials, self).__init__(*args, **kwargs) -class SnowflakeCredentials(APIObject): + @property + def type(self): + return 'redshift' + + +class SnowflakeCredentials(Credentials): SCHEMA = SNOWFLAKE_CREDENTIALS_CONTRACT + @property + def type(self): + return 'snowflake' + -class BigQueryCredentials(APIObject): +class BigQueryCredentials(Credentials): SCHEMA = BIGQUERY_CREDENTIALS_CONTRACT + @property + def type(self): + return 'bigquery' + CREDENTIALS_MAPPING = { 'postgres': PostgresCredentials, @@ -214,11 +245,45 @@ class BigQueryCredentials(APIObject): } +def create_credentials(typename, credentials): + if typename not in CREDENTIALS_MAPPING: + dbt.exceptions.raise_unrecognized_credentials_type( + typename, CREDENTIALS_MAPPING.keys() + ) + cls = CREDENTIALS_MAPPING[typename] + return cls(**credentials) + + class Connection(APIObject): SCHEMA = CONNECTION_CONTRACT - def validate(self): - super(Connection, self).validate() - # make sure our credentials match our adapter type - ContractType = CREDENTIALS_MAPPING.get(self.get('type')) - ContractType(**self.get('credentials')) + def __init__(self, credentials, *args, **kwargs): + # this is a bit clunky but we deserialize and then reserialize for now + if isinstance(credentials, Credentials): + credentials = credentials.serialize() + # we can't serialize handles + self._handle = kwargs.pop('handle') + super(Connection, self).__init__(credentials=credentials, + *args, **kwargs) + # this will validate itself in its own __init__. + self._credentials = create_credentials(self.type, + self._contents['credentials']) + + @property + def credentials(self): + return self._credentials + + @property + def handle(self): + return self._handle + + @handle.setter + def handle(self, value): + self._handle = value + + name = named_property('name', 'The name of this connection') + state = named_property('state', 'The state of the connection') + transaction_open = named_property( + 'transaction_open', + 'True if there is an open transaction, False otherwise.' + ) diff --git a/dbt/contracts/graph/manifest.py b/dbt/contracts/graph/manifest.py index b2afc7617c5..9f99e0b7a8c 100644 --- a/dbt/contracts/graph/manifest.py +++ b/dbt/contracts/graph/manifest.py @@ -160,13 +160,13 @@ class Manifest(APIObject): the current state of the compiler. Macros will always be ParsedMacros and docs will always be ParsedDocumentations. """ - def __init__(self, nodes, macros, docs, generated_at, project=None): + def __init__(self, nodes, macros, docs, generated_at, config=None): """The constructor. nodes and macros are dictionaries mapping unique IDs to ParsedNode/CompiledNode and ParsedMacro objects, respectively. docs is a dictionary mapping unique IDs to ParsedDocumentation objects. generated_at is a text timestamp in RFC 3339 format. """ - metadata = self.get_metadata(project) + metadata = self.get_metadata(config) self.nodes = nodes self.macros = macros self.docs = docs @@ -175,13 +175,13 @@ def __init__(self, nodes, macros, docs, generated_at, project=None): super(Manifest, self).__init__() @staticmethod - def get_metadata(project): + def get_metadata(config): project_id = None user_id = None send_anonymous_usage_stats = None - if project is not None: - project_id = project.hashed_name() + if config is not None: + project_id = config.hashed_name() if tracking.active_user is not None: user_id = tracking.active_user.id diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py index 2bba409bb8b..f46c51b2107 100644 --- a/dbt/contracts/project.py +++ b/dbt/contracts/project.py @@ -1,27 +1,326 @@ from dbt.api.object import APIObject from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.utils import deep_merge +from dbt.contracts.connection import POSTGRES_CREDENTIALS_CONTRACT, \ + REDSHIFT_CREDENTIALS_CONTRACT, SNOWFLAKE_CREDENTIALS_CONTRACT, \ + BIGQUERY_CREDENTIALS_CONTRACT + +# TODO: add description fields. +ARCHIVE_TABLE_CONFIG_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'source_table': {'type': 'string'}, + 'target_table': {'type': 'string'}, + 'updated_at': {'type': 'string'}, + 'unique_key': {'type': 'string'}, + }, + 'required': ['source_table', 'target_table', 'updated_at', 'unique_key'], +} + + +ARCHIVE_CONFIG_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'source_schema': {'type': 'string'}, + 'target_schema': {'type': 'string'}, + 'tables': { + 'type': 'array', + 'item': ARCHIVE_TABLE_CONFIG_CONTRACT, + } + }, + 'required': ['source_schema', 'target_schema', 'tables'], +} + PROJECT_CONTRACT = { 'type': 'object', - 'additionalProperties': True, - # TODO: Come back and wire the rest of the project config stuff into this. - 'description': 'The project configuration. This is incomplete.', + 'description': 'The project configuration.', + 'additionalProperties': False, 'properties': { 'name': { 'type': 'string', - } + 'pattern': r'^[^\d\W]\w*\Z', + }, + 'version': { + 'anyOf': [ + { + 'type': 'string', + 'pattern': ( + # this does not support the full semver (does not + # allow a trailing -fooXYZ) and is not restrictive + # enough for full semver, (allows '1.0'). But it's like + # 'semver lite'. + r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(\.(?:0|[1-9]\d*))?$' + ), + }, + { + # the internal global_project/dbt_project.yml is actually + # 1.0. Heaven only knows how many users have done the same + 'type': 'number', + }, + ], + }, + 'project-root': { + 'type': 'string', + }, + 'source-paths': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'macro-paths': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'data-paths': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'test-paths': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'analysis-paths': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'docs-paths': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'target-path': { + 'type': 'string', + }, + 'clean-targets': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'profile': { + 'type': ['null', 'string'], + }, + 'log-path': { + 'type': 'string', + }, + 'modules-path': { + 'type': 'string', + }, + 'quoting': { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'identifier': { + 'type': 'boolean', + }, + 'schema': { + 'type': 'boolean', + }, + }, + }, + 'models': { + 'type': 'object', + 'additionalProperties': True, + }, + 'on-run-start': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'on-run-end': { + 'type': 'array', + 'items': {'type': 'string'}, + }, + 'archive': { + 'type': 'array', + 'items': ARCHIVE_CONFIG_CONTRACT, + }, + 'seeds': { + 'type': 'object', + 'additionalProperties': True, + }, + }, + 'required': ['name', 'version'], +} + + +class Project(APIObject): + SCHEMA = PROJECT_CONTRACT + + +LOCAL_PACKAGE_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'local': { + 'type': 'string', + 'description': 'The absolute path to the local package.', + }, + 'required': ['local'], + }, +} + + +GIT_PACKAGE_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'git': { + 'type': 'string', + 'description': ( + 'The URL to the git repository that stores the pacakge' + ), + }, + 'revision': { + 'type': 'string', + 'description': 'The git revision to use, if it is not tip', + }, + }, + 'required': ['git'], +} + + +REGISTRY_PACKAGE_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'package': { + 'type': 'string', + 'description': 'The name of the package', + }, + 'version': { + 'type': 'string', + 'description': 'The version of the package', + }, + }, + 'required': ['package'], +} + + +class Package(APIObject): + SCHEMA = NotImplemented + + +class LocalPackage(Package): + SCHEMA = LOCAL_PACKAGE_CONTRACT + + +class GitPackage(Package): + SCHEMA = GIT_PACKAGE_CONTRACT + + +class RegistryPackage(Package): + SCHEMA = REGISTRY_PACKAGE_CONTRACT + + +PACKAGE_FILE_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'packages': { + 'type': 'array', + 'items': { + 'anyOf': [ + LOCAL_PACKAGE_CONTRACT, + GIT_PACKAGE_CONTRACT, + REGISTRY_PACKAGE_CONTRACT, + ], + }, + }, }, - 'required': ['name'], + 'required': ['packages'], } + +class PackageConfig(APIObject): + SCHEMA = PACKAGE_FILE_CONTRACT + + +PROFILE_INFO_CONTRACT = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'profile_name': { + 'type': 'string', + }, + 'target_name': { + 'type': 'string', + }, + 'send_anonymous_usage_stats': { + 'type': 'boolean', + }, + 'use_colors': { + 'type': 'boolean', + }, + 'threads': { + 'type': 'number', + }, + 'credentials': { + 'anyOf': [ + POSTGRES_CREDENTIALS_CONTRACT, + REDSHIFT_CREDENTIALS_CONTRACT, + SNOWFLAKE_CREDENTIALS_CONTRACT, + BIGQUERY_CREDENTIALS_CONTRACT, + ], + }, + }, + 'required': [ + 'profile_name', 'target_name', 'send_anonymous_usage_stats', + 'use_colors', 'threads', 'credentials' + ], +} + + +class ProfileConfig(APIObject): + SCHEMA = PROFILE_INFO_CONTRACT + + +def _merge_requirements(base, *args): + required = base[:] + for arg in args: + required.extend(arg['required']) + return required + + +CONFIG_CONTRACT = deep_merge( + PROJECT_CONTRACT, + PACKAGE_FILE_CONTRACT, + PROFILE_INFO_CONTRACT, + { + 'properties': { + 'cli_vars': { + 'type': 'object', + 'additionalProperties': True, + }, + # override quoting: both 'identifier' and 'schema' must be + # populated + 'quoting': { + 'required': ['identifier', 'schema'], + }, + }, + 'required': _merge_requirements( + ['cli_vars'], + PROJECT_CONTRACT, + PACKAGE_FILE_CONTRACT, + PROFILE_INFO_CONTRACT + ), + }, +) + + +class Configuration(APIObject): + SCHEMA = CONFIG_CONTRACT + + PROJECTS_LIST_PROJECT = { 'type': 'object', 'additionalProperties': False, 'patternProperties': { - '.*': PROJECT_CONTRACT, + '.*': CONFIG_CONTRACT, }, } class ProjectList(APIObject): SCHEMA = PROJECTS_LIST_PROJECT + + def serialize(self): + return {k: v.serialize() for k, v in self._contents.items()} diff --git a/dbt/contracts/results.py b/dbt/contracts/results.py index d565327955b..25f7cdca3df 100644 --- a/dbt/contracts/results.py +++ b/dbt/contracts/results.py @@ -1,5 +1,6 @@ from dbt.api.object import APIObject from dbt.utils import deep_merge +from dbt.contracts.common import named_property from dbt.contracts.graph.manifest import COMPILE_RESULT_NODE_CONTRACT from dbt.contracts.graph.parsed import PARSED_NODE_CONTRACT from dbt.contracts.graph.compiled import COMPILED_NODE_CONTRACT @@ -38,17 +39,6 @@ } -def named_property(name, doc=None): - def get_prop(self): - return self._contents.get(name) - - def set_prop(self, value): - self._contents[name] = value - self.validate() - - return property(get_prop, set_prop, doc=doc) - - class RunModelResult(APIObject): SCHEMA = RUN_MODEL_RESULT_CONTRACT diff --git a/dbt/exceptions.py b/dbt/exceptions.py index b1b6984390b..7bf3d2566b0 100644 --- a/dbt/exceptions.py +++ b/dbt/exceptions.py @@ -447,3 +447,10 @@ def raise_incorrect_version(path): 'for more information on schema.yml syntax:\n\n' 'https://docs.getdbt.com/v0.11/docs/schemayml-files'.format(path) ) + + +def raise_unrecognized_credentials_type(typename, supported_types): + raise_compiler_error( + 'Unrecognized credentials type "{}" - supported types are ({})' + .format(typename, ', '.join('"{}"'.format(t) for t in supported_types)) + ) diff --git a/dbt/loader.py b/dbt/loader.py index 52df20d17b3..f737b36cd6b 100644 --- a/dbt/loader.py +++ b/dbt/loader.py @@ -12,8 +12,8 @@ class GraphLoader(object): _LOADERS = [] @classmethod - def load_all(cls, project_obj, all_projects): - root_project = project_obj.cfg + def load_all(cls, project_config, all_projects): + root_project = project_config macros = MacroLoader.load_all(root_project, all_projects) macros.update(OperationLoader.load_all(root_project, all_projects)) nodes = {} @@ -24,13 +24,13 @@ def load_all(cls, project_obj, all_projects): tests, patches = SchemaTestLoader.load_all(root_project, all_projects) manifest = Manifest(nodes=nodes, macros=macros, docs=docs, - generated_at=timestring(), project=project_obj) + generated_at=timestring(), config=project_config) manifest.add_nodes(tests) manifest.patch_nodes(patches) manifest = dbt.parser.ParserUtils.process_refs( manifest, - root_project.get('name') + root_project.project_name ) manifest = dbt.parser.ParserUtils.process_docs(manifest, root_project) return manifest @@ -68,8 +68,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('macro-paths', []), + root_dir=project.project_root, + relative_dirs=project.macro_paths, resource_type=NodeType.Macro) @@ -93,13 +93,13 @@ def load_all(cls, root_project, all_projects, macros=None): def load_project(cls, root_project, all_projects, project, project_name, macros): return dbt.parser.ModelParser.load_and_parse( - package_name=project_name, - root_project=root_project, - all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('source-paths', []), - resource_type=NodeType.Model, - macros=macros) + package_name=project_name, + root_project=root_project, + all_projects=all_projects, + root_dir=project.project_root, + relative_dirs=project.source_paths, + resource_type=NodeType.Model, + macros=macros) class OperationLoader(ResourceLoader): @@ -111,8 +111,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('macro-paths', []), + root_dir=project.project_root, + relative_dirs=project.macro_paths, resource_type=NodeType.Operation) @@ -125,8 +125,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('analysis-paths', []), + root_dir=project.project_root, + relative_dirs=project.analysis_paths, resource_type=NodeType.Analysis, macros=macros) @@ -161,8 +161,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('source-paths', []), + root_dir=project.project_root, + relative_dirs=project.source_paths, macros=macros) @@ -175,8 +175,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('test-paths', []), + root_dir=project.project_root, + relative_dirs=project.test_paths, resource_type=NodeType.Test, tags=['data'], macros=macros) @@ -218,8 +218,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('data-paths', []), + root_dir=project.project_root, + relative_dirs=project.data_paths, macros=macros) @@ -231,8 +231,8 @@ def load_project(cls, root_project, all_projects, project, project_name, package_name=project_name, root_project=root_project, all_projects=all_projects, - root_dir=project.get('project-root'), - relative_dirs=project.get('docs-paths', [])) + root_dir=project.project_root, + relative_dirs=project.docs_paths) # node loaders GraphLoader.register(ModelLoader) diff --git a/dbt/main.py b/dbt/main.py index 5bb6c91b7cb..d5135b2df3d 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -1,4 +1,5 @@ -from dbt.logger import initialize_logger, GLOBAL_LOGGER as logger +from dbt.logger import initialize_logger, GLOBAL_LOGGER as logger, \ + initialized as logger_initialized import argparse import os.path @@ -7,7 +8,6 @@ import dbt.version import dbt.flags as flags -import dbt.project as project import dbt.task.run as run_task import dbt.task.compile as compile_task import dbt.task.debug as debug_task @@ -28,6 +28,7 @@ from dbt.utils import ExitCodes + PROFILES_HELP_MESSAGE = """ For more information on configuring profiles, please consult the dbt docs: @@ -87,7 +88,10 @@ def main(args=None): logger.info("Encountered an error:") logger.info(str(e)) - logger.debug(traceback.format_exc()) + if logger_initialized: + logger.debug(traceback.format_exc()) + else: + logger.error(traceback.format_exc()) exit_code = ExitCodes.UnhandledError sys.exit(exit_code) @@ -138,7 +142,7 @@ def get_nearest_project_dir(): def run_from_args(parsed): task = None - proj = None + cfg = None if parsed.which == 'init': # bypass looking for a project file if we're running `dbt init` @@ -157,39 +161,39 @@ def run_from_args(parsed): if res is None: raise RuntimeError("Could not run dbt") else: - task, proj = res + task, cfg = res log_path = None - if proj is not None: - log_path = proj.get('log-path', 'logs') + if cfg is not None: + log_path = cfg.log_path initialize_logger(parsed.debug, log_path) logger.debug("Tracking: {}".format(dbt.tracking.active_user.state())) - dbt.tracking.track_invocation_start(project=proj, args=parsed) + dbt.tracking.track_invocation_start(config=cfg, args=parsed) - results = run_from_task(task, proj, parsed) + results = run_from_task(task, cfg, parsed) return task, results -def run_from_task(task, proj, parsed_args): +def run_from_task(task, cfg, parsed_args): result = None try: result = task.run() dbt.tracking.track_invocation_end( - project=proj, args=parsed_args, result_type="ok" + config=cfg, args=parsed_args, result_type="ok" ) except (dbt.exceptions.NotImplementedException, dbt.exceptions.FailedToConnectException) as e: logger.info('ERROR: {}'.format(e)) dbt.tracking.track_invocation_end( - project=proj, args=parsed_args, result_type="error" + config=cfg, args=parsed_args, result_type="error" ) except Exception as e: dbt.tracking.track_invocation_end( - project=proj, args=parsed_args, result_type="error" + config=cfg, args=parsed_args, result_type="error" ) raise @@ -198,22 +202,21 @@ def run_from_task(task, proj, parsed_args): def invoke_dbt(parsed): task = None - proj = None + cfg = None try: - proj = project.read_project( - 'dbt_project.yml', - parsed.profiles_dir, - validate=False, - profile_to_load=parsed.profile, - args=parsed - ) - proj.validate() - except project.DbtProjectError as e: + if parsed.which == 'deps': + # deps doesn't need a profile, so don't require one. + cfg = config.Project.from_current_directory() + elif parsed.which != 'debug': + # for debug, we will attempt to load the various configurations as + # part of the task, so just leave cfg=None. + cfg = config.RuntimeConfig.from_args(parsed) + except config.DbtProjectError as e: logger.info("Encountered an error while reading the project:") logger.info(dbt.compat.to_string(e)) - all_profiles = project.read_profiles(parsed.profiles_dir).keys() + all_profiles = config.read_profiles(parsed.profiles_dir).keys() if len(all_profiles) > 0: logger.info("Defined profiles:") @@ -226,46 +229,26 @@ def invoke_dbt(parsed): logger.info(PROFILES_HELP_MESSAGE) dbt.tracking.track_invalid_invocation( - project=proj, + config=cfg, args=parsed, - result_type="invalid_profile") + result_type=e.result_type) return None - except project.DbtProfileError as e: + except config.DbtProfileError as e: logger.info("Encountered an error while reading profiles:") logger.info(" ERROR {}".format(str(e))) dbt.tracking.track_invalid_invocation( - project=proj, + config=cfg, args=parsed, - result_type="invalid_profile") + result_type=e.result_type) return None - if parsed.target is not None: - targets = proj.cfg.get('outputs', {}).keys() - if parsed.target in targets: - proj.cfg['target'] = parsed.target - # make sure we update the target if this is overriden on the cli - proj.compile_and_update_target() - else: - logger.info("Encountered an error while reading the project:") - logger.info(" ERROR Specified target {} is not a valid option " - "for profile {}" - .format(parsed.target, proj.profile_to_load)) - logger.info("Valid targets are: {}".format( - ', '.join(targets))) - dbt.tracking.track_invalid_invocation( - project=proj, - args=parsed, - result_type="invalid_target") - - return None - - flags.NON_DESTRUCTIVE = getattr(proj.args, 'non_destructive', False) + flags.NON_DESTRUCTIVE = getattr(parsed, 'non_destructive', False) - arg_drop_existing = getattr(proj.args, 'drop_existing', False) - arg_full_refresh = getattr(proj.args, 'full_refresh', False) + arg_drop_existing = getattr(parsed, 'drop_existing', False) + arg_full_refresh = getattr(parsed, 'full_refresh', False) if arg_drop_existing: dbt.deprecations.warn('drop-existing') @@ -275,9 +258,9 @@ def invoke_dbt(parsed): logger.debug("running dbt with arguments %s", parsed) - task = parsed.cls(args=parsed, project=proj) + task = parsed.cls(args=parsed, config=cfg) - return task, proj + return task, cfg def parse_args(args): @@ -310,11 +293,11 @@ def parse_args(args): base_subparser.add_argument( '--profiles-dir', - default=project.default_profiles_dir, + default=config.DEFAULT_PROFILES_DIR, type=str, help=""" Which directory to look in for the profiles.yml file. Default = {} - """.format(project.default_profiles_dir) + """.format(config.DEFAULT_PROFILES_DIR) ) base_subparser.add_argument( diff --git a/dbt/model.py b/dbt/model.py index f96676b3ea3..eb4e5a3f491 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -77,7 +77,7 @@ def config(self): active_config = self.load_config_from_active_project() - if self.active_project['name'] == self.own_project['name']: + if self.active_project.project_name == self.own_project.project_name: cfg = self._merge(defaults, active_config, self.in_model_config) else: @@ -133,7 +133,7 @@ def smart_update(self, mutable_config, new_configs): return relevant_configs - def get_project_config(self, project): + def get_project_config(self, runtime_config): # most configs are overwritten by a more specific config, but pre/post # hooks are appended! config = {} @@ -143,9 +143,9 @@ def get_project_config(self, project): config[k] = {} if self.node_type == NodeType.Seed: - model_configs = project.get('seeds') + model_configs = runtime_config.seeds else: - model_configs = project.get('models') + model_configs = runtime_config.models if model_configs is None: return config diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 58e337fce9b..01cb12e5ed1 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -43,9 +43,8 @@ def track_model_run(index, num_nodes, run_model_result): class BaseRunner(object): print_header = True - def __init__(self, project, adapter, node, node_index, num_nodes): - self.project = project - self.profile = project.run_environment() + def __init__(self, config, adapter, node, node_index, num_nodes): + self.config = config self.adapter = adapter self.node = node self.node_index = node_index @@ -118,7 +117,7 @@ def safe_run(self, manifest): finally: node_name = self.node.name - self.adapter.release_connection(self.profile, node_name) + self.adapter.release_connection(self.config, node_name) result.execution_time = time.time() - started return result @@ -159,19 +158,19 @@ def get_model_schemas(cls, manifest): return schemas @classmethod - def before_hooks(self, project, adapter, manifest): + def before_hooks(self, config, adapter, manifest): pass @classmethod - def before_run(self, project, adapter, manifest): + def before_run(self, config, adapter, manifest): pass @classmethod - def after_run(self, project, adapter, results, manifest): + def after_run(self, config, adapter, results, manifest): pass @classmethod - def after_hooks(self, project, adapter, results, manifest, elapsed): + def after_hooks(self, config, adapter, results, manifest, elapsed): pass @@ -191,14 +190,14 @@ def execute(self, compiled_node, manifest): return RunModelResult(compiled_node) def compile(self, manifest): - return self._compile_node(self.adapter, self.project, self.node, + return self._compile_node(self.adapter, self.config, self.node, manifest) @classmethod - def _compile_node(cls, adapter, project, node, manifest): - compiler = dbt.compilation.Compiler(project) + def _compile_node(cls, adapter, config, node, manifest): + compiler = dbt.compilation.Compiler(config) node = compiler.compile_node(node, manifest) - node = cls._inject_runtime_config(adapter, project, node) + node = cls._inject_runtime_config(adapter, config, node) if(node.injected_sql is not None and not (dbt.utils.is_type(node, NodeType.Archive))): @@ -207,7 +206,7 @@ def _compile_node(cls, adapter, project, node, manifest): written_path = dbt.writer.write_node( node, - project.get('target-path'), + config.target_path, 'compiled', node.injected_sql) @@ -216,31 +215,30 @@ def _compile_node(cls, adapter, project, node, manifest): return node @classmethod - def _inject_runtime_config(cls, adapter, project, node): + def _inject_runtime_config(cls, adapter, config, node): wrapped_sql = node.wrapped_sql - context = cls._node_context(adapter, project, node) + context = cls._node_context(adapter, config, node) sql = dbt.clients.jinja.get_rendered(wrapped_sql, context) node.wrapped_sql = sql return node @classmethod - def _node_context(cls, adapter, project, node): - profile = project.run_environment() + def _node_context(cls, adapter, config, node): def call_get_columns_in_table(schema_name, table_name): return adapter.get_columns_in_table( - profile, project, schema_name, + config, schema_name, table_name, model_name=node.alias) def call_get_missing_columns(from_schema, from_table, to_schema, to_table): return adapter.get_missing_columns( - profile, project, from_schema, from_table, + config, from_schema, from_table, to_schema, to_table, node.alias) def call_already_exists(schema, table): return adapter.already_exists( - profile, project, schema, table, node.alias) + config, schema, table, node.alias) return { "run_started_at": dbt.tracking.active_user.run_started_at, @@ -251,12 +249,11 @@ def call_already_exists(schema, table): } @classmethod - def create_schemas(cls, project, adapter, manifest): - profile = project.run_environment() + def create_schemas(cls, config, adapter, manifest): required_schemas = cls.get_model_schemas(manifest) - existing_schemas = set(adapter.get_existing_schemas(profile, project)) + existing_schemas = set(adapter.get_existing_schemas(config)) for schema in (required_schemas - existing_schemas): - adapter.create_schema(profile, project, schema) + adapter.create_schema(config, schema) class ModelRunner(CompileRunner): @@ -265,8 +262,7 @@ def raise_on_first_error(self): return False @classmethod - def run_hooks(cls, project, adapter, manifest, hook_type): - profile = project.run_environment() + def run_hooks(cls, config, adapter, manifest, hook_type): nodes = manifest.nodes.values() hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation) @@ -283,8 +279,8 @@ def run_hooks(cls, project, adapter, manifest, hook_type): # implement a for-loop over these sql statements in jinja-land. # Also, consider configuring psycopg2 (and other adapters?) to # ensure that a transaction is only created if dbt initiates it. - adapter.clear_transaction(profile, model_name) - compiled = cls._compile_node(adapter, project, hook, manifest) + adapter.clear_transaction(config, model_name) + compiled = cls._compile_node(adapter, config, hook, manifest) statement = compiled.wrapped_sql hook_index = hook.get('index', len(hooks)) @@ -296,40 +292,39 @@ def run_hooks(cls, project, adapter, manifest, hook_type): sql = hook_dict.get('sql', '') if len(sql.strip()) > 0: - adapter.execute(profile, sql, model_name=model_name, + adapter.execute(config, sql, model_name=model_name, auto_begin=False, fetch=False) - adapter.release_connection(profile, model_name) + adapter.release_connection(config, model_name) @classmethod - def safe_run_hooks(cls, project, adapter, manifest, hook_type): + def safe_run_hooks(cls, config, adapter, manifest, hook_type): try: - cls.run_hooks(project, adapter, manifest, hook_type) + cls.run_hooks(config, adapter, manifest, hook_type) except dbt.exceptions.RuntimeException: logger.info("Database error while running {}".format(hook_type)) raise @classmethod - def create_schemas(cls, project, adapter, manifest): - profile = project.run_environment() + def create_schemas(cls, config, adapter, manifest): required_schemas = cls.get_model_schemas(manifest) # Snowflake needs to issue a "use {schema}" query, where schema # is the one defined in the profile. Create this schema if it # does not exist, otherwise subsequent queries will fail. Generally, # dbt expects that this schema will exist anyway. - required_schemas.add(adapter.get_default_schema(profile, project)) + required_schemas.add(adapter.get_default_schema(config)) - existing_schemas = set(adapter.get_existing_schemas(profile, project)) + existing_schemas = set(adapter.get_existing_schemas(config)) for schema in (required_schemas - existing_schemas): - adapter.create_schema(profile, project, schema) + adapter.create_schema(config, schema) @classmethod - def before_run(cls, project, adapter, manifest): - cls.safe_run_hooks(project, adapter, manifest, RunHookType.Start) - cls.create_schemas(project, adapter, manifest) + def before_run(cls, config, adapter, manifest): + cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start) + cls.create_schemas(config, adapter, manifest) @classmethod def print_results_line(cls, results, execution_time): @@ -348,11 +343,11 @@ def print_results_line(cls, results, execution_time): .format(stat_line=stat_line, execution=execution)) @classmethod - def after_run(cls, project, adapter, results, manifest): - cls.safe_run_hooks(project, adapter, manifest, RunHookType.End) + def after_run(cls, config, adapter, results, manifest): + cls.safe_run_hooks(config, adapter, manifest, RunHookType.End) @classmethod - def after_hooks(cls, project, adapter, results, manifest, elapsed): + def after_hooks(cls, config, adapter, results, manifest, elapsed): cls.print_results_line(results, elapsed) def describe_node(self): @@ -382,7 +377,7 @@ def after_execute(self, result): def execute(self, model, manifest): context = dbt.context.runtime.generate( - model, self.project.cfg, manifest) + model, self.config, manifest) materialization_macro = manifest.get_materialization_macro( model.get_materialization(), @@ -423,7 +418,7 @@ def print_start_line(self): def execute_test(self, test): res, table = self.adapter.execute_and_fetch( - self.profile, + self.config, test.wrapped_sql, test.name, auto_begin=True) diff --git a/dbt/parser/archives.py b/dbt/parser/archives.py index e73c18edec8..84e89ce6975 100644 --- a/dbt/parser/archives.py +++ b/dbt/parser/archives.py @@ -8,9 +8,9 @@ class ArchiveParser(BaseParser): @classmethod - def parse_archives_from_project(cls, project): + def parse_archives_from_project(cls, config): archives = [] - archive_configs = project.get('archive', []) + archive_configs = config.archive for archive_config in archive_configs: tables = archive_config.get('tables') @@ -19,19 +19,19 @@ def parse_archives_from_project(cls, project): continue for table in tables: - config = table.copy() - config['source_schema'] = archive_config.get('source_schema') - config['target_schema'] = archive_config.get('target_schema') + cfg = table.copy() + cfg['source_schema'] = archive_config.get('source_schema') + cfg['target_schema'] = archive_config.get('target_schema') - fake_path = [config['target_schema'], config['target_table']] + fake_path = [cfg['target_schema'], cfg['target_table']] archives.append({ 'name': table.get('target_table'), - 'root_path': project.get('project-root'), + 'root_path': config.project_root, 'resource_type': NodeType.Archive, 'path': os.path.join('archive', *fake_path), 'original_file_path': 'dbt_project.yml', - 'package_name': project.get('name'), - 'config': config, + 'package_name': config.project_name, + 'config': cfg, 'raw_sql': '{{config(materialized="archive")}} -- noop' }) diff --git a/dbt/parser/base.py b/dbt/parser/base.py index 96a9f7b520b..1ee4bb15178 100644 --- a/dbt/parser/base.py +++ b/dbt/parser/base.py @@ -30,7 +30,7 @@ def get_path(cls, resource_type, package_name, resource_name): def get_fqn(cls, path, package_project_config, extra=[]): parts = dbt.utils.split_path(path) name, _ = os.path.splitext(parts[-1]) - fqn = ([package_project_config.get('name')] + + fqn = ([package_project_config.project_name] + parts[:-1] + extra + [name]) @@ -93,8 +93,8 @@ def parse_node(cls, node, node_path, root_project_config, node['config'] = config_dict # Set this temporarily so get_rendered() has access to a schema & alias - profile = dbt.utils.get_profile_from_project(root_project_config) - default_schema = profile.get('schema', 'public') + default_schema = getattr(root_project_config.credentials, 'schema', + 'public') node['schema'] = default_schema default_alias = node.get('name') node['alias'] = default_alias @@ -118,8 +118,8 @@ def parse_node(cls, node, node_path, root_project_config, # Clean up any open conns opened by adapter functions that hit the db db_wrapper = context['adapter'] adapter = db_wrapper.adapter - profile = db_wrapper.profile - adapter.release_connection(profile, parsed_node.name) + runtime_config = db_wrapper.config + adapter.release_connection(runtime_config, parsed_node.name) # Special macro defined in the global project schema_override = config.config.get('schema') diff --git a/dbt/parser/docs.py b/dbt/parser/docs.py index 80a992d9a50..7f1f0766058 100644 --- a/dbt/parser/docs.py +++ b/dbt/parser/docs.py @@ -51,8 +51,7 @@ def parse(cls, all_projects, root_project_config, docfile): e.node = docfile raise - profile = dbt.utils.get_profile_from_project(root_project_config) - schema = profile.get('schema', 'public') + schema = getattr(root_project_config.credentials, 'schema', 'public') for key, item in template.module.__dict__.items(): if type(item) != jinja2.runtime.Macro: diff --git a/dbt/parser/hooks.py b/dbt/parser/hooks.py index 3f8d513bd05..26140161135 100644 --- a/dbt/parser/hooks.py +++ b/dbt/parser/hooks.py @@ -11,8 +11,15 @@ class HookParser(BaseSqlParser): @classmethod - def get_hooks_from_project(cls, project_cfg, hook_type): - hooks = project_cfg.get(hook_type, []) + def get_hooks_from_project(cls, config, hook_type): + if hook_type == RunHookType.Start: + hooks = config.on_run_start + elif hook_type == RunHookType.End: + hooks = config.on_run_end + else: + dbt.exceptions.InternalException( + 'hook_type must be one of "{}" or "{}"' + .format(RunHookType.Start, RunHookType.End)) if type(hooks) not in (list, tuple): hooks = [hooks] diff --git a/dbt/parser/util.py b/dbt/parser/util.py index dbd58867200..85c5a84ec66 100644 --- a/dbt/parser/util.py +++ b/dbt/parser/util.py @@ -3,11 +3,11 @@ import dbt.utils -def docs(node, manifest, project_cfg, column_name=None): +def docs(node, manifest, config, column_name=None): """Return a function that will process `doc()` references in jinja, look them up in the manifest, and return the appropriate block contents. """ - current_project = project_cfg.get('name') + current_project = config.project_name def do_docs(*args): if len(args) == 1: diff --git a/dbt/project.py b/dbt/project.py deleted file mode 100644 index ddaf25a43c7..00000000000 --- a/dbt/project.py +++ /dev/null @@ -1,314 +0,0 @@ -import os.path -import pprint -import copy -import hashlib -import re - -import dbt.deprecations -import dbt.contracts.connection -import dbt.clients.yaml_helper -import dbt.clients.jinja -import dbt.compat -import dbt.context.common -import dbt.clients.system -import dbt.ui.printer -import dbt.links - -from dbt.api.object import APIObject -from dbt.utils import deep_merge -from dbt.logger import GLOBAL_LOGGER as logger # noqa - -default_project_cfg = { - 'source-paths': ['models'], - 'macro-paths': ['macros'], - 'data-paths': ['data'], - 'test-paths': ['test'], - 'target-path': 'target', - 'clean-targets': ['target'], - 'outputs': {'default': {}}, - 'target': 'default', - 'models': {}, - 'quoting': {}, - 'profile': None, - 'packages': [], - 'modules-path': 'dbt_modules' -} - -default_profiles = {} - -default_profiles_dir = os.path.join(os.path.expanduser('~'), '.dbt') - -NO_SUPPLIED_PROFILE_ERROR = """\ -dbt cannot run because no profile was specified for this dbt project. -To specify a profile for this project, add a line like the this to -your dbt_project.yml file: - -profile: [profile name] - -Here, [profile name] should be replaced with a profile name -defined in your profiles.yml file. You can find profiles.yml here: - -{profiles_file}/profiles.yml -""".format(profiles_file=default_profiles_dir) - - -class DbtProjectError(Exception): - def __init__(self, message, project): - self.project = project - super(DbtProjectError, self).__init__(message) - - -class DbtProfileError(Exception): - def __init__(self, message, project): - super(DbtProfileError, self).__init__(message) - - -class Project(object): - - def __init__(self, cfg, profiles, profiles_dir, profile_to_load=None, - args=None): - - self.cfg = default_project_cfg.copy() - self.cfg.update(cfg) - # docs paths defaults to the exact value of source-paths - if 'docs-paths' not in self.cfg: - self.cfg['docs-paths'] = self.cfg['source-paths'][:] - self.profiles = default_profiles.copy() - self.profiles.update(profiles) - self.profiles_dir = profiles_dir - self.profile_to_load = profile_to_load - self.args = args - - # load profile from dbt_config.yml if cli arg isn't supplied - if self.profile_to_load is None and self.cfg['profile'] is not None: - self.profile_to_load = self.cfg['profile'] - - if self.profile_to_load is None: - raise DbtProjectError(NO_SUPPLIED_PROFILE_ERROR, self) - - if self.profile_to_load in self.profiles: - self.cfg.update(self.profiles[self.profile_to_load]) - self.compile_and_update_target() - - else: - raise DbtProjectError( - "Could not find profile named '{}'" - .format(self.profile_to_load), self) - - if self.cfg.get('models') is None: - self.cfg['models'] = {} - - if self.cfg.get('quoting') is None: - self.cfg['quoting'] = {} - - if self.cfg['models'].get('vars') is None: - self.cfg['models']['vars'] = {} - - global_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) - self.cfg['cli_vars'] = global_vars - - def __str__(self): - return pprint.pformat({'project': self.cfg, 'profiles': self.profiles}) - - def __repr__(self): - return self.__str__() - - def __getitem__(self, key): - return self.cfg.__getitem__(key) - - def __contains__(self, key): - return self.cfg.__contains__(key) - - def __setitem__(self, key, value): - return self.cfg.__setitem__(key, value) - - def get(self, key, default=None): - return self.cfg.get(key, default) - - def handle_deprecations(self): - pass - - def is_valid_package_name(self): - if re.match(r"^[^\d\W]\w*\Z", self['name']): - return True - else: - return False - - def compile_target(self, target_cfg): - ctx = self.base_context() - - compiled = {} - for (key, value) in target_cfg.items(): - is_str = isinstance(value, dbt.compat.basestring) - - if is_str: - compiled_val = dbt.clients.jinja.get_rendered(value, ctx) - else: - compiled_val = value - - compiled[key] = compiled_val - - if self.args and hasattr(self.args, 'threads') and self.args.threads: - compiled['threads'] = self.args.threads - - return compiled - - def compile_and_update_target(self): - target = self.cfg['target'] - run_env = self.run_environment() - self.cfg['outputs'][target].update(run_env) - - def run_environment(self): - target_name = self.cfg['target'] - if target_name in self.cfg['outputs']: - target_cfg = self.cfg['outputs'][target_name] - return self.compile_target(target_cfg) - else: - - outputs = self.cfg.get('outputs', {}).keys() - output_names = [" - {}".format(output) for output in outputs] - - msg = ("The profile '{}' does not have a target named '{}'. The " - "valid target names for this profile are:\n{}".format( - self.profile_to_load, - target_name, - "\n".join(output_names))) - - raise DbtProfileError(msg, self) - - def get_target(self): - ctx = self.context().get('env').copy() - ctx['name'] = self.cfg['target'] - return ctx - - def base_context(self): - return { - 'env_var': dbt.context.common._env_var - } - - def context(self): - target_cfg = self.run_environment() - filtered_target = copy.deepcopy(target_cfg) - filtered_target.pop('pass', None) - - ctx = self.base_context() - ctx.update({ - 'env': filtered_target - }) - - return ctx - - def validate(self): - self.handle_deprecations() - - target_cfg = self.run_environment() - package_name = self.cfg.get('name', None) - package_version = self.cfg.get('version', None) - - if package_name is None or package_version is None: - raise DbtProjectError( - "Project name and version is not provided", self) - - if not self.is_valid_package_name(): - raise DbtProjectError( - ('Package name can only contain letters, numbers, and ' - 'underscores, and must start with a letter.'), self) - - db_type = target_cfg.get('type') - validator = dbt.contracts.connection.CREDENTIALS_MAPPING.get(db_type) - - if validator is None: - valid_types = dbt.contracts.connection.CREDENTIALS_MAPPING.keys() - raise DbtProjectError( - "Invalid db type '{}' should be one of [{}]".format( - db_type, - ", ".join(valid_types)), self) - - # This is python so I guess we'll just make a class here... - # it might be wise to tack an extend classmethod onto APIObject, - # similar to voluptous, to do all the deep merge stuff for us and spit - # out a new class. - class CredentialsValidator(APIObject): - SCHEMA = deep_merge( - validator.SCHEMA, - { - 'properties': { - 'type': {'type': 'string'}, - 'threads': {'type': 'integer'}, - }, - 'required': ( - validator.SCHEMA.get('required', []) + - ['type', 'threads'] - ), - } - ) - - try: - CredentialsValidator(**target_cfg) - except dbt.exceptions.ValidationException as e: - raise DbtProjectError(str(e), self) - - def hashed_name(self): - if self.cfg.get("name", None) is None: - return None - - project_name = self['name'] - return hashlib.md5(project_name.encode('utf-8')).hexdigest() - - -def read_profiles(profiles_dir=None): - if profiles_dir is None: - profiles_dir = default_profiles_dir - - raw_profiles = dbt.config.read_profile(profiles_dir) - - if raw_profiles is None: - profiles = {} - else: - profiles = {k: v for (k, v) in raw_profiles.items() if k != 'config'} - - return profiles - - -def read_packages(project_dir): - - package_filepath = dbt.clients.system.resolve_path_from_base( - 'packages.yml', project_dir) - - if dbt.clients.system.path_exists(package_filepath): - package_file_contents = dbt.clients.system.load_file_contents( - package_filepath) - package_cfg = dbt.clients.yaml_helper.load_yaml_text( - package_file_contents) - else: - package_cfg = {} - - return package_cfg.get('packages', []) - - -def read_project(project_filepath, profiles_dir=None, validate=True, - profile_to_load=None, args=None): - if profiles_dir is None: - profiles_dir = default_profiles_dir - - project_dir = os.path.dirname(os.path.abspath(project_filepath)) - project_file_contents = dbt.clients.system.load_file_contents( - project_filepath) - - project_cfg = dbt.clients.yaml_helper.load_yaml_text(project_file_contents) - package_cfg = read_packages(project_dir) - - project_cfg['project-root'] = project_dir - project_cfg['packages'] = package_cfg - - profiles = read_profiles(profiles_dir) - proj = Project(project_cfg, - profiles, - profiles_dir, - profile_to_load=profile_to_load, - args=args) - - if validate: - proj.validate() - - return proj diff --git a/dbt/runner.py b/dbt/runner.py index a1d34ddc796..1949b31caae 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -26,23 +26,13 @@ class RunManager(object): - def __init__(self, project, target_path, args): - self.project = project - self.target_path = target_path - self.args = args - - profile = self.project.run_environment() - - # TODO validate the number of threads - if not getattr(self.args, "threads", None): - self.threads = profile.get('threads', 1) - else: - self.threads = self.args.threads + def __init__(self, config): + self.config = config def deserialize_graph(self): logger.info("Loading dependency graph file.") - base_target_path = self.project['target-path'] + base_target_path = self.project.target_path graph_file = os.path.join( base_target_path, dbt.compilation.graph_file_name @@ -67,10 +57,10 @@ def get_runners(self, Runner, adapter, node_dependency_list): for node in all_nodes: uid = node.get('unique_id') if Runner.is_ephemeral_model(node): - runner = Runner(self.project, adapter, node, 0, 0) + runner = Runner(self.config, adapter, node, 0, 0) else: i += 1 - runner = Runner(self.project, adapter, node, i, num_nodes) + runner = Runner(self.config, adapter, node, i, num_nodes) node_runners[uid] = runner return node_runners @@ -105,11 +95,10 @@ def get_relevant_runners(self, node_runners, node_subset): return runners def execute_nodes(self, linker, Runner, manifest, node_dependency_list): - profile = self.project.run_environment() - adapter = get_adapter(profile) + adapter = get_adapter(self.config) - num_threads = self.threads - target_name = self.project.get_target().get('name') + num_threads = self.config.threads + target_name = self.config.target_name text = "Concurrency: {} threads (target='{}')" concurrency_line = text.format(num_threads, target_name) @@ -150,8 +139,7 @@ def execute_nodes(self, linker, Runner, manifest, node_dependency_list): pool.close() pool.terminate() - profile = self.project.run_environment() - adapter = get_adapter(profile) + adapter = get_adapter(self.config) if not adapter.is_cancelable(): msg = ("The {} adapter does not support query " @@ -162,7 +150,7 @@ def execute_nodes(self, linker, Runner, manifest, node_dependency_list): dbt.ui.printer.print_timestamped_line(msg, yellow) raise - for conn_name in adapter.cancel_open_connections(profile): + for conn_name in adapter.cancel_open_connections(self.config): dbt.ui.printer.print_cancel_line(conn_name) dbt.ui.printer.print_run_end_messages(node_results, @@ -177,11 +165,11 @@ def execute_nodes(self, linker, Runner, manifest, node_dependency_list): return node_results def write_results(self, execution_result): - filepath = os.path.join(self.project['target-path'], RESULT_FILE_NAME) + filepath = os.path.join(self.config.target_path, RESULT_FILE_NAME) write_json(filepath, execution_result.serialize()) - def compile(self, project): - compiler = dbt.compilation.Compiler(project) + def compile(self, config): + compiler = dbt.compilation.Compiler(config) compiler.initialize() return compiler.compile() @@ -194,14 +182,13 @@ def run_from_graph(self, Selector, Runner, query): dbt.node_runners.BaseRunner """ - manifest, linker = self.compile(self.project) + manifest, linker = self.compile(self.config) selector = Selector(linker, manifest) selected_nodes = selector.select(query) dep_list = selector.as_node_list(selected_nodes) - profile = self.project.run_environment() - adapter = get_adapter(profile) + adapter = get_adapter(self.config) flat_nodes = dbt.utils.flatten_nodes(dep_list) if len(flat_nodes) == 0: @@ -217,13 +204,13 @@ def run_from_graph(self, Selector, Runner, query): logger.info("") try: - Runner.before_hooks(self.project, adapter, manifest) + Runner.before_hooks(self.config, adapter, manifest) started = time.time() - Runner.before_run(self.project, adapter, manifest) + Runner.before_run(self.config, adapter, manifest) res = self.execute_nodes(linker, Runner, manifest, dep_list) - Runner.after_run(self.project, adapter, res, manifest) + Runner.after_run(self.config, adapter, res, manifest) elapsed = time.time() - started - Runner.after_hooks(self.project, adapter, res, manifest, elapsed) + Runner.after_hooks(self.config, adapter, res, manifest, elapsed) finally: adapter.cleanup_connections() diff --git a/dbt/task/archive.py b/dbt/task/archive.py index ca072955aa7..62a1e8cb0b5 100644 --- a/dbt/task/archive.py +++ b/dbt/task/archive.py @@ -10,11 +10,7 @@ class ArchiveTask(RunnableTask): def run(self): - runner = RunManager( - self.project, - self.project['target-path'], - self.args - ) + runner = RunManager(self.config) query = { 'include': ['*'], diff --git a/dbt/task/base_task.py b/dbt/task/base_task.py index 754415c7508..1b9e8e58177 100644 --- a/dbt/task/base_task.py +++ b/dbt/task/base_task.py @@ -2,9 +2,9 @@ class BaseTask(object): - def __init__(self, args, project=None): + def __init__(self, args, config=None): self.args = args - self.project = project + self.config = config def run(self): raise dbt.exceptions.NotImplementedException('Not Implemented') diff --git a/dbt/task/clean.py b/dbt/task/clean.py index 07560ac2d8c..f7b524057b8 100644 --- a/dbt/task/clean.py +++ b/dbt/task/clean.py @@ -15,14 +15,14 @@ def __is_project_path(self, path): def __is_protected_path(self, path): abs_path = os.path.abspath(path) - protected_paths = self.project['source-paths'] + \ - self.project['test-paths'] + ['.'] + protected_paths = self.config.source_paths + \ + self.config.test_paths + ['.'] protected_abs_paths = [os.path.abspath for p in protected_paths] return abs_path in set(protected_abs_paths) or \ self.__is_project_path(abs_path) def run(self): - for path in self.project['clean-targets']: + for path in self.config.clean_targets: if not self.__is_protected_path(path): shutil.rmtree(path, True) diff --git a/dbt/task/compile.py b/dbt/task/compile.py index bd3266dc170..121ca02e39b 100644 --- a/dbt/task/compile.py +++ b/dbt/task/compile.py @@ -11,15 +11,13 @@ class CompileTask(RunnableTask): def run(self): - runner = RunManager( - self.project, self.project['target-path'], self.args - ) + runner = RunManager(self.config) query = { "include": self.args.models, "exclude": self.args.exclude, "resource_types": NodeType.executable(), - "tags": [] + "tags": [], } results = runner.run(query, CompileRunner) diff --git a/dbt/task/debug.py b/dbt/task/debug.py index 936bcdaf4c2..2d521371e9e 100644 --- a/dbt/task/debug.py +++ b/dbt/task/debug.py @@ -2,7 +2,7 @@ from dbt.logger import GLOBAL_LOGGER as logger import dbt.clients.system -import dbt.project +import dbt.config from dbt.task.base_task import BaseTask @@ -14,7 +14,7 @@ class DebugTask(BaseTask): def path_info(self): open_cmd = dbt.clients.system.open_dir_cmd() - profiles_dir = dbt.project.default_profiles_dir + profiles_dir = dbt.config.DEFAULT_PROFILES_DIR message = PROFILE_DIR_MESSAGE.format( open_cmd=open_cmd, @@ -24,9 +24,26 @@ def path_info(self): logger.info(message) def diag(self): + # if we got here, a 'dbt_project.yml' does exist, but we have not tried + # to parse it. + project_profile = None + try: + project = dbt.config.Project.from_current_directory() + project_profile = project.profile_name + except dbt.config.DbtConfigError as exc: + project = 'ERROR loading project: {!s}'.format(exc) + + # log the profile we decided on as well, if it's available. + try: + profile = dbt.config.Profile.from_args(self.args, project_profile) + except dbt.config.DbtConfigError as exc: + profile = 'ERROR loading profile: {!s}'.format(exc) + logger.info("args: {}".format(self.args)) - logger.info("project: ") - pprint.pprint(self.project) + logger.info("") + logger.info("project:\n{!s}".format(project)) + logger.info("") + logger.info("profile:\n{!s}".format(profile)) def run(self): diff --git a/dbt/task/deps.py b/dbt/task/deps.py index 777fd1aab5d..fd07f2b6f58 100644 --- a/dbt/task/deps.py +++ b/dbt/task/deps.py @@ -6,7 +6,6 @@ import yaml import dbt.utils -import dbt.project import dbt.deprecations import dbt.clients.git import dbt.clients.system @@ -16,17 +15,26 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.semver import VersionSpecifier, UnboundedVersionSpecifier from dbt.utils import AttrDict +from dbt.api.object import APIObject +from dbt.contracts.project import LOCAL_PACKAGE_CONTRACT, \ + GIT_PACKAGE_CONTRACT, REGISTRY_PACKAGE_CONTRACT from dbt.task.base_task import BaseTask DOWNLOADS_PATH = os.path.join(tempfile.gettempdir(), "dbt-downloads") -class Package(object): - def __init__(self, name): - self.name = name +class Package(APIObject): + SCHEMA = NotImplemented + + def __init__(self, *args, **kwargs): + super(Package, self).__init__(*args, **kwargs) self._cached_metadata = None + @property + def name(self): + raise NotImplementedError + def __str__(self): version = getattr(self, 'version', None) if not version: @@ -77,18 +85,23 @@ def fetch_metadata(self, project): def get_project_name(self, project): metadata = self.fetch_metadata(project) - return metadata["name"] + return metadata.project_name def get_installation_path(self, project): dest_dirname = self.get_project_name(project) - return os.path.join(project['modules-path'], dest_dirname) + return os.path.join(project.modules_path, dest_dirname) class RegistryPackage(Package): - def __init__(self, package, version): - super(RegistryPackage, self).__init__(package) - self.package = package - self._version = self._sanitize_version(version) + SCHEMA = REGISTRY_PACKAGE_CONTRACT + + def __init__(self, *args, **kwargs): + super(RegistryPackage, self).__init__(*args, **kwargs) + self._version = self._sanitize_version(self._contents['version']) + + @property + def name(self): + return self.package @classmethod def _sanitize_version(cls, version): @@ -145,6 +158,8 @@ def _check_version_pinned(self): def _fetch_metadata(self, project): version_string = self.version_name() + # TODO(jeb): this needs to actually return a RuntimeConfig, instead of + # parsed json from a URL return registry.package_version(self.package, version_string) def install(self, project): @@ -157,17 +172,22 @@ def install(self, project): download_url = metadata.get('downloads').get('tarball') dbt.clients.system.download(download_url, tar_path) - deps_path = project['modules-path'] + deps_path = project.modules_path package_name = self.get_project_name(project) dbt.clients.system.untar_package(tar_path, deps_path, package_name) class GitPackage(Package): - def __init__(self, git, version): - super(GitPackage, self).__init__(git) - self.git = git - self._checkout_name = hashlib.md5(six.b(git)).hexdigest() - self._version = self._sanitize_version(version) + SCHEMA = GIT_PACKAGE_CONTRACT + + def __init__(self, *args, **kwargs): + super(GitPackage, self).__init__(*args, **kwargs) + self._checkout_name = hashlib.md5(six.b(self.git)).hexdigest() + self.version = self._contents.get('revision') + + @property + def name(self): + return self.git @classmethod def _sanitize_version(cls, version): @@ -191,7 +211,8 @@ def nice_version_name(self): return "revision {}".format(self.version_name()) def incorporate(self, other): - return GitPackage(self.git, self.version + other.version) + return GitPackage(git=self.git, + revision=(self.version + other.version)) def _resolve_version(self): requested = set(self.version) @@ -216,7 +237,7 @@ def _checkout(self, project): def _fetch_metadata(self, project): path = self._checkout(project) - return dbt.utils.load_project_with_profile(project, path) + return project.from_project_root(path) def install(self, project): dest_path = self.get_installation_path(project) @@ -229,9 +250,11 @@ def install(self, project): class LocalPackage(Package): - def __init__(self, local): - super(LocalPackage, self).__init__(local) - self.local = local + SCHEMA = LOCAL_PACKAGE_CONTRACT + + @property + def name(self): + return self.local def incorporate(self, _): return LocalPackage(self.local) @@ -248,14 +271,14 @@ def nice_version_name(self): def _fetch_metadata(self, project): project_file_path = dbt.clients.system.resolve_path_from_base( self.local, - project['project-root']) + project.project_root) - return dbt.utils.load_project_with_profile(project, project_file_path) + return project.from_project_root(project_file_path) def install(self, project): src_path = dbt.clients.system.resolve_path_from_base( self.local, - project['project-root']) + project.project_root) dest_path = self.get_installation_path(project) @@ -286,15 +309,15 @@ def _parse_package(dict_): 'yours has {} of them - {}' .format(only_1_keys, len(specified), specified)) if dict_.get('package'): - return RegistryPackage(dict_['package'], dict_.get('version')) + return RegistryPackage(**dict_) if dict_.get('git'): if dict_.get('version'): msg = ("Keyword 'version' specified for git package {}.\nDid " "you mean 'revision'?".format(dict_.get('git'))) dbt.exceptions.raise_dependency_error(msg) - return GitPackage(dict_['git'], dict_.get('revision')) + return GitPackage(**dict_) if dict_.get('local'): - return LocalPackage(dict_['local']) + return LocalPackage(**dict_) dbt.exceptions.raise_dependency_error( 'Malformed package definition. Must contain package, git, or local.') @@ -376,7 +399,7 @@ class DepsTask(BaseTask): def _check_for_duplicate_project_names(self, final_deps): seen = set() for _, package in final_deps.items(): - project_name = package.get_project_name(self.project) + project_name = package.get_project_name(self.config) if project_name in seen: dbt.exceptions.raise_dependency_error( 'Found duplicate project {}. This occurs when a dependency' @@ -397,10 +420,10 @@ def track_package_install(self, package_name, source_type, version): }) def run(self): - dbt.clients.system.make_directory(self.project['modules-path']) + dbt.clients.system.make_directory(self.config.modules_path) dbt.clients.system.make_directory(DOWNLOADS_PATH) - packages = _read_packages(self.project) + packages = self.config.packages.packages if not packages: logger.info('Warning: No packages were found in packages.yml') return @@ -413,15 +436,15 @@ def run(self): for name, package in pending_deps.items(): final_deps.incorporate(package) final_deps[name].resolve_version() - target_metadata = final_deps[name].fetch_metadata(self.project) - sub_deps.incorporate_from_yaml(_read_packages(target_metadata)) + target_config = final_deps[name].fetch_metadata(self.config) + sub_deps.incorporate_from_yaml(target_config.packages.packages) pending_deps = sub_deps self._check_for_duplicate_project_names(final_deps) for _, package in final_deps.items(): logger.info('Installing %s', package) - package.install(self.project) + package.install(self.config) logger.info(' Installed from %s\n', package.nice_version_name()) self.track_package_install( diff --git a/dbt/task/generate.py b/dbt/task/generate.py index 4591f5fd43e..05808f1016f 100644 --- a/dbt/task/generate.py +++ b/dbt/task/generate.py @@ -187,12 +187,12 @@ def incorporate_catalog_unique_ids(catalog, manifest): class GenerateTask(CompileTask): def _get_manifest(self): - compiler = dbt.compilation.Compiler(self.project) + compiler = dbt.compilation.Compiler(self.config) compiler.initialize() all_projects = compiler.get_all_projects() - manifest = dbt.loader.GraphLoader.load_all(self.project, all_projects) + manifest = dbt.loader.GraphLoader.load_all(self.config, all_projects) return manifest def run(self): @@ -207,14 +207,13 @@ def run(self): shutil.copyfile( DOCS_INDEX_FILE_PATH, - os.path.join(self.project['target-path'], 'index.html')) + os.path.join(self.config.target_path, 'index.html')) manifest = self._get_manifest() - profile = self.project.run_environment() - adapter = get_adapter(profile) + adapter = get_adapter(self.config) dbt.ui.printer.print_timestamped_line("Building catalog") - results = adapter.get_catalog(profile, self.project.cfg, manifest) + results = adapter.get_catalog(self.config, manifest) results = [ dict(zip(results.column_names, row)) @@ -227,7 +226,7 @@ def run(self): 'generated_at': dbt.utils.timestring(), } - path = os.path.join(self.project['target-path'], CATALOG_FILENAME) + path = os.path.join(self.config.target_path, CATALOG_FILENAME) write_json(path, results) dbt.ui.printer.print_timestamped_line( diff --git a/dbt/task/init.py b/dbt/task/init.py index cace35a0294..1146bb9e2d4 100644 --- a/dbt/task/init.py +++ b/dbt/task/init.py @@ -1,6 +1,6 @@ import os -import dbt.project +import dbt.config import dbt.clients.git import dbt.clients.system @@ -91,7 +91,7 @@ def get_addendum(self, project_name, profiles_path): def run(self): project_dir = self.args.project_name - profiles_dir = dbt.project.default_profiles_dir + profiles_dir = dbt.config.DEFAULT_PROFILES_DIR profiles_file = os.path.join(profiles_dir, 'profiles.yml') self.create_profiles_dir(profiles_dir) diff --git a/dbt/task/run.py b/dbt/task/run.py index 03b2272e496..efcedbe03b4 100644 --- a/dbt/task/run.py +++ b/dbt/task/run.py @@ -12,9 +12,7 @@ class RunTask(RunnableTask): def run(self): - runner = RunManager( - self.project, self.project['target-path'], self.args - ) + runner = RunManager(self.config) query = { "include": self.args.models, diff --git a/dbt/task/seed.py b/dbt/task/seed.py index 2d58bdbc91b..2463bf5db3e 100644 --- a/dbt/task/seed.py +++ b/dbt/task/seed.py @@ -9,11 +9,7 @@ class SeedTask(RunnableTask): def run(self): - runner = RunManager( - self.project, - self.project["target-path"], - self.args, - ) + runner = RunManager(self.config) query = { "include": ["*"], "exclude": [], diff --git a/dbt/task/serve.py b/dbt/task/serve.py index 99a0dccb2e8..448ca2fcb93 100644 --- a/dbt/task/serve.py +++ b/dbt/task/serve.py @@ -12,7 +12,7 @@ class ServeTask(RunnableTask): def run(self): - os.chdir(self.project['target-path']) + os.chdir(self.config.target_path) port = self.args.port diff --git a/dbt/task/test.py b/dbt/task/test.py index 272a7774fcf..a276db6dd44 100644 --- a/dbt/task/test.py +++ b/dbt/task/test.py @@ -18,8 +18,7 @@ class TestTask(RunnableTask): d) accepted value """ def run(self): - runner = RunManager( - self.project, self.project['target-path'], self.args) + runner = RunManager(self.config) include = self.args.models exclude = self.args.exclude diff --git a/dbt/tracking.py b/dbt/tracking.py index de6bcd5e98c..4d0f3e8dc44 100644 --- a/dbt/tracking.py +++ b/dbt/tracking.py @@ -85,9 +85,9 @@ def get_run_type(args): return 'regular' -def get_invocation_context(user, project, args): +def get_invocation_context(user, config, args): return { - "project_id": None if project is None else project.hashed_name(), + "project_id": None if config is None else config.hashed_name(), "user_id": user.id, "invocation_id": user.invocation_id, @@ -99,8 +99,8 @@ def get_invocation_context(user, project, args): } -def get_invocation_start_context(user, project, args): - data = get_invocation_context(user, project, args) +def get_invocation_start_context(user, config, args): + data = get_invocation_context(user, config, args) start_data = { "progress": "start", @@ -112,8 +112,8 @@ def get_invocation_start_context(user, project, args): return SelfDescribingJson(INVOCATION_SPEC, data) -def get_invocation_end_context(user, project, args, result_type): - data = get_invocation_context(user, project, args) +def get_invocation_end_context(user, config, args, result_type): + data = get_invocation_context(user, config, args) start_data = { "progress": "end", @@ -125,8 +125,8 @@ def get_invocation_end_context(user, project, args, result_type): return SelfDescribingJson(INVOCATION_SPEC, data) -def get_invocation_invalid_context(user, project, args, result_type): - data = get_invocation_context(user, project, args) +def get_invocation_invalid_context(user, config, args, result_type): + data = get_invocation_context(user, config, args) start_data = { "progress": "invalid", @@ -175,9 +175,9 @@ def track(user, *args, **kwargs): ) -def track_invocation_start(project=None, args=None): +def track_invocation_start(config=None, args=None): context = [ - get_invocation_start_context(active_user, project, args), + get_invocation_start_context(active_user, config, args), get_platform_context(), get_dbt_env_context() ] @@ -216,11 +216,11 @@ def track_package_install(options): def track_invocation_end( - project=None, args=None, result_type=None + config=None, args=None, result_type=None ): user = active_user context = [ - get_invocation_end_context(user, project, args, result_type), + get_invocation_end_context(user, config, args, result_type), get_platform_context(), get_dbt_env_context() ] @@ -234,13 +234,13 @@ def track_invocation_end( def track_invalid_invocation( - project=None, args=None, result_type=None + config=None, args=None, result_type=None ): user = active_user invocation_context = get_invocation_invalid_context( user, - project, + config, args, result_type ) diff --git a/dbt/utils.py b/dbt/utils.py index c7f225d1251..f2d65d67892 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -168,20 +168,10 @@ def get_docs_macro_name(docs_name, with_prefix=True): return docs_name -def load_project_with_profile(source_project, project_dir): - project_filepath = os.path.join(project_dir, 'dbt_project.yml') - return dbt.project.read_project( - project_filepath, - source_project.profiles_dir, - profile_to_load=source_project.profile_to_load, - args=source_project.args) - - -def dependencies_for_path(project, module_path): +def dependencies_for_path(config, module_path): """Given a module path, yield all dependencies in that path.""" - import dbt.project logger.debug("Loading dependency project from {}".format(module_path)) - + import dbt.config for obj in os.listdir(module_path): full_obj = os.path.join(module_path, obj) @@ -192,8 +182,8 @@ def dependencies_for_path(project, module_path): continue try: - yield load_project_with_profile(project, full_obj) - except dbt.project.DbtProjectError as e: + yield config.new_project(full_obj) + except dbt.config.DbtProjectError as e: logger.info( "Error reading dependency project at {}".format( full_obj) @@ -201,14 +191,14 @@ def dependencies_for_path(project, module_path): logger.info(str(e)) -def dependency_projects(project): +def dependency_projects(config): module_paths = [ GLOBAL_DBT_MODULES_PATH, - os.path.join(project['project-root'], project['modules-path']) + os.path.join(config.project_root, config.modules_path) ] for module_path in module_paths: - for entry in dependencies_for_path(project, module_path): + for entry in dependencies_for_path(config, module_path): yield entry diff --git a/test/integration/001_simple_copy_test/test_simple_copy.py b/test/integration/001_simple_copy_test/test_simple_copy.py index f99ab209efc..6f9930827d3 100644 --- a/test/integration/001_simple_copy_test/test_simple_copy.py +++ b/test/integration/001_simple_copy_test/test_simple_copy.py @@ -2,7 +2,7 @@ from test.integration.base import DBTIntegrationTest, use_profile -class TestSimpleCopy(DBTIntegrationTest): +class BaseTestSimpleCopy(DBTIntegrationTest): @property def schema(self): @@ -16,6 +16,7 @@ def dir(path): def models(self): return self.dir("models") +class TestSimpleCopy(BaseTestSimpleCopy): @use_profile("postgres") def test__postgres__simple_copy(self): self.use_default_project({"data-paths": [self.dir("seed-initial")]}) @@ -64,31 +65,6 @@ def test__snowflake__simple_copy(self): self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"]) - @use_profile("snowflake") - def test__snowflake__simple_copy__quoting_on(self): - self.use_default_project({ - "data-paths": [self.dir("seed-initial")], - "quoting": {"identifier": True}, - }) - - results = self.run_dbt(["seed"]) - self.assertEqual(len(results), 1) - results = self.run_dbt() - self.assertEqual(len(results), 6) - - self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"]) - - self.use_default_project({ - "data-paths": [self.dir("seed-update")], - "quoting": {"identifier": True}, - }) - results = self.run_dbt(["seed"]) - self.assertEqual(len(results), 1) - results = self.run_dbt() - self.assertEqual(len(results), 6) - - self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"]) - @use_profile("snowflake") def test__snowflake__simple_copy__quoting_off(self): self.use_default_project({ @@ -155,24 +131,47 @@ def test__bigquery__simple_copy(self): self.assertTablesEqual("seed","materialized") -class TestSimpleCopyLowercasedSchema(DBTIntegrationTest): +class TestSimpleCopyQuotingIdentifierOn(BaseTestSimpleCopy): @property - def schema(self): - return "simple_copy_001" + def project_config(self): + return { + 'quoting': { + 'identifier': True, + }, + } - @staticmethod - def dir(path): - return "test/integration/001_simple_copy_test/" + path.lstrip("/") + @use_profile("snowflake") + def test__snowflake__simple_copy__quoting_on(self): + self.use_default_project({ + "data-paths": [self.dir("seed-initial")], + }) - @property - def models(self): - return self.dir("models") + results = self.run_dbt(["seed"]) + self.assertEqual(len(results), 1) + results = self.run_dbt() + self.assertEqual(len(results), 6) + self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"]) + + self.use_default_project({ + "data-paths": [self.dir("seed-update")], + }) + results = self.run_dbt(["seed"]) + self.assertEqual(len(results), 1) + results = self.run_dbt() + self.assertEqual(len(results), 6) + + self.assertManyTablesEqual(["seed", "view_model", "incremental", "materialized"]) + + +class BaseLowercasedSchemaTest(BaseTestSimpleCopy): def unique_schema(self): # bypass the forced uppercasing that unique_schema() does on snowflake - schema = super(TestSimpleCopyLowercasedSchema, self).unique_schema() + schema = super(BaseLowercasedSchemaTest, self).unique_schema() return schema.lower() + +class TestSnowflakeSimpleLowercasedSchemaCopy(BaseLowercasedSchemaTest): @use_profile('snowflake') def test__snowflake__simple_copy(self): self.use_default_project({"data-paths": [self.dir("seed-initial")]}) @@ -188,11 +187,18 @@ def test__snowflake__simple_copy(self): self.assertManyTablesEqual(["SEED", "VIEW_MODEL", "INCREMENTAL", "MATERIALIZED"]) + +class TestSnowflakeSimpleLowercasedSchemaQuoted(BaseLowercasedSchemaTest): + @property + def project_config(self): + return { + 'quoting': {'identifier': False, 'schema': True} + } + @use_profile("snowflake") def test__snowflake__seed__quoting_switch_schema(self): self.use_default_project({ "data-paths": [self.dir("seed-initial")], - "quoting": {"identifier": False, "schema": True}, }) results = self.run_dbt(["seed"]) diff --git a/test/integration/002_varchar_widening_test/test_varchar_widening.py b/test/integration/002_varchar_widening_test/test_varchar_widening.py index e2b47f03e5e..c276876a909 100644 --- a/test/integration/002_varchar_widening_test/test_varchar_widening.py +++ b/test/integration/002_varchar_widening_test/test_varchar_widening.py @@ -1,11 +1,6 @@ -from nose.plugins.attrib import attr -from test.integration.base import DBTIntegrationTest +from test.integration.base import DBTIntegrationTest, use_profile class TestVarcharWidening(DBTIntegrationTest): - - def setUp(self): - pass - @property def schema(self): return "varchar_widening_002" @@ -14,10 +9,8 @@ def schema(self): def models(self): return "test/integration/002_varchar_widening_test/models" - @attr(type='postgres') + @use_profile('postgres') def test__postgres__varchar_widening(self): - self.use_profile('postgres') - self.use_default_project() self.run_sql_file("test/integration/002_varchar_widening_test/seed.sql") results = self.run_dbt() @@ -34,10 +27,8 @@ def test__postgres__varchar_widening(self): self.assertTablesEqual("seed","incremental") self.assertTablesEqual("seed","materialized") - @attr(type='snowflake') + @use_profile('snowflake') def test__snowflake__varchar_widening(self): - self.use_profile('snowflake') - self.use_default_project() self.run_sql_file("test/integration/002_varchar_widening_test/seed.sql") results = self.run_dbt() diff --git a/test/integration/003_simple_reference_test/test_simple_reference.py b/test/integration/003_simple_reference_test/test_simple_reference.py index 810a321ab77..1fb17884ce8 100644 --- a/test/integration/003_simple_reference_test/test_simple_reference.py +++ b/test/integration/003_simple_reference_test/test_simple_reference.py @@ -1,11 +1,6 @@ -from nose.plugins.attrib import attr -from test.integration.base import DBTIntegrationTest +from test.integration.base import DBTIntegrationTest, use_profile class TestSimpleReference(DBTIntegrationTest): - - def setUp(self): - pass - @property def schema(self): return "simple_reference_003" @@ -14,9 +9,8 @@ def schema(self): def models(self): return "test/integration/003_simple_reference_test/models" - @attr(type='postgres') + @use_profile('postgres') def test__postgres__simple_reference(self): - self.use_profile('postgres') self.use_default_project() self.run_sql_file( "test/integration/003_simple_reference_test/seed.sql") @@ -52,9 +46,8 @@ def test__postgres__simple_reference(self): self.assertTablesEqual("summary_expected","view_summary") self.assertTablesEqual("summary_expected","ephemeral_summary") - @attr(type='snowflake') + @use_profile('snowflake') def test__snowflake__simple_reference(self): - self.use_profile('snowflake') self.use_default_project() self.run_sql_file("test/integration/003_simple_reference_test/seed.sql") @@ -78,9 +71,8 @@ def test__snowflake__simple_reference(self): ["SUMMARY_EXPECTED", "INCREMENTAL_SUMMARY", "MATERIALIZED_SUMMARY", "VIEW_SUMMARY", "EPHEMERAL_SUMMARY"] ) - @attr(type='postgres') + @use_profile('postgres') def test__postgres__simple_reference_with_models(self): - self.use_profile('postgres') self.use_default_project() self.run_sql_file("test/integration/003_simple_reference_test/seed.sql") @@ -97,9 +89,8 @@ def test__postgres__simple_reference_with_models(self): created_models = self.get_models_in_schema() self.assertTrue('materialized_copy' in created_models) - @attr(type='postgres') + @use_profile('postgres') def test__postgres__simple_reference_with_models_and_children(self): - self.use_profile('postgres') self.use_default_project() self.run_sql_file("test/integration/003_simple_reference_test/seed.sql") @@ -136,9 +127,8 @@ def test__postgres__simple_reference_with_models_and_children(self): self.assertTrue('ephemeral_summary' in created_models) self.assertEqual(created_models['ephemeral_summary'], 'table') - @attr(type='snowflake') + @use_profile('snowflake') def test__snowflake__simple_reference_with_models(self): - self.use_profile('snowflake') self.use_default_project() self.run_sql_file("test/integration/003_simple_reference_test/seed.sql") @@ -155,9 +145,8 @@ def test__snowflake__simple_reference_with_models(self): created_models = self.get_models_in_schema() self.assertTrue('MATERIALIZED_COPY' in created_models) - @attr(type='snowflake') + @use_profile('snowflake') def test__snowflake__simple_reference_with_models_and_children(self): - self.use_profile('snowflake') self.use_default_project() self.run_sql_file("test/integration/003_simple_reference_test/seed.sql") diff --git a/test/integration/004_simple_archive_test/test_simple_archive.py b/test/integration/004_simple_archive_test/test_simple_archive.py index cb0d5f6368e..970990ecaef 100644 --- a/test/integration/004_simple_archive_test/test_simple_archive.py +++ b/test/integration/004_simple_archive_test/test_simple_archive.py @@ -162,8 +162,8 @@ def test__bigquery__archive_with_new_field(self): # A more thorough test would assert that archived == expected, but BigQuery does not support the # "EXCEPT DISTINCT" operator on nested fields! Instead, just check that schemas are congruent. - expected_cols = self.adapter.get_columns_in_table(self._profile, self.project_config, self.unique_schema(), 'archive_expected') - archived_cols = self.adapter.get_columns_in_table(self._profile, self.project_config, self.unique_schema(), 'archive_actual') + expected_cols = self.adapter.get_columns_in_table(self.config, self.unique_schema(), 'archive_expected') + archived_cols = self.adapter.get_columns_in_table(self.config, self.unique_schema(), 'archive_actual') self.assertTrue(len(expected_cols) > 0, "source table does not exist -- bad test") self.assertEqual(len(expected_cols), len(archived_cols), "actual and expected column lengths are different") diff --git a/test/integration/005_simple_seed_test/test_seed_type_override.py b/test/integration/005_simple_seed_test/test_seed_type_override.py index 63e678a8b14..ac585de6453 100644 --- a/test/integration/005_simple_seed_test/test_seed_type_override.py +++ b/test/integration/005_simple_seed_test/test_seed_type_override.py @@ -70,6 +70,7 @@ def test_simple_seed_with_column_override_snowflake(self): results = self.run_dbt(["test"]) self.assertEqual(len(results), 2) + class TestSimpleSeedColumnOverrideBQ(TestSimpleSeedColumnOverride): @property def models(self): @@ -86,7 +87,7 @@ def profile_config(self): return self.bigquery_profile() @attr(type='bigquery') - def test_simple_seed_with_column_override_bq(self): + def test_simple_seed_with_column_override_bigquery(self): results = self.run_dbt(["seed"]) self.assertEqual(len(results), 1) results = self.run_dbt(["test"]) diff --git a/test/integration/006_simple_dependency_test/test_simple_dependency.py b/test/integration/006_simple_dependency_test/test_simple_dependency.py index ff84441e962..e20b4b2b4eb 100644 --- a/test/integration/006_simple_dependency_test/test_simple_dependency.py +++ b/test/integration/006_simple_dependency_test/test_simple_dependency.py @@ -16,10 +16,12 @@ def models(self): return "test/integration/006_simple_dependency_test/models" @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project' + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project' + } ] } @@ -77,10 +79,13 @@ def models(self): return "test/integration/006_simple_dependency_test/models" @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@master' + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'master', + }, ] } diff --git a/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py b/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py index eb4ea23cf92..ae96afd7f41 100644 --- a/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py +++ b/test/integration/006_simple_dependency_test/test_simple_dependency_with_configs.py @@ -16,6 +16,16 @@ def models(self): return "test/integration/006_simple_dependency_test/models" class TestSimpleDependencyWithConfigs(BaseTestSimpleDependencyWithConfigs): + @property + def packages_config(self): + return { + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'with-configs', + }, + ] + } @property def project_config(self): @@ -28,9 +38,6 @@ def project_config(self): } }, - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@with-configs' - ] } @attr(type='postgres') @@ -47,6 +54,17 @@ def test_simple_dependency(self): class TestSimpleDependencyWithOverriddenConfigs(BaseTestSimpleDependencyWithConfigs): + @property + def packages_config(self): + return { + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'with-configs', + }, + ] + } + @property def project_config(self): return { @@ -62,9 +80,6 @@ def project_config(self): } }, - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@with-configs' - ] } @@ -83,6 +98,17 @@ def test_simple_dependency(self): class TestSimpleDependencyWithModelSpecificOverriddenConfigs(BaseTestSimpleDependencyWithConfigs): + @property + def packages_config(self): + return { + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'with-configs', + }, + ] + } + @property def project_config(self): return { @@ -97,11 +123,7 @@ def project_config(self): } } } - }, - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@with-configs' - ] } @@ -121,6 +143,17 @@ def test_simple_dependency(self): class TestSimpleDependencyWithModelSpecificOverriddenConfigsAndMaterializations(BaseTestSimpleDependencyWithConfigs): + @property + def packages_config(self): + return { + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'with-configs', + }, + ] + } + @property def project_config(self): return { @@ -147,9 +180,6 @@ def project_config(self): } }, - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@with-configs' - ] } diff --git a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py index 4964a4b1f00..af21250936e 100644 --- a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py @@ -2,7 +2,6 @@ from test.integration.base import DBTIntegrationTest, FakeArgs from dbt.task.test import TestTask -from dbt.project import read_project class TestSchemaTestGraphSelection(DBTIntegrationTest): @@ -16,17 +15,16 @@ def models(self): return "test/integration/007_graph_selection_tests/models" @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project' + "packages": [ + {'git': 'https://github.com/fishtown-analytics/dbt-integration-project'} ] } def run_schema_and_assert(self, include, exclude, expected_tests): self.use_profile('postgres') self.use_default_project() - self.project = read_project('dbt_project.yml') self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql") self.run_dbt(["deps"]) @@ -37,7 +35,7 @@ def run_schema_and_assert(self, include, exclude, expected_tests): args.models = include args.exclude = exclude - test_task = TestTask(args, self.project) + test_task = TestTask(args, self.config) test_results = test_task.run() ran_tests = sorted([test.node.get('name') for test in test_results]) diff --git a/test/integration/008_schema_tests_test/test_schema_tests.py b/test/integration/008_schema_tests_test/test_schema_tests.py index 57d7a935c91..6ba72ebd418 100644 --- a/test/integration/008_schema_tests_test/test_schema_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_tests.py @@ -2,7 +2,6 @@ from test.integration.base import DBTIntegrationTest, FakeArgs from dbt.task.test import TestTask -from dbt.project import read_project class TestSchemaTests(DBTIntegrationTest): @@ -21,10 +20,9 @@ def models(self): return "test/integration/008_schema_tests_test/models-v1/models" def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') @@ -72,10 +70,9 @@ def models(self): return "test/integration/008_schema_tests_test/models-v1/malformed" def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') @@ -99,6 +96,16 @@ def setUp(self): def schema(self): return "schema_tests_008" + @property + def packages_config(self): + return { + "packages": [ + {'git': 'https://github.com/fishtown-analytics/dbt-utils'}, + {'git': 'https://github.com/fishtown-analytics/dbt-integration-project'}, + ] + } + + @property def project_config(self): # dbt-utils containts a schema test (equality) @@ -106,10 +113,6 @@ def project_config(self): # both should work! return { "macro-paths": ["test/integration/008_schema_tests_test/macros-v1"], - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-utils', - 'https://github.com/fishtown-analytics/dbt-integration-project' - ] } @property @@ -117,10 +120,9 @@ def models(self): return "test/integration/008_schema_tests_test/models-v1/custom" def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') diff --git a/test/integration/008_schema_tests_test/test_schema_v2_tests.py b/test/integration/008_schema_tests_test/test_schema_v2_tests.py index f18c9c349ee..c32908b64e5 100644 --- a/test/integration/008_schema_tests_test/test_schema_v2_tests.py +++ b/test/integration/008_schema_tests_test/test_schema_v2_tests.py @@ -3,9 +3,9 @@ import os from dbt.task.test import TestTask -from dbt.project import read_project from dbt.exceptions import CompilationException + class TestSchemaTests(DBTIntegrationTest): def setUp(self): @@ -22,10 +22,9 @@ def models(self): return "test/integration/008_schema_tests_test/models-v2/models" def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') @@ -73,10 +72,9 @@ def models(self): return "test/integration/008_schema_tests_test/models-v2/malformed" def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') @@ -107,6 +105,20 @@ def setUp(self): def schema(self): return "schema_tests_008" + @property + def packages_config(self): + return { + 'packages': [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-utils', + 'revision': 'macros-v2', + }, + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + }, + ] + } + @property def project_config(self): # dbt-utils containts a schema test (equality) @@ -114,10 +126,6 @@ def project_config(self): # both should work! return { "macro-paths": ["test/integration/008_schema_tests_test/macros-v2"], - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-utils@macros-v2', - 'https://github.com/fishtown-analytics/dbt-integration-project' - ] } @property @@ -125,10 +133,9 @@ def models(self): return "test/integration/008_schema_tests_test/models-v2/custom" def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') @@ -147,7 +154,8 @@ def test_schema_tests(self): self.assertTrue(result.node['name'] in expected_failures) self.assertEqual(sum(x.status for x in test_results), 52) -class TestSchemaTests(DBTIntegrationTest): + +class TestBQSchemaTests(DBTIntegrationTest): @property def schema(self): return "schema_tests_008" @@ -162,14 +170,13 @@ def dir(path): os.path.join('test/integration/008_schema_tests_test/models-v2', path)) def run_schema_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @use_profile('bigquery') - def test_schema_tests(self): + def test_schema_tests_bigquery(self): self.use_default_project({'data-paths': [self.dir('seed')]}) self.assertEqual(len(self.run_dbt(['seed'])), 1) results = self.run_dbt() diff --git a/test/integration/009_data_tests_test/test_data_tests.py b/test/integration/009_data_tests_test/test_data_tests.py index 3f7504b706a..4bfc9b96a98 100644 --- a/test/integration/009_data_tests_test/test_data_tests.py +++ b/test/integration/009_data_tests_test/test_data_tests.py @@ -2,7 +2,6 @@ from test.integration.base import DBTIntegrationTest, FakeArgs from dbt.task.test import TestTask -from dbt.project import read_project import os @@ -25,11 +24,10 @@ def models(self): return "test/integration/009_data_tests_test/models" def run_data_validations(self): - project = read_project('dbt_project.yml') args = FakeArgs() args.data = True - test_task = TestTask(args, project) + test_task = TestTask(args, self.config) return test_task.run() @attr(type='postgres') diff --git a/test/integration/014_hook_tests/test_model_hooks_bq.py b/test/integration/014_hook_tests/test_model_hooks_bq.py index 25577515d51..d9c81b212f5 100644 --- a/test/integration/014_hook_tests/test_model_hooks_bq.py +++ b/test/integration/014_hook_tests/test_model_hooks_bq.py @@ -45,8 +45,6 @@ class TestBigqueryPrePostModelHooks(DBTIntegrationTest): def setUp(self): DBTIntegrationTest.setUp(self) - self.use_profile('bigquery') - self.use_default_project() self.run_sql_file("test/integration/014_hook_tests/seed_model_bigquery.sql") self.fields = [ @@ -109,7 +107,7 @@ def check_hooks(self, state): self.assertTrue(ctx['invocation_id'] is not None and len(ctx['invocation_id']) > 0, 'invocation_id was not set') @attr(type='bigquery') - def test_pre_and_post_model_hooks(self): + def test_pre_and_post_model_hooks_bigquery(self): self.run_dbt(['run']) self.check_hooks('start') @@ -117,11 +115,6 @@ def test_pre_and_post_model_hooks(self): class TestBigqueryPrePostModelHooksOnSeeds(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - self.use_profile('bigquery') - self.use_default_project() - @property def schema(self): return "model_hooks_014" @@ -143,7 +136,7 @@ def project_config(self): } @attr(type='bigquery') - def test_hooks_on_seeds(self): + def test_hooks_on_seeds_bigquery(self): res = self.run_dbt(['seed']) self.assertEqual(len(res), 1, 'Expected exactly one item') res = self.run_dbt(['test']) diff --git a/test/integration/014_hook_tests/test_run_hooks_bq.py b/test/integration/014_hook_tests/test_run_hooks_bq.py index c8699f0e97a..5dc05e57faa 100644 --- a/test/integration/014_hook_tests/test_run_hooks_bq.py +++ b/test/integration/014_hook_tests/test_run_hooks_bq.py @@ -89,7 +89,7 @@ def test_bigquery_pre_and_post_run_hooks(self): self.assertTableDoesNotExist("end_hook_order_test") @attr(type='bigquery') - def test_pre_and_post_seed_hooks(self): + def test_bigquery_pre_and_post_seed_hooks(self): self.run_dbt(['seed']) self.check_hooks('start') diff --git a/test/integration/015_cli_invocation_tests/test_cli_invocation.py b/test/integration/015_cli_invocation_tests/test_cli_invocation.py index 86747b66f9e..8a43eba8794 100644 --- a/test/integration/015_cli_invocation_tests/test_cli_invocation.py +++ b/test/integration/015_cli_invocation_tests/test_cli_invocation.py @@ -70,7 +70,7 @@ def custom_profile_config(self): 'schema': self.custom_schema }, }, - 'run-target': 'default' + 'target': 'default', } } diff --git a/test/integration/016_macro_tests/test_macros.py b/test/integration/016_macro_tests/test_macros.py index f6b6e41d7a8..a2164c9c62f 100644 --- a/test/integration/016_macro_tests/test_macros.py +++ b/test/integration/016_macro_tests/test_macros.py @@ -16,6 +16,14 @@ def schema(self): def models(self): return "test/integration/016_macro_tests/models" + @property + def packages_config(self): + return { + 'packages': [ + {'git': 'https://github.com/fishtown-analytics/dbt-integration-project'}, + ] + } + @property def project_config(self): return { @@ -25,9 +33,6 @@ def project_config(self): } }, "macro-paths": ["test/integration/016_macro_tests/macros"], - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project' - ] } @attr(type='postgres') @@ -84,13 +89,18 @@ def schema(self): def models(self): return "test/integration/016_macro_tests/bad-models" + @property + def packages_config(self): + return { + 'packages': [ + {'git': 'https://github.com/fishtown-analytics/dbt-integration-project'} + ] + } + @property def project_config(self): return { "macro-paths": ["test/integration/016_macro_tests/macros"], - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project' - ] } # TODO: compilation no longer exists, so while the model calling this macro diff --git a/test/integration/020_ephemeral_test/test_ephemeral.py b/test/integration/020_ephemeral_test/test_ephemeral.py index 0759264e4a4..4c0aaaac6ca 100644 --- a/test/integration/020_ephemeral_test/test_ephemeral.py +++ b/test/integration/020_ephemeral_test/test_ephemeral.py @@ -3,10 +3,6 @@ class TestEphemeral(DBTIntegrationTest): - - def setUp(self): - pass - @property def schema(self): return "ephemeral_020" @@ -17,8 +13,6 @@ def models(self): @attr(type='postgres') def test__postgres(self): - self.use_profile('postgres') - self.use_default_project() self.run_sql_file("test/integration/020_ephemeral_test/seed.sql") results = self.run_dbt() @@ -30,8 +24,6 @@ def test__postgres(self): @attr(type='snowflake') def test__snowflake(self): - self.use_profile('snowflake') - self.use_default_project() self.run_sql_file("test/integration/020_ephemeral_test/seed.sql") results = self.run_dbt() diff --git a/test/integration/021_concurrency_test/test_concurrency.py b/test/integration/021_concurrency_test/test_concurrency.py index dd76d1c7a52..a2e5d497007 100644 --- a/test/integration/021_concurrency_test/test_concurrency.py +++ b/test/integration/021_concurrency_test/test_concurrency.py @@ -3,10 +3,6 @@ class TestConcurrency(DBTIntegrationTest): - - def setUp(self): - pass - @property def schema(self): return "concurrency_021" @@ -17,8 +13,6 @@ def models(self): @attr(type='postgres') def test__postgres__concurrency(self): - self.use_profile('postgres') - self.use_default_project() self.run_sql_file("test/integration/021_concurrency_test/seed.sql") results = self.run_dbt(expect_pass=False) @@ -45,8 +39,6 @@ def test__postgres__concurrency(self): @attr(type='snowflake') def test__snowflake__concurrency(self): - self.use_profile('snowflake') - self.use_default_project() self.run_sql_file("test/integration/021_concurrency_test/seed.sql") results = self.run_dbt(expect_pass=False) diff --git a/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py b/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py index 359ef5b204c..a8e6bd0cc5a 100644 --- a/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py +++ b/test/integration/022_bigquery_test/test_bigquery_adapter_functions.py @@ -18,8 +18,6 @@ def profile_config(self): @attr(type='bigquery') def test__bigquery_adapter_functions(self): - self.use_profile('bigquery') - self.use_default_project() results = self.run_dbt() self.assertEqual(len(results), 3) diff --git a/test/integration/022_bigquery_test/test_bigquery_date_partitioning.py b/test/integration/022_bigquery_test/test_bigquery_date_partitioning.py index b7fde63992c..6a9ba233378 100644 --- a/test/integration/022_bigquery_test/test_bigquery_date_partitioning.py +++ b/test/integration/022_bigquery_test/test_bigquery_date_partitioning.py @@ -18,8 +18,6 @@ def profile_config(self): @attr(type='bigquery') def test__bigquery_date_partitioning(self): - self.use_profile('bigquery') - self.use_default_project() results = self.run_dbt() self.assertEqual(len(results), 6) diff --git a/test/integration/023_exit_codes_test/test_exit_codes.py b/test/integration/023_exit_codes_test/test_exit_codes.py index 5aa96cc52c1..bda5a5dee9f 100644 --- a/test/integration/023_exit_codes_test/test_exit_codes.py +++ b/test/integration/023_exit_codes_test/test_exit_codes.py @@ -35,8 +35,6 @@ def project_config(self): @attr(type='postgres') def test_exit_code_run_succeed(self): - self.use_profile('postgres') - self.use_default_project() results, success = self.run_dbt_and_check(['run', '--model', 'good']) self.assertEqual(len(results), 1) self.assertTrue(success) @@ -44,8 +42,6 @@ def test_exit_code_run_succeed(self): @attr(type='postgres') def test__exit_code_run_fail(self): - self.use_profile('postgres') - self.use_default_project() results, success = self.run_dbt_and_check(['run', '--model', 'bad']) self.assertEqual(len(results), 1) self.assertFalse(success) @@ -53,8 +49,6 @@ def test__exit_code_run_fail(self): @attr(type='postgres') def test___schema_test_pass(self): - self.use_profile('postgres') - self.use_default_project() results, success = self.run_dbt_and_check(['run', '--model', 'good']) self.assertEqual(len(results), 1) self.assertTrue(success) @@ -64,8 +58,6 @@ def test___schema_test_pass(self): @attr(type='postgres') def test___schema_test_fail(self): - self.use_profile('postgres') - self.use_default_project() results, success = self.run_dbt_and_check(['run', '--model', 'dupe']) self.assertEqual(len(results), 1) self.assertTrue(success) @@ -75,17 +67,12 @@ def test___schema_test_fail(self): @attr(type='postgres') def test___compile(self): - self.use_profile('postgres') - self.use_default_project() results, success = self.run_dbt_and_check(['compile']) self.assertEqual(len(results), 7) self.assertTrue(success) @attr(type='postgres') def test___archive_pass(self): - self.use_profile('postgres') - self.use_default_project() - self.run_dbt_and_check(['run', '--model', 'good']) results, success = self.run_dbt_and_check(['archive']) self.assertEqual(len(results), 1) @@ -123,9 +110,6 @@ def project_config(self): @attr(type='postgres') def test___archive_fail(self): - self.use_profile('postgres') - self.use_default_project() - results, success = self.run_dbt_and_check(['run', '--model', 'good']) self.assertTrue(success) self.assertEqual(len(results), 1) @@ -146,10 +130,10 @@ def models(self): return "test/integration/023_exit_codes_test/models" @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project' + "packages": [ + {'git': 'https://github.com/fishtown-analytics/dbt-integration-project'} ] } @@ -167,11 +151,15 @@ def schema(self): def models(self): return "test/integration/023_exit_codes_test/models" + @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@bad-branch' + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'bad-branch', + }, ] } diff --git a/test/integration/025_duplicate_model_test/test_duplicate_model.py b/test/integration/025_duplicate_model_test/test_duplicate_model.py index 23cd526b215..1aa52693293 100644 --- a/test/integration/025_duplicate_model_test/test_duplicate_model.py +++ b/test/integration/025_duplicate_model_test/test_duplicate_model.py @@ -6,9 +6,6 @@ class TestDuplicateModelEnabled(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - @property def schema(self): return "duplicate_model_025" @@ -49,9 +46,6 @@ def test_duplicate_model_enabled(self): class TestDuplicateModelDisabled(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - @property def schema(self): return "duplicate_model_025" @@ -96,9 +90,6 @@ def test_duplicate_model_disabled(self): class TestDuplicateModelEnabledAcrossPackages(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - @property def schema(self): return "duplicate_model_025" @@ -108,11 +99,14 @@ def models(self): return "test/integration/025_duplicate_model_test/models-3" @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@master' - ] + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'master', + }, + ], } @attr(type="postgres") @@ -141,11 +135,14 @@ def models(self): return "test/integration/025_duplicate_model_test/models-4" @property - def project_config(self): + def packages_config(self): return { - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project@master' - ] + "packages": [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + 'revision': 'master', + }, + ], } @attr(type="postgres") diff --git a/test/integration/025_timezones_test/test_timezones.py b/test/integration/025_timezones_test/test_timezones.py index 3391d4fe1bf..ed1ce6a9023 100644 --- a/test/integration/025_timezones_test/test_timezones.py +++ b/test/integration/025_timezones_test/test_timezones.py @@ -5,9 +5,6 @@ class TestTimezones(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - @property def schema(self): return "timezones_025" diff --git a/test/integration/026_aliases_test/test_aliases.py b/test/integration/026_aliases_test/test_aliases.py index 39e301e40e6..418b799cc63 100644 --- a/test/integration/026_aliases_test/test_aliases.py +++ b/test/integration/026_aliases_test/test_aliases.py @@ -29,21 +29,18 @@ def project_config(self): @attr(type='postgres') def test__alias_model_name(self): - self.use_profile('postgres') results = self.run_dbt(['run']) self.assertEqual(len(results), 4) self.run_dbt(['test']) @attr(type='bigquery') def test__alias_model_name_bigquery(self): - self.use_profile('bigquery') results = self.run_dbt(['run']) self.assertEqual(len(results), 4) self.run_dbt(['test']) @attr(type='snowflake') def test__alias_model_name_snowflake(self): - self.use_profile('snowflake') results = self.run_dbt(['run']) self.assertEqual(len(results), 4) self.run_dbt(['test']) diff --git a/test/integration/027_cycle_test/test_cycles.py b/test/integration/027_cycle_test/test_cycles.py index d8ab20fbb7f..91fd22705cb 100644 --- a/test/integration/027_cycle_test/test_cycles.py +++ b/test/integration/027_cycle_test/test_cycles.py @@ -5,9 +5,6 @@ class TestSimpleCycle(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - @property def schema(self): return "cycles_simple_025" @@ -25,9 +22,6 @@ def test_simple_cycle(self): class TestComplexCycle(DBTIntegrationTest): - def setUp(self): - DBTIntegrationTest.setUp(self) - @property def schema(self): return "cycles_complex_025" diff --git a/test/integration/028_cli_vars/test_cli_var_override.py b/test/integration/028_cli_vars/test_cli_var_override.py index 77506178dc1..0a2451118d0 100644 --- a/test/integration/028_cli_vars/test_cli_var_override.py +++ b/test/integration/028_cli_vars/test_cli_var_override.py @@ -55,8 +55,6 @@ def project_config(self): @attr(type='postgres') def test__overriden_vars_project_level(self): - self.use_default_project() - self.use_profile('postgres') # This should be "override" self.run_dbt(["run", "--vars", "{required: override}"]) diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index a2a67b0cc80..ca67a3f6ae3 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -45,11 +45,18 @@ def models(self): return self.dir("models") @property - def project_config(self): + def packages_config(self): return { - 'repositories': [ - 'https://github.com/fishtown-analytics/dbt-integration-project' + 'packages': [ + { + 'git': 'https://github.com/fishtown-analytics/dbt-integration-project', + }, ], + } + + @property + def project_config(self): + return { 'quoting': { 'identifier': False } @@ -1516,8 +1523,9 @@ def expected_run_results(self, quote_schema=True, quote_model=False): compiled_seed = self._quote('seed') if quote_model else 'seed' if self.adapter_type == 'bigquery': - compiled_sql = '\n\nselect * from `{}`.{}.{}'.format( - self._profile['project'], compiled_schema, compiled_seed + status = 'OK' + compiled_sql = '\n\nselect * from `{}`.`{}`.seed'.format( + self.config.credentials.project, schema ) else: compiled_sql = '\n\nselect * from {}.{}'.format(compiled_schema, diff --git a/test/integration/030_statement_test/models-bq/statement_actual.sql b/test/integration/030_statement_test/models-bq/statement_actual.sql new file mode 100644 index 00000000000..92f9ab1ab95 --- /dev/null +++ b/test/integration/030_statement_test/models-bq/statement_actual.sql @@ -0,0 +1,23 @@ + +-- {{ ref('seed') }} + +{%- call statement('test_statement', fetch_result=True) -%} + + select + count(*) as `num_records` + + from {{ ref('seed') }} + +{%- endcall -%} + +{% set result = load_result('test_statement') %} + +{% set res_table = result['table'] %} +{% set res_matrix = result['data'] %} + +{% set matrix_value = res_matrix[0][0] %} +{% set table_value = res_table[0]['num_records'] %} + +select 'matrix' as source, {{ matrix_value }} as value +union all +select 'table' as source, {{ table_value }} as value diff --git a/test/integration/030_statement_test/test_statements.py b/test/integration/030_statement_test/test_statements.py index cc70fb2414b..6d3c204e070 100644 --- a/test/integration/030_statement_test/test_statements.py +++ b/test/integration/030_statement_test/test_statements.py @@ -1,5 +1,4 @@ -from nose.plugins.attrib import attr -from test.integration.base import DBTIntegrationTest +from test.integration.base import DBTIntegrationTest, use_profile class TestStatements(DBTIntegrationTest): @@ -16,9 +15,8 @@ def dir(path): def models(self): return self.dir("models") - @attr(type="postgres") + @use_profile("postgres") def test_postgres_statements(self): - self.use_profile("postgres") self.use_default_project({"data-paths": [self.dir("seed")]}) results = self.run_dbt(["seed"]) @@ -28,9 +26,8 @@ def test_postgres_statements(self): self.assertTablesEqual("statement_actual","statement_expected") - @attr(type="snowflake") + @use_profile("snowflake") def test_snowflake_statements(self): - self.use_profile("snowflake") self.use_default_project({"data-paths": [self.dir("seed")]}) results = self.run_dbt(["seed"]) @@ -40,9 +37,23 @@ def test_snowflake_statements(self): self.assertManyTablesEqual(["STATEMENT_ACTUAL", "STATEMENT_EXPECTED"]) - @attr(type="bigquery") + +class TestStatementsBigquery(DBTIntegrationTest): + + @property + def schema(self): + return "statements_030" + + @staticmethod + def dir(path): + return "test/integration/030_statement_test/" + path.lstrip("/") + + @property + def models(self): + return self.dir("models-bq") + + @use_profile("bigquery") def test_bigquery_statements(self): - self.use_profile("postgres") self.use_default_project({"data-paths": [self.dir("seed")]}) results = self.run_dbt(["seed"]) diff --git a/test/integration/031_thread_count_test/test_thread_count.py b/test/integration/031_thread_count_test/test_thread_count.py index 208acadd686..2d968eb7bad 100644 --- a/test/integration/031_thread_count_test/test_thread_count.py +++ b/test/integration/031_thread_count_test/test_thread_count.py @@ -1,9 +1,5 @@ -from nose.plugins.attrib import attr -from test.integration.base import DBTIntegrationTest, FakeArgs -from dbt.task.test import TestTask -from dbt.project import read_project -import os +from test.integration.base import DBTIntegrationTest, use_profile class TestThreadCount(DBTIntegrationTest): @@ -26,9 +22,7 @@ def schema(self): def models(self): return "test/integration/031_thread_count_test/models" - @attr(type='postgres') + @use_profile('postgres') def test_postgres_threading_8x(self): - self.use_profile('postgres') - results = self.run_dbt(args=['run', '--threads', '16']) self.assertTrue(len(results), 20) diff --git a/test/integration/033_event_tracking_test/test_events.py b/test/integration/033_event_tracking_test/test_events.py index 5acd3c20d27..702dec05bc5 100644 --- a/test/integration/033_event_tracking_test/test_events.py +++ b/test/integration/033_event_tracking_test/test_events.py @@ -50,7 +50,7 @@ def run_event_test( track_fn.reset_mock() project_id = hashlib.md5( - self.project['name'].encode('utf-8')).hexdigest() + self.config.project_name.encode('utf-8')).hexdigest() version = str(dbt.version.get_installed_version()) if expect_raise: @@ -166,14 +166,19 @@ def populate(project_id, user_id, invocation_id, version): class TestEventTrackingSuccess(TestEventTracking): + @property + def packages_config(self): + return { + 'packages': [ + {'git': 'https://github.com/fishtown-analytics/dbt-integration-project'}, + ], + } + @property def project_config(self): return { "data-paths": [self.dir("data")], "test-paths": [self.dir("test")], - "repositories": [ - 'https://github.com/fishtown-analytics/dbt-integration-project' - ] } @attr(type="postgres") diff --git a/test/integration/034_redshift_test/test_late_binding_view.py b/test/integration/034_redshift_test/test_late_binding_view.py index e4742e329c0..a55318443bf 100644 --- a/test/integration/034_redshift_test/test_late_binding_view.py +++ b/test/integration/034_redshift_test/test_late_binding_view.py @@ -20,10 +20,12 @@ def dir(path): def models(self): return self.dir("models") - @use_profile('redshift') - def test__late_binding_view_query(self): - self.use_default_project({"data-paths": [self.dir("seed")]}) + @property + def project_config(self): + return {"data-paths": [self.dir("seed")]} + @use_profile('redshift') + def test__redshift_late_binding_view_query(self): self.assertEqual(len(self.run_dbt(["seed"])), 1) self.assertEqual(len(self.run_dbt()), 1) # remove the table. Use 'cascade' here so that if late-binding views diff --git a/test/integration/035_changing_relation_type_test/test_changing_relation_type.py b/test/integration/035_changing_relation_type_test/test_changing_relation_type.py index 0eee9545f09..8e32fc3a747 100644 --- a/test/integration/035_changing_relation_type_test/test_changing_relation_type.py +++ b/test/integration/035_changing_relation_type_test/test_changing_relation_type.py @@ -1,5 +1,4 @@ -from nose.plugins.attrib import attr -from test.integration.base import DBTIntegrationTest +from test.integration.base import DBTIntegrationTest, use_profile class TestChangingRelationType(DBTIntegrationTest): @@ -40,28 +39,24 @@ def swap_types_and_test(self): self.assertEquals(results[0].node['config']['materialized'], 'view') self.assertEqual(len(results), 1) - @attr(type="postgres") + @use_profile("postgres") def test__postgres__switch_materialization(self): - self.use_profile("postgres") self.swap_types_and_test() - @attr(type="snowflake") + @use_profile("snowflake") def test__snowflake__switch_materialization(self): - self.use_profile("snowflake") self.swap_types_and_test() - @attr(type="redshift") + @use_profile("redshift") def test__redshift__switch_materialization(self): - self.use_profile("redshift") self.swap_types_and_test() - @attr(type="bigquery") + @use_profile("bigquery") def test__bigquery__switch_materialization(self): # BQ has a weird check that prevents the dropping of tables in the view materialization # if --full-refresh is not provided. This is to prevent the clobbering of a date-sharded # table with a view if a model config is accidently changed. We should probably remove that check # and then remove these bq-specific tests - self.use_profile("bigquery") results = self.run_dbt(['run', '--vars', 'materialized: view']) self.assertEquals(results[0].node['config']['materialized'], 'view') diff --git a/test/integration/035_docs_blocks/test_docs_blocks.py b/test/integration/035_docs_blocks/test_docs_blocks.py index 4f25e55b966..9b8af115cb2 100644 --- a/test/integration/035_docs_blocks/test_docs_blocks.py +++ b/test/integration/035_docs_blocks/test_docs_blocks.py @@ -1,7 +1,6 @@ import json import os -from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest, use_profile import dbt.exceptions diff --git a/test/integration/base.py b/test/integration/base.py index 594696c1493..aae81e99b83 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -12,7 +12,7 @@ import dbt.flags as flags from dbt.adapters.factory import get_adapter -from dbt.project import Project +from dbt.config import RuntimeConfig from dbt.logger import GLOBAL_LOGGER as logger @@ -34,6 +34,27 @@ def __init__(self): self.exclude = None +class TestArgs(object): + def __init__(self, kwargs): + self.__dict__.update(kwargs) + + +def _profile_from_test_name(test_name): + adapters_in_name = sum(x in test_name for x in + ('postgres', 'snowflake', 'redshift', 'bigquery')) + if adapters_in_name > 1: + raise ValueError('test names must only have 1 profile choice embedded') + + if 'snowflake' in test_name: + return 'snowflake' + elif 'redshift' in test_name: + return 'redshift' + elif 'bigquery' in test_name: + return 'bigquery' + else: + return 'postgres' + + class DBTIntegrationTest(unittest.TestCase): prefix = "test{}{:04}".format(int(time.time()), random.randint(0, 9999)) @@ -148,6 +169,10 @@ def bigquery_profile(self): } } + @property + def packages_config(self): + return None + def unique_schema(self): schema = self.schema @@ -167,65 +192,21 @@ def get_profile(self, adapter_type): return self.bigquery_profile() elif adapter_type == 'redshift': return self.redshift_profile() + else: + raise ValueError('invalid adapter type {}'.format(adapter_type)) + + def _pick_profile(self): + test_name = self.id().split('.')[-1] + return _profile_from_test_name(test_name) def setUp(self): flags.reset() - self.adapter_type = 'postgres' - - # create a dbt_project.yml - - base_project_config = { - 'name': 'test', - 'version': '1.0', - 'test-paths': [], - 'source-paths': [self.models], - 'profile': 'test' - } - - project_config = {} - project_config.update(base_project_config) - project_config.update(self.project_config) - - with open("dbt_project.yml", 'w') as f: - yaml.safe_dump(project_config, f, default_flow_style=True) - - # create profiles - - profile_config = {} - default_profile_config = self.postgres_profile() - profile_config.update(default_profile_config) - profile_config.update(self.profile_config) - - if not os.path.exists(DBT_CONFIG_DIR): - os.makedirs(DBT_CONFIG_DIR) - - with open(DBT_PROFILES, 'w') as f: - yaml.safe_dump(profile_config, f, default_flow_style=True) - - target = profile_config.get('test').get('target') - - if target is None: - target = profile_config.get('test').get('run-target') - - profile = profile_config.get('test').get('outputs').get(target) - - project = Project(project_config, profile_config, DBT_CONFIG_DIR) + self._clean_files() - adapter = get_adapter(profile) - - # it's important to use a different connection handle here so - # we don't look into an incomplete transaction - adapter.cleanup_connections() - connection = adapter.acquire_connection(profile, '__test') - self.handle = connection.get('handle') - self.adapter_type = profile.get('type') - self.adapter = adapter - self._profile = profile - self._profile_config = profile_config - self.project = project - - self._drop_schema() - self._create_schema() + self.use_profile(self._pick_profile()) + self.use_default_project() + self.set_packages() + self.load_config() def use_default_project(self, overrides=None): # create a dbt_project.yml @@ -242,9 +223,6 @@ def use_default_project(self, overrides=None): project_config.update(self.project_config) project_config.update(overrides or {}) - project = Project(project_config, self._profile_config, DBT_CONFIG_DIR) - self.project = project - with open("dbt_project.yml", 'w') as f: yaml.safe_dump(project_config, f, default_flow_style=True) @@ -262,35 +240,50 @@ def use_profile(self, adapter_type): with open(DBT_PROFILES, 'w') as f: yaml.safe_dump(profile_config, f, default_flow_style=True) + self._profile_config = profile_config - profile = profile_config.get('test').get('outputs').get('default2') - adapter = get_adapter(profile) - - self.adapter = adapter + def set_packages(self): + if self.packages_config is not None: + with open('packages.yml', 'w') as f: + yaml.safe_dump(self.packages_config, f, default_flow_style=True) + def load_config(self): + # we've written our profile and project. Now we want to instantiate a + # fresh adapter for the tests. # it's important to use a different connection handle here so # we don't look into an incomplete transaction - connection = adapter.acquire_connection(profile, '__test') - self.handle = connection.get('handle') - self.adapter_type = profile.get('type') - self._profile_config = profile_config - self._profile = profile + kwargs = { + 'profile': None, + 'profile_dir': DBT_CONFIG_DIR, + 'target': None, + } + + config = RuntimeConfig.from_args(TestArgs(kwargs)) + + adapter = get_adapter(config) + + adapter.cleanup_connections() + connection = adapter.acquire_connection(config, '__test') + self.handle = connection.handle + self.adapter_type = connection.type + self.adapter = adapter + self.config = config self._drop_schema() self._create_schema() def quote_as_configured(self, value, quote_key): - # we need this because some tests explicitly skip setUp - but they are - # all ok with default values here. - project = getattr(self, 'project', {}) return self.adapter.quote_as_configured( - self._profile, project, value, quote_key + self.config, value, quote_key ) - def tearDown(self): - os.remove(DBT_PROFILES) - os.remove("dbt_project.yml") - + def _clean_files(self): + if os.path.exists(DBT_PROFILES): + os.remove(DBT_PROFILES) + if os.path.exists('dbt_project.yml'): + os.remove("dbt_project.yml") + if os.path.exists('packages.yml'): + os.remove('packages.yml') # quick fix for windows bug that prevents us from deleting dbt_modules try: if os.path.exists('dbt_modules'): @@ -298,7 +291,10 @@ def tearDown(self): except: os.rename("dbt_modules", "dbt_modules-{}".format(time.time())) - self.adapter = get_adapter(self._profile) + def tearDown(self): + self._clean_files() + + self.adapter = get_adapter(self.config) self._drop_schema() @@ -310,7 +306,7 @@ def tearDown(self): def _create_schema(self): if self.adapter_type == 'bigquery': - self.adapter.create_schema(self._profile, self.project, self.unique_schema(), '__test') + self.adapter.create_schema(self.config, self.unique_schema(), '__test') else: schema = self.quote_as_configured(self.unique_schema(), 'schema') self.run_sql('CREATE SCHEMA {}'.format(schema)) @@ -318,8 +314,7 @@ def _create_schema(self): def _drop_schema(self): if self.adapter_type == 'bigquery': - self.adapter.drop_schema(self._profile, self.project, - self.unique_schema(), '__test') + self.adapter.drop_schema(self.config, self.unique_schema(), '__test') else: had_existing = False try: @@ -386,9 +381,9 @@ def run_sql_bigquery(self, sql, fetch): """Run an SQL query on a bigquery adapter. No cursors, transactions, etc. to worry about""" - adapter = get_adapter(self._profile) + adapter = get_adapter(self.config) do_fetch = fetch != 'None' - _, res = adapter.execute(self._profile, sql, fetch=do_fetch) + _, res = adapter.execute(self.config, sql, fetch=do_fetch) # convert dataframe to matrix-ish repr if fetch == 'one': @@ -451,8 +446,7 @@ def filter_many_columns(self, column): def get_table_columns(self, table, schema=None): schema = self.unique_schema() if schema is None else schema columns = self.adapter.get_columns_in_table( - self._profile, - self.project_config, + self.config, schema, table ) @@ -664,7 +658,8 @@ def outer(wrapped): @attr(type=profile_name) @wraps(wrapped) def func(self, *args, **kwargs): - self.use_profile(profile_name) return wrapped(self, *args, **kwargs) + # sanity check at import time + assert _profile_from_test_name(wrapped.__name__) == profile_name return func return outer diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index 1ca78796141..d1f0f402153 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -1,8 +1,9 @@ import unittest -from mock import patch +from mock import patch, MagicMock import dbt.flags as flags +from dbt.contracts.connection import BigQueryCredentials from dbt.adapters.bigquery import BigQueryAdapter from dbt.adapters.bigquery.relation import BigQueryRelation import dbt.exceptions @@ -15,20 +16,28 @@ class TestBigQueryAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True - self.oauth_profile = { - "type": "bigquery", - "method": "oauth", - "project": 'dbt-unit-000000', - "schema": "dummy_schema", - } - self.service_account_profile = { - "type": "bigquery", - "method": "service-account", - "project": 'dbt-unit-000000', - "schema": "dummy_schema", - "keyfile": "/tmp/dummy-service-account.json", - } + self.oauth_credentials = BigQueryCredentials( + method='oauth', + project='dbt-unit-000000', + schema='dummy_schema' + ) + self.oauth_profile = MagicMock( + credentials=self.oauth_credentials, + threads=1 + ) + + self.service_account_credentials = BigQueryCredentials( + method='service-account', + project='dbt-unit-000000', + schema='dummy_schema', + keyfile='/tmp/dummy-service-account.json' + ) + self.service_account_profile = MagicMock( + credentials=self.service_account_credentials, + threads=1 + ) + @patch('dbt.adapters.bigquery.BigQueryAdapter.open_connection', return_value=fake_conn) def test_acquire_connection_oauth_validations(self, mock_open_connection): diff --git a/test/unit/test_config.py b/test/unit/test_config.py index e15ab695501..4c0d3d3a819 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -1,66 +1,780 @@ +from copy import deepcopy +from contextlib import contextmanager +import json import os +import shutil +import tempfile import unittest + +import mock import yaml import dbt.config +from dbt.contracts.connection import PostgresCredentials, RedshiftCredentials +from dbt.contracts.project import PackageConfig + + +@contextmanager +def temp_cd(path): + current_path = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(current_path) -if os.name == 'nt': - TMPDIR = 'c:/Windows/TEMP' -else: - TMPDIR = '/tmp' class ConfigTest(unittest.TestCase): + def setUp(self): + self.base_dir = tempfile.mkdtemp() + self.profiles_path = os.path.join(self.base_dir, 'profiles.yml') def set_up_empty_config(self): - profiles_path = '{}/profiles.yml'.format(TMPDIR) - - with open(profiles_path, 'w') as f: + with open(self.profiles_path, 'w') as f: f.write(yaml.dump({})) def set_up_config_options(self, **kwargs): - profiles_path = '{}/profiles.yml'.format(TMPDIR) - config = { 'config': kwargs } - with open(profiles_path, 'w') as f: + with open(self.profiles_path, 'w') as f: f.write(yaml.dump(config)) def tearDown(self): - profiles_path = '{}/profiles.yml'.format(TMPDIR) - try: - os.remove(profiles_path) + shutil.rmtree(self.base_dir) except: pass def test__implicit_opt_in(self): self.set_up_empty_config() - config = dbt.config.read_config(TMPDIR) + config = dbt.config.read_config(self.base_dir) self.assertTrue(dbt.config.send_anonymous_usage_stats(config)) def test__explicit_opt_out(self): self.set_up_config_options(send_anonymous_usage_stats=False) - config = dbt.config.read_config(TMPDIR) + config = dbt.config.read_config(self.base_dir) self.assertFalse(dbt.config.send_anonymous_usage_stats(config)) def test__explicit_opt_in(self): self.set_up_config_options(send_anonymous_usage_stats=True) - config = dbt.config.read_config(TMPDIR) + config = dbt.config.read_config(self.base_dir) self.assertTrue(dbt.config.send_anonymous_usage_stats(config)) def test__implicit_colors(self): self.set_up_empty_config() - config = dbt.config.read_config(TMPDIR) + config = dbt.config.read_config(self.base_dir) self.assertTrue(dbt.config.colorize_output(config)) def test__explicit_opt_out(self): self.set_up_config_options(use_colors=False) - config = dbt.config.read_config(TMPDIR) + config = dbt.config.read_config(self.base_dir) self.assertFalse(dbt.config.colorize_output(config)) def test__explicit_opt_in(self): self.set_up_config_options(use_colors=True) - config = dbt.config.read_config(TMPDIR) + config = dbt.config.read_config(self.base_dir) self.assertTrue(dbt.config.colorize_output(config)) + + +class Args(object): + def __init__(self, profiles_dir=None, threads=None, profile=None, cli_vars=None): + self.profile = profile + if threads is not None: + self.threads = threads + if profiles_dir is not None: + self.profiles_dir = profiles_dir + if cli_vars is not None: + self.vars = cli_vars + + +class BaseConfigTest(unittest.TestCase): + """Subclass this, and before calling the superclass setUp, set + profiles_dir. + """ + def setUp(self): + self.default_project_data = { + 'version': '0.0.1', + 'name': 'my_test_project', + 'profile': 'default', + } + self.default_profile_data = { + 'default': { + 'outputs': { + 'postgres': { + 'type': 'postgres', + 'host': 'postgres-db-hostname', + 'port': 5555, + 'user': 'db_user', + 'pass': 'db_pass', + 'dbname': 'postgres-db-name', + 'schema': 'postgres-schema', + 'threads': 7, + }, + 'redshift': { + 'type': 'redshift', + 'host': 'redshift-db-hostname', + 'port': 5555, + 'user': 'db_user', + 'pass': 'db_pass', + 'dbname': 'redshift-db-name', + 'schema': 'redshift-schema', + }, + 'with-vars': { + 'type': "{{ env_var('env_value_type') }}", + 'host': "{{ env_var('env_value_host') }}", + 'port': "{{ env_var('env_value_port') }}", + 'user': "{{ env_var('env_value_user') }}", + 'pass': "{{ env_var('env_value_pass') }}", + 'dbname': "{{ env_var('env_value_dbname') }}", + 'schema': "{{ env_var('env_value_schema') }}", + } + }, + 'target': 'postgres', + }, + 'other': { + 'outputs': { + 'other-postgres': { + 'type': 'postgres', + 'host': 'other-postgres-db-hostname', + 'port': 4444, + 'user': 'other_db_user', + 'pass': 'other_db_pass', + 'dbname': 'other-postgres-db-name', + 'schema': 'other-postgres-schema', + 'threads': 2, + } + }, + 'target': 'other-postgres', + } + } + self.args = Args(profiles_dir=self.profiles_dir, cli_vars='{}') + self.env_override = { + 'env_value_type': 'postgres', + 'env_value_host': 'env-postgres-host', + 'env_value_port': '6543', + 'env_value_user': 'env-postgres-user', + 'env_value_pass': 'env-postgres-pass', + 'env_value_dbname': 'env-postgres-dbname', + 'env_value_schema': 'env-postgres-schema', + } + + +class BaseFileTest(BaseConfigTest): + def setUp(self): + self.project_dir = os.path.normpath(tempfile.mkdtemp()) + self.profiles_dir = os.path.normpath(tempfile.mkdtemp()) + super(BaseFileTest, self).setUp() + + def tearDown(self): + try: + shutil.rmtree(self.project_dir) + except EnvironmentError: + pass + try: + shutil.rmtree(self.profiles_dir) + except EnvironmentError: + pass + + def proejct_path(self, name): + return os.path.join(self.project_dir, name) + + def profile_path(self, name): + return os.path.join(self.profiles_dir, name) + + def write_project(self, project_data=None): + if project_data is None: + project_data = self.project_data + with open(self.proejct_path('dbt_project.yml'), 'w') as fp: + yaml.dump(project_data, fp) + + def write_packages(self, package_data): + with open(self.proejct_path('packages.yml'), 'w') as fp: + yaml.dump(package_data, fp) + + def write_profile(self, profile_data=None): + if profile_data is None: + profile_data = self.profile_data + with open(self.profile_path('profiles.yml'), 'w') as fp: + yaml.dump(profile_data, fp) + + +class TestProfile(BaseConfigTest): + def setUp(self): + self.profiles_dir = '/invalid-path' + super(TestProfile, self).setUp() + + def test_from_raw_profiles(self): + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'postgres') + self.assertEqual(profile.threads, 7) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertTrue(isinstance(profile.credentials, PostgresCredentials)) + self.assertEqual(profile.credentials.type, 'postgres') + self.assertEqual(profile.credentials.host, 'postgres-db-hostname') + self.assertEqual(profile.credentials.port, 5555) + self.assertEqual(profile.credentials.user, 'db_user') + self.assertEqual(profile.credentials.password, 'db_pass') + self.assertEqual(profile.credentials.schema, 'postgres-schema') + self.assertEqual(profile.credentials.dbname, 'postgres-db-name') + + def test_config_override(self): + self.default_profile_data['config'] = { + 'send_anonymous_usage_stats': False, + 'use_colors': False + } + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'postgres') + self.assertFalse(profile.send_anonymous_usage_stats) + self.assertFalse(profile.use_colors) + + def test_partial_config_override(self): + self.default_profile_data['config'] = { + 'send_anonymous_usage_stats': False, + } + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'postgres') + self.assertFalse(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + + def test_missing_type(self): + del self.default_profile_data['default']['outputs']['postgres']['type'] + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertIn('type', str(exc.exception)) + self.assertIn('postgres', str(exc.exception)) + self.assertIn('default', str(exc.exception)) + + def test_bad_type(self): + self.default_profile_data['default']['outputs']['postgres']['type'] = 'invalid' + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertIn('Credentials', str(exc.exception)) + self.assertIn('postgres', str(exc.exception)) + self.assertIn('default', str(exc.exception)) + + def test_invalid_credentials(self): + del self.default_profile_data['default']['outputs']['postgres']['host'] + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertIn('Credentials', str(exc.exception)) + self.assertIn('postgres', str(exc.exception)) + self.assertIn('default', str(exc.exception)) + + def test_target_missing(self): + del self.default_profile_data['default']['target'] + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertIn('target not specified in profile', str(exc.exception)) + self.assertIn('default', str(exc.exception)) + + def test_profile_invalid_project(self): + with self.assertRaises(dbt.config.DbtProjectError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'invalid-profile' + ) + + self.assertEqual(exc.exception.result_type, 'invalid_project') + self.assertIn('Could not find', str(exc.exception)) + self.assertIn('invalid-profile', str(exc.exception)) + + def test_profile_invalid_target(self): + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default', target_override='nope', + ) + + self.assertIn('nope', str(exc.exception)) + self.assertIn('- postgres', str(exc.exception)) + self.assertIn('- redshift', str(exc.exception)) + self.assertIn('- with-vars', str(exc.exception)) + + def test_no_outputs(self): + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = dbt.config.Profile.from_raw_profiles( + {'some-profile': {'target': 'blah'}}, 'some-profile' + ) + self.assertIn('outputs not specified', str(exc.exception)) + self.assertIn('some-profile', str(exc.exception)) + + def test_neq(self): + profile = dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default' + ) + self.assertNotEqual(profile, object()) + + def test_eq(self): + profile = dbt.config.Profile.from_raw_profiles( + deepcopy(self.default_profile_data), 'default' + ) + + other = dbt.config.Profile.from_raw_profiles( + deepcopy(self.default_profile_data), 'default' + ) + self.assertEqual(profile, other) + + def test_invalid_env_vars(self): + self.env_override['env_value_port'] = 'hello' + with mock.patch.dict(os.environ, self.env_override): + with self.assertRaises(dbt.config.DbtProfileError) as exc: + dbt.config.Profile.from_raw_profile_info( + self.default_profile_data['default'], + 'default', + target_override='with-vars' + ) + self.assertIn("not of type 'integer'", str(exc.exception)) + + +class TestProfileFile(BaseFileTest): + def setUp(self): + super(TestProfileFile, self).setUp() + self.write_profile(self.default_profile_data) + + def test_profile_simple(self): + profile = dbt.config.Profile.from_args(self.args, 'default') + from_raw = dbt.config.Profile.from_raw_profile_info( + self.default_profile_data['default'], + 'default' + ) + + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'postgres') + self.assertEqual(profile.threads, 7) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertTrue(isinstance(profile.credentials, PostgresCredentials)) + self.assertEqual(profile.credentials.type, 'postgres') + self.assertEqual(profile.credentials.host, 'postgres-db-hostname') + self.assertEqual(profile.credentials.port, 5555) + self.assertEqual(profile.credentials.user, 'db_user') + self.assertEqual(profile.credentials.password, 'db_pass') + self.assertEqual(profile.credentials.schema, 'postgres-schema') + self.assertEqual(profile.credentials.dbname, 'postgres-db-name') + self.assertEqual(profile, from_raw) + + def test_profile_override(self): + self.args.profile = 'other' + self.args.threads = 3 + profile = dbt.config.Profile.from_args(self.args, 'default') + from_raw = dbt.config.Profile.from_raw_profile_info( + self.default_profile_data['other'], + 'other', + threads_override=3, + ) + + self.assertEqual(profile.profile_name, 'other') + self.assertEqual(profile.target_name, 'other-postgres') + self.assertEqual(profile.threads, 3) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertTrue(isinstance(profile.credentials, PostgresCredentials)) + self.assertEqual(profile.credentials.type, 'postgres') + self.assertEqual(profile.credentials.host, 'other-postgres-db-hostname') + self.assertEqual(profile.credentials.port, 4444) + self.assertEqual(profile.credentials.user, 'other_db_user') + self.assertEqual(profile.credentials.password, 'other_db_pass') + self.assertEqual(profile.credentials.schema, 'other-postgres-schema') + self.assertEqual(profile.credentials.dbname, 'other-postgres-db-name') + self.assertEqual(profile, from_raw) + + def test_target_override(self): + self.args.target = 'redshift' + profile = dbt.config.Profile.from_args(self.args, 'default') + from_raw = dbt.config.Profile.from_raw_profile_info( + self.default_profile_data['default'], + 'default', + target_override='redshift' + ) + + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'redshift') + self.assertEqual(profile.threads, 1) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertTrue(isinstance(profile.credentials, RedshiftCredentials)) + self.assertEqual(profile.credentials.type, 'redshift') + self.assertEqual(profile.credentials.host, 'redshift-db-hostname') + self.assertEqual(profile.credentials.port, 5555) + self.assertEqual(profile.credentials.user, 'db_user') + self.assertEqual(profile.credentials.password, 'db_pass') + self.assertEqual(profile.credentials.schema, 'redshift-schema') + self.assertEqual(profile.credentials.dbname, 'redshift-db-name') + self.assertEqual(profile, from_raw) + + def test_env_vars(self): + self.args.target = 'with-vars' + with mock.patch.dict(os.environ, self.env_override): + profile = dbt.config.Profile.from_args(self.args, 'default') + from_raw = dbt.config.Profile.from_raw_profile_info( + self.default_profile_data['default'], + 'default', + target_override='with-vars' + ) + + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'with-vars') + self.assertEqual(profile.threads, 1) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertEqual(profile.credentials.type, 'postgres') + self.assertEqual(profile.credentials.host, 'env-postgres-host') + self.assertEqual(profile.credentials.port, 6543) + self.assertEqual(profile.credentials.user, 'env-postgres-user') + self.assertEqual(profile.credentials.password, 'env-postgres-pass') + self.assertEqual(profile, from_raw) + + def test_no_profile(self): + with self.assertRaises(dbt.config.DbtProjectError) as exc: + dbt.config.Profile.from_args(self.args) + self.assertIn('no profile was specified', str(exc.exception)) + + +class TestProject(BaseConfigTest): + def setUp(self): + self.profiles_dir = '/invalid-profiles-path' + self.project_dir = '/invalid-root-path' + super(TestProject, self).setUp() + self.default_project_data['project-root'] = self.project_dir + + def test_defaults(self): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertEqual(project.project_name, 'my_test_project') + self.assertEqual(project.version, '0.0.1') + self.assertEqual(project.profile_name, 'default') + self.assertEqual(project.project_root, '/invalid-root-path') + self.assertEqual(project.source_paths, ['models']) + self.assertEqual(project.macro_paths, ['macros']) + self.assertEqual(project.data_paths, ['data']) + self.assertEqual(project.test_paths, ['test']) + self.assertEqual(project.analysis_paths, []) + self.assertEqual(project.docs_paths, ['models']) + self.assertEqual(project.target_path, 'target') + self.assertEqual(project.clean_targets, ['target']) + self.assertEqual(project.log_path, 'logs') + self.assertEqual(project.modules_path, 'dbt_modules') + self.assertEqual(project.quoting, {}) + self.assertEqual(project.models, {}) + self.assertEqual(project.on_run_start, []) + self.assertEqual(project.on_run_end, []) + self.assertEqual(project.archive, []) + self.assertEqual(project.seeds, {}) + self.assertEqual(project.packages, PackageConfig(packages=[])) + # just make sure str() doesn't crash anything, that's always + # embarrassing + str(project) + + def test_eq(self): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + other = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertEqual(project, other) + + def test_neq(self): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertNotEqual(project, object()) + + def test_implicit_overrides(self): + self.default_project_data.update({ + 'source-paths': ['other-models'], + 'target-path': 'other-target', + }) + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertEqual(project.docs_paths, ['other-models']) + self.assertEqual(project.clean_targets, ['other-target']) + + def test_hashed_name(self): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertEqual(project.hashed_name(), '754cd47eac1d6f50a5f7cd399ec43da4') + + def test_all_overrides(self): + self.default_project_data.update({ + 'source-paths': ['other-models'], + 'macro-paths': ['other-macros'], + 'data-paths': ['other-data'], + 'test-paths': ['other-test'], + 'analysis-paths': ['analysis'], + 'docs-paths': ['docs'], + 'target-path': 'other-target', + 'clean-targets': ['another-target'], + 'log-path': 'other-logs', + 'modules-path': 'other-dbt_modules', + 'quoting': {'identifier': False}, + 'models': { + 'pre-hook': ['{{ logging.log_model_start_event() }}'], + 'post-hook': ['{{ logging.log_model_end_event() }}'], + 'my_test_project': { + 'first': { + 'enabled': False, + 'sub': { + 'enabled': True, + } + }, + 'second': { + 'materialized': 'table', + }, + }, + 'third_party': { + 'third': { + 'materialized': 'view', + }, + }, + }, + 'on-run-start': [ + '{{ logging.log_run_start_event() }}', + ], + 'on-run-end': [ + '{{ logging.log_run_end_event() }}', + ], + 'archive': [ + { + 'source_schema': 'my_schema', + 'target_schema': 'my_other_schema', + 'tables': [ + { + 'source_table': 'my_table', + 'target_Table': 'my_table_archived', + 'updated_at': 'updated_at_field', + 'unique_key': 'table_id', + }, + ], + }, + ], + 'seeds': { + 'my_test_project': { + 'enabled': True, + 'schema': 'seed_data', + 'post-hook': 'grant select on {{ this }} to bi_user', + }, + }, + }) + packages = { + 'packages': [ + { + 'local': 'foo', + }, + { + 'git': 'git@example.com:fishtown-analytics/dbt-utils.git', + 'revision': 'test-rev' + }, + ], + } + project = dbt.config.Project.from_project_config( + self.default_project_data, packages + ) + self.assertEqual(project.project_name, 'my_test_project') + self.assertEqual(project.version, '0.0.1') + self.assertEqual(project.profile_name, 'default') + self.assertEqual(project.project_root, '/invalid-root-path') + self.assertEqual(project.source_paths, ['other-models']) + self.assertEqual(project.macro_paths, ['other-macros']) + self.assertEqual(project.data_paths, ['other-data']) + self.assertEqual(project.test_paths, ['other-test']) + self.assertEqual(project.analysis_paths, ['analysis']) + self.assertEqual(project.docs_paths, ['docs']) + self.assertEqual(project.target_path, 'other-target') + self.assertEqual(project.clean_targets, ['another-target']) + self.assertEqual(project.log_path, 'other-logs') + self.assertEqual(project.modules_path, 'other-dbt_modules') + self.assertEqual(project.quoting, {'identifier': False}) + self.assertEqual(project.models, { + 'pre-hook': ['{{ logging.log_model_start_event() }}'], + 'post-hook': ['{{ logging.log_model_end_event() }}'], + 'my_test_project': { + 'first': { + 'enabled': False, + 'sub': { + 'enabled': True, + } + }, + 'second': { + 'materialized': 'table', + }, + }, + 'third_party': { + 'third': { + 'materialized': 'view', + }, + }, + }) + self.assertEqual(project.on_run_start, ['{{ logging.log_run_start_event() }}']) + self.assertEqual(project.on_run_end, ['{{ logging.log_run_end_event() }}']) + self.assertEqual(project.archive, [{ + 'source_schema': 'my_schema', + 'target_schema': 'my_other_schema', + 'tables': [ + { + 'source_table': 'my_table', + 'target_Table': 'my_table_archived', + 'updated_at': 'updated_at_field', + 'unique_key': 'table_id', + }, + ], + }]) + self.assertEqual(project.seeds, { + 'my_test_project': { + 'enabled': True, + 'schema': 'seed_data', + 'post-hook': 'grant select on {{ this }} to bi_user', + }, + }) + self.assertEqual(project.packages, PackageConfig(packages=[ + { + 'local': 'foo', + }, + { + 'git': 'git@example.com:fishtown-analytics/dbt-utils.git', + 'revision': 'test-rev' + }, + ])) + str(project) + json.dumps(project.to_project_config()) + + def test_invalid_project_name(self): + self.default_project_data['name'] = 'invalid-project-name' + with self.assertRaises(dbt.config.DbtProjectError) as exc: + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertIn('invalid-project-name', str(exc.exception)) + + def test_no_project(self): + with self.assertRaises(dbt.config.DbtProjectError) as exc: + dbt.config.Project.from_project_root(self.project_dir) + + self.assertIn('no dbt_project.yml', str(exc.exception)) + + +class TestProjectFile(BaseFileTest): + def setUp(self): + super(TestProjectFile, self).setUp() + self.write_project(self.default_project_data) + # and after the fact, add the project root + self.default_project_data['project-root'] = self.project_dir + + def test_from_project_root(self): + project = dbt.config.Project.from_project_root(self.project_dir) + from_config = dbt.config.Project.from_project_config( + self.default_project_data + ) + self.assertEqual(project, from_config) + + def test_with_invalid_package(self): + self.write_packages({'invalid': ['not a package of any kind']}) + with self.assertRaises(dbt.config.DbtProjectError) as exc: + dbt.config.Project.from_project_root(self.project_dir) + + +class TestRuntimeConfig(BaseConfigTest): + def setUp(self): + self.profiles_dir = '/invalid-profiles-path' + self.project_dir = '/invalid-root-path' + super(TestRuntimeConfig, self).setUp() + self.default_project_data['project-root'] = self.project_dir + + def get_project(self): + return dbt.config.Project.from_project_config( + self.default_project_data + ) + + def get_profile(self): + return dbt.config.Profile.from_raw_profiles( + self.default_profile_data, self.default_project_data['profile'] + ) + + def test_from_parts(self): + project = self.get_project() + profile = self.get_profile() + config = dbt.config.RuntimeConfig.from_parts(project, profile, {}) + + self.assertEqual(config.cli_vars, {}) + self.assertEqual(config.to_profile_info(), profile.to_profile_info()) + # we should have the default quoting set in the full config, but not in + # the project + # TODO(jeb): Adapters must assert that quoting is populated? + expected_project = project.to_project_config() + self.assertEqual(expected_project['quoting'], {}) + + expected_project['quoting'] = {'identifier': True, 'schema': True} + self.assertEqual(config.to_project_config(), expected_project) + + def test_str(self): + project = self.get_project() + profile = self.get_profile() + config = dbt.config.RuntimeConfig.from_parts(project, profile, {}) + + # to make sure nothing terrible happens + str(config) + + def test_validate_fails(self): + project = self.get_project() + profile = self.get_profile() + # invalid - must be boolean + profile.use_colors = None + with self.assertRaises(dbt.config.DbtProjectError): + dbt.config.RuntimeConfig.from_parts(project, profile, {}) + + +class TestRuntimeConfigFiles(BaseFileTest): + def setUp(self): + super(TestRuntimeConfigFiles, self).setUp() + self.write_profile(self.default_profile_data) + self.write_project(self.default_project_data) + # and after the fact, add the project root + self.default_project_data['project-root'] = self.project_dir + + def test_from_args(self): + with temp_cd(self.project_dir): + config = dbt.config.RuntimeConfig.from_args(self.args) + self.assertEqual(config.project_name, 'my_test_project') + self.assertEqual(config.version, '0.0.1') + self.assertEqual(config.profile_name, 'default') + # on osx, for example, these are not necessarily equal due to /private + self.assertTrue(os.path.samefile(config.project_root, + self.project_dir)) + self.assertEqual(config.source_paths, ['models']) + self.assertEqual(config.macro_paths, ['macros']) + self.assertEqual(config.data_paths, ['data']) + self.assertEqual(config.test_paths, ['test']) + self.assertEqual(config.analysis_paths, []) + self.assertEqual(config.docs_paths, ['models']) + self.assertEqual(config.target_path, 'target') + self.assertEqual(config.clean_targets, ['target']) + self.assertEqual(config.log_path, 'logs') + self.assertEqual(config.modules_path, 'dbt_modules') + self.assertEqual(config.quoting, {'identifier': True, 'schema': True}) + self.assertEqual(config.models, {}) + self.assertEqual(config.on_run_start, []) + self.assertEqual(config.on_run_end, []) + self.assertEqual(config.archive, []) + self.assertEqual(config.seeds, {}) + self.assertEqual(config.packages, PackageConfig(packages=[])) diff --git a/test/unit/test_docs_blocks.py b/test/unit/test_docs_blocks.py index 6a03760c7ef..d1dd2af7b85 100644 --- a/test/unit/test_docs_blocks.py +++ b/test/unit/test_docs_blocks.py @@ -1,11 +1,14 @@ import mock import unittest +from dbt.config import RuntimeConfig from dbt.node_types import NodeType import dbt.utils from dbt.parser import docs from dbt.contracts.graph.unparsed import UnparsedDocumentationFile +from .utils import config_from_parts_or_dicts + #DocumentationParser @@ -51,36 +54,42 @@ class DocumentationParserTest(unittest.TestCase): def setUp(self): - self.root_project_config = { - 'name': 'root', - 'version': '0.1', - 'profile': 'test', - 'project-root': '/test_root', - 'target': 'test', - 'quoting': {}, + profile_data = { 'outputs': { 'test': { 'type': 'postgres', 'host': 'localhost', 'schema': 'analytics', + 'user': 'test', + 'pass': 'test', + 'dbname': 'test', + 'port': 1, } - } + }, + 'target': 'test', + } + root_project = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/test_root', } - self.subdir_project_config = { + subdir_project = { 'name': 'some_package', 'version': '0.1', + 'profile': 'test', 'project-root': '/test_root/test_subdir', - 'target': 'test', 'quoting': {}, - 'outputs': { - 'test': { - 'type': 'postgres', - 'host': 'localhost', - 'schema': 'analytics', - } - } } + self.root_project_config = config_from_parts_or_dicts( + project=root_project, profile=profile_data + ) + self.subdir_project_config = config_from_parts_or_dicts( + project=subdir_project, profile=profile_data + ) + + @mock.patch('dbt.clients.system') def test_load_file(self, system): system.load_file_contents.return_value = TEST_DOCUMENTATION_FILE diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 25ca2a26de2..e663d199af8 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -9,7 +9,7 @@ import dbt.flags import dbt.linker import dbt.model -import dbt.project +import dbt.config import dbt.templates import dbt.utils @@ -17,6 +17,8 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa +from .utils import config_from_parts_or_dicts + class GraphTest(unittest.TestCase): @@ -37,22 +39,20 @@ def mock_write_gpickle(graph, outfile): self.graph_result = None - self.profiles = { - 'test': { - 'outputs': { - 'test': { - 'type': 'postgres', - 'threads': 4, - 'host': 'database', - 'port': 5432, - 'user': 'root', - 'pass': 'password', - 'dbname': 'dbt', - 'schema': 'dbt_test' - } - }, - 'target': 'test' - } + self.profile = { + 'outputs': { + 'test': { + 'type': 'postgres', + 'threads': 4, + 'host': 'database', + 'port': 5432, + 'user': 'root', + 'pass': 'password', + 'dbname': 'dbt', + 'schema': 'dbt_test' + } + }, + 'target': 'test' } self.real_dependency_projects = dbt.utils.dependency_projects @@ -84,7 +84,7 @@ def mock_load_file_contents(path): dbt.clients.system.load_file_contents = MagicMock( side_effect=mock_load_file_contents) - def get_project(self, extra_cfg=None): + def get_config(self, extra_cfg=None): if extra_cfg is None: extra_cfg = {} @@ -96,13 +96,7 @@ def get_project(self, extra_cfg=None): } cfg.update(extra_cfg) - project = dbt.project.Project( - cfg=cfg, - profiles=self.profiles, - profiles_dir=None) - - project.validate() - return project + return config_from_parts_or_dicts(project=cfg, profile=self.profile) def get_compiler(self, project): return dbt.compilation.Compiler(project) @@ -121,7 +115,7 @@ def test__single_model(self): 'model_one': 'select * from events', }) - compiler = self.get_compiler(self.get_project()) + compiler = self.get_compiler(self.get_config()) graph, linker = compiler.compile() self.assertEquals( @@ -138,7 +132,7 @@ def test__two_models_simple_ref(self): 'model_two': "select * from {{ref('model_one')}}", }) - compiler = self.get_compiler(self.get_project()) + compiler = self.get_compiler(self.get_config()) graph, linker = compiler.compile() six.assertCountEqual(self, @@ -172,7 +166,7 @@ def test__model_materializations(self): } } - compiler = self.get_compiler(self.get_project(cfg)) + compiler = self.get_compiler(self.get_config(cfg)) graph, linker = compiler.compile() expected_materialization = { @@ -207,7 +201,7 @@ def test__model_incremental(self): } } - compiler = self.get_compiler(self.get_project(cfg)) + compiler = self.get_compiler(self.get_config(cfg)) graph, linker = compiler.compile() node = 'model.test_models_compile.model_one' @@ -231,7 +225,7 @@ def test__dependency_list(self): 'model_4': 'select * from {{ ref("model_3") }}' }) - compiler = self.get_compiler(self.get_project({})) + compiler = self.get_compiler(self.get_config({})) graph, linker = compiler.compile() actual_dep_list = linker.as_dependency_list() diff --git a/test/unit/test_graph_selection.py b/test/unit/test_graph_selection.py index 96f2145d8bc..5d4865875bf 100644 --- a/test/unit/test_graph_selection.py +++ b/test/unit/test_graph_selection.py @@ -3,7 +3,6 @@ import os import string import dbt.graph.selector as graph_selector -import dbt.project import networkx as nx @@ -24,47 +23,6 @@ def setUp(self): for node in self.package_graph: self.package_graph.node[node]['fqn'] = node.split('.')[1:] - self.project = self.get_project() - - def get_project(self, extra_cfg=None): - if extra_cfg is None: - extra_cfg = {} - - cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': os.path.abspath('.'), - } - - profiles = { - 'test': { - 'outputs': { - 'test': { - 'type': 'postgres', - 'threads': 4, - 'host': 'database', - 'port': 5432, - 'user': 'root', - 'pass': 'password', - 'dbname': 'dbt', - 'schema': 'dbt_test' - } - }, - 'target': 'test' - } - } - - cfg.update(extra_cfg) - - project = dbt.project.Project( - cfg=cfg, - profiles=profiles, - profiles_dir=None) - - project.validate() - return project - def run_specs_and_assert(self, graph, include, exclude, expected): selected = graph_selector.select_nodes( graph, diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index f52770f12a7..a4fd33f8e7a 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -259,11 +259,11 @@ def test__to_flat_graph(self): def test_get_metadata(self, mock_user): mock_user.id = 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf' mock_user.do_not_track = True - project = mock.MagicMock() + config = mock.MagicMock() # md5 of 'test' - project.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' + config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' self.assertEqual( - Manifest.get_metadata(project), + Manifest.get_metadata(config), { 'project_id': '098f6bcd4621d373cade4e832627b4f6', 'user_id': 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', @@ -276,11 +276,11 @@ def test_get_metadata(self, mock_user): def test_no_nodes_with_metadata(self, mock_user): mock_user.id = 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf' mock_user.do_not_track = True - project = mock.MagicMock() + config = mock.MagicMock() # md5 of 'test' - project.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' + config.hashed_name.return_value = '098f6bcd4621d373cade4e832627b4f6' manifest = Manifest(nodes={}, macros={}, docs={}, - generated_at=timestring(), project=project) + generated_at=timestring(), config=config) metadata = { 'project_id': '098f6bcd4621d373cade4e832627b4f6', 'user_id': 'cfc9500f-dc7f-4c83-9ea7-2c581c1b38cf', diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index aecc7894820..58164da0ac0 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -8,12 +8,15 @@ import dbt.parser from dbt.parser import ModelParser, MacroParser, DataTestParser, SchemaParser, ParserUtils from dbt.utils import timestring +from dbt.config import RuntimeConfig from dbt.node_types import NodeType from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.parsed import ParsedNode, ParsedMacro, ParsedNodePatch from dbt.contracts.graph.unparsed import UnparsedNode +from .utils import config_from_parts_or_dicts + def get_os_path(unix_path): return os.path.normpath(unix_path) @@ -30,11 +33,7 @@ def setUp(self): self.maxDiff = None - self.root_project_config = { - 'name': 'root', - 'version': '0.1', - 'profile': 'test', - 'project-root': os.path.abspath('.'), + profile_data = { 'target': 'test', 'quoting': {}, 'outputs': { @@ -42,25 +41,38 @@ def setUp(self): 'type': 'postgres', 'host': 'localhost', 'schema': 'analytics', + 'user': 'test', + 'pass': 'test', + 'dbname': 'test', + 'port': 1, } } } - self.snowplow_project_config = { + root_project = { + 'name': 'root', + 'version': '0.1', + 'profile': 'test', + 'project-root': os.path.abspath('.'), + } + + + self.root_project_config = config_from_parts_or_dicts( + project=root_project, + profile=profile_data + ) + + snowplow_project = { 'name': 'snowplow', 'version': '0.1', + 'profile': 'test', 'project-root': os.path.abspath('./dbt_modules/snowplow'), - 'target': 'test', - 'quoting': {}, - 'outputs': { - 'test': { - 'type': 'postgres', - 'host': 'localhost', - 'schema': 'analytics', - } - } } + self.snowplow_project_config = config_from_parts_or_dicts( + project=snowplow_project, profile=profile_data + ) + self.model_config = { 'enabled': True, 'materialized': 'view', @@ -137,7 +149,7 @@ def test__single_model__nested_configuration(self): 'raw_sql': ("select * from events"), }] - self.root_project_config['models'] = { + self.root_project_config.models = { 'materialized': 'ephemeral', 'root': { 'nested': { @@ -876,7 +888,7 @@ def test__in_model_config(self): ) def test__root_project_config(self): - self.root_project_config['models'] = { + self.root_project_config.models = { 'materialized': 'ephemeral', 'root': { 'view': { @@ -1010,7 +1022,7 @@ def test__root_project_config(self): ) def test__other_project_config(self): - self.root_project_config['models'] = { + self.root_project_config.models = { 'materialized': 'ephemeral', 'root': { 'view': { @@ -1029,7 +1041,7 @@ def test__other_project_config(self): } } - self.snowplow_project_config['models'] = { + self.snowplow_project_config.models = { 'snowplow': { 'enabled': False, 'views': { @@ -1625,10 +1637,11 @@ def test__simple_macro(self): self.assertTrue(callable(result['macro.root.simple'].generator)) + self.assertEqual( result, { - 'macro.root.simple': { + 'macro.root.simple': ParsedMacro(**{ 'name': 'simple', 'resource_type': 'macro', 'unique_id': 'macro.root.simple', @@ -1641,7 +1654,7 @@ def test__simple_macro(self): 'tags': [], 'path': 'simple_macro.sql', 'raw_sql': macro_file_contents, - } + }) } ) @@ -1664,7 +1677,7 @@ def test__simple_macro_used_in_model(self): self.assertEqual( result, { - 'macro.root.simple': { + 'macro.root.simple': ParsedMacro(**{ 'name': 'simple', 'resource_type': 'macro', 'unique_id': 'macro.root.simple', @@ -1677,7 +1690,7 @@ def test__simple_macro_used_in_model(self): 'tags': [], 'path': 'simple_macro.sql', 'raw_sql': macro_file_contents, - } + }), } ) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 567f9af41e0..b65410b0a7c 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -10,26 +10,34 @@ from psycopg2 import extensions as psycopg2_extensions import agate +from dbt.config import Profile + class TestPostgresAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True - self.profile = { - 'dbname': 'postgres', - 'user': 'root', - 'host': 'database', - 'pass': 'password', - 'port': 5432, - 'schema': 'public' - } + self.profile = Profile.from_raw_profile_info({ + 'outputs': { + 'test': { + 'type': 'postgres', + 'dbname': 'postgres', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5432, + 'schema': 'public' + } + }, + 'target': 'test' + }, 'test') def test_acquire_connection_validations(self): try: connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') - self.assertEquals(connection.get('type'), 'postgres') + self.assertEquals(connection.type, 'postgres') except ValidationException as e: self.fail('got ValidationException: {}'.format(str(e))) except BaseException as e: @@ -39,8 +47,8 @@ def test_acquire_connection_validations(self): def test_acquire_connection(self): connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') - self.assertEquals(connection.get('state'), 'open') - self.assertNotEquals(connection.get('handle'), None) + self.assertEquals(connection.state, 'open') + self.assertNotEquals(connection.handle, None) @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_default_keepalive(self, psycopg2): @@ -56,7 +64,8 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - self.profile['keepalives_idle'] = 256 + credentials = self.profile.credentials.incorporate(keepalives_idle=256) + self.profile.credentials = credentials connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') psycopg2.connect.assert_called_once_with( @@ -70,7 +79,8 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_set_zero_keepalive(self, psycopg2): - self.profile['keepalives_idle'] = 0 + credentials = self.profile.credentials.incorporate(keepalives_idle=0) + self.profile.credentials = credentials connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') psycopg2.connect.assert_called_once_with( diff --git a/test/unit/test_project.py b/test/unit/test_project.py deleted file mode 100644 index 866044d4963..00000000000 --- a/test/unit/test_project.py +++ /dev/null @@ -1,85 +0,0 @@ -import unittest - -import os -import dbt.project - - -class ProjectTest(unittest.TestCase): - def setUp(self): - self.profiles = { - 'test': { - 'outputs': { - 'test': { - 'type': 'postgres', - 'threads': 4, - 'host': 'database', - 'port': 5432, - 'user': 'root', - 'pass': 'password', - 'dbname': 'dbt', - 'schema': 'dbt_test' - } - }, - 'target': 'test' - } - } - self.cfg = { - 'name': 'X', - 'version': '0.1', - 'profile': 'test', - 'project-root': os.path.abspath('.'), - } - - def test_profile_validate_success(self): - # Make sure we can instantiate + validate a valid profile - - project = dbt.project.Project( - cfg=self.cfg, - profiles=self.profiles, - profiles_dir=None - ) - - project.validate() - - def test_profile_validate_missing(self): - del self.profiles['test']['outputs']['test']['schema'] - - project = dbt.project.Project( - cfg=self.cfg, - profiles=self.profiles, - profiles_dir=None - ) - - message = r'.*schema.* is a required property.*' - with self.assertRaisesRegexp(dbt.project.DbtProjectError, message): - project.validate() - - def test_profile_validate_extra(self): - self.profiles['test']['outputs']['test']['foo'] = 'bar' - - project = dbt.project.Project( - cfg=self.cfg, - profiles=self.profiles, - profiles_dir=None - ) - - message = r'.*not allowed.*foo.* was unexpected.*' - with self.assertRaisesRegexp(dbt.project.DbtProjectError, message): - project.validate() - - def test_profile_validate_missing_and_extra(self): - del self.profiles['test']['outputs']['test']['schema'] - self.profiles['test']['outputs']['test']['foo'] = 'bar' - - project = dbt.project.Project( - cfg=self.cfg, - profiles=self.profiles, - profiles_dir=None - ) - - unrecognized = r'not allowed.*foo.* was unexpected' - extra = r'schema.* is a required property' - # fun with regexp ordering: want both, don't care about order - message = '.*({0}.*{1}|{1}.*{0}).*'.format(unrecognized, extra) - with self.assertRaisesRegexp(dbt.project.DbtProjectError, message): - project.validate() diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index f8b4432e3bd..4b713737d3d 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -8,6 +8,8 @@ from dbt.adapters.redshift import RedshiftAdapter from dbt.exceptions import ValidationException, FailedToConnectException from dbt.logger import GLOBAL_LOGGER as logger # noqa +from dbt.config import Profile + @classmethod def fetch_cluster_credentials(*args, **kwargs): @@ -21,52 +23,61 @@ class TestRedshiftAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True - self.profile = { - 'dbname': 'redshift', - 'user': 'root', - 'host': 'database', - 'pass': 'password', - 'port': 5439, - 'schema': 'public' - } + self.profile = Profile.from_raw_profile_info({ + 'outputs': { + 'test': { + 'type': 'redshift', + 'dbname': 'redshift', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5439, + 'schema': 'public' + } + }, + 'target': 'test' + }, 'test') def test_implicit_database_conn(self): - creds = RedshiftAdapter.get_credentials(self.profile) - self.assertEquals(creds, self.profile) + creds = RedshiftAdapter.get_credentials(self.profile.credentials) + self.assertEquals(creds, self.profile.credentials) def test_explicit_database_conn(self): - self.profile['method'] = 'database' + self.profile.method = 'database' - creds = RedshiftAdapter.get_credentials(self.profile) - self.assertEquals(creds, self.profile) + creds = RedshiftAdapter.get_credentials(self.profile.credentials) + self.assertEquals(creds, self.profile.credentials) def test_explicit_iam_conn(self): - self.profile.update({ - 'method': 'iam', - 'cluster_id': 'my_redshift', - 'iam_duration_s': 1200, - }) + self.profile.credentials = self.profile.credentials.incorporate( + method='iam', + cluster_id='my_redshift', + iam_duration_seconds=1200 + ) with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - creds = RedshiftAdapter.get_credentials(self.profile) + creds = RedshiftAdapter.get_credentials(self.profile.credentials) - expected_creds = dbt.utils.merge(self.profile, {'pass': 'tmp_password'}) + expected_creds = self.profile.credentials.incorporate(password='tmp_password') self.assertEquals(creds, expected_creds) def test_invalid_auth_method(self): - self.profile['method'] = 'badmethod' + # we have to set method this way, otherwise it won't validate + self.profile.credentials._contents['method'] = 'badmethod' with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.get_credentials(self.profile) + RedshiftAdapter.get_credentials(self.profile.credentials) self.assertTrue('badmethod' in context.exception.msg) def test_invalid_iam_no_cluster_id(self): - self.profile['method'] = 'iam' + self.profile.credentials = self.profile.credentials.incorporate( + method='iam' + ) with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.get_credentials(self.profile) + RedshiftAdapter.get_credentials(self.profile.credentials) self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @@ -86,7 +97,9 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - self.profile['keepalives_idle'] = 256 + self.profile.credentials = self.profile.credentials.incorporate( + keepalives_idle=256 + ) connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') psycopg2.connect.assert_called_once_with( @@ -100,7 +113,9 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_set_zero_keepalive(self, psycopg2): - self.profile['keepalives_idle'] = 0 + self.profile.credentials = self.profile.credentials.incorporate( + keepalives_idle=0 + ) connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') psycopg2.connect.assert_called_once_with( diff --git a/test/unit/utils.py b/test/unit/utils.py new file mode 100644 index 00000000000..d09fae907e1 --- /dev/null +++ b/test/unit/utils.py @@ -0,0 +1,23 @@ +"""Unit test utility functions. + +Note that all imports should be inside the functions to avoid import/mocking +issues. +""" + +def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): + from dbt.config import Project, Profile, RuntimeConfig + from dbt.utils import parse_cli_vars + from copy import deepcopy + if not isinstance(project, Project): + project = Project.from_project_config(deepcopy(project), packages) + if not isinstance(profile, Profile): + profile = Profile.from_raw_profile_info(deepcopy(profile), + project.profile_name) + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) + + return RuntimeConfig.from_parts( + project=project, + profile=profile, + cli_vars=cli_vars + ) From 582f9f91436fbb6cd2a326c4c816299204e9ea8a Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 11 Sep 2018 10:19:59 -0600 Subject: [PATCH 006/133] make __eq__ type checks symmetrical --- dbt/config.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index 0326bd3d148..d56826f5daf 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -248,7 +248,8 @@ def __str__(self): return pprint.pformat(cfg) def __eq__(self, other): - if not isinstance(other, self.__class__): + if not (isinstance(other, self.__class__) and + isinstance(self, other.__class__)): return False return self.to_project_config(with_packages=True) == \ other.to_project_config(with_packages=True) @@ -360,7 +361,9 @@ def __str__(self): return pprint.pformat(self.to_profile_info()) def __eq__(self, other): - if not isinstance(other, self.__class__): + if not (isinstance(other, self.__class__) and + isinstance(self, other.__class__)): + return False return False return self.to_profile_info() == other.to_profile_info() From 273af5368fb3d13ae0d1e9ecee7c7ad239ccbf72 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 11 Sep 2018 10:22:08 -0600 Subject: [PATCH 007/133] add more explicit dbt.config imports for readability --- dbt/main.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/dbt/main.py b/dbt/main.py index d5135b2df3d..126b9d0b872 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -21,12 +21,14 @@ import dbt.task.serve as serve_task import dbt.tracking -import dbt.config as config import dbt.ui.printer import dbt.compat import dbt.deprecations from dbt.utils import ExitCodes +from dbt.config import Project, RuntimeConfig, DbtProjectError, \ + DbtProfileError, DEFAULT_PROFILES_DIR, read_config, \ + send_anonymous_usage_stats, colorize_output, read_profiles PROFILES_HELP_MESSAGE = """ @@ -108,13 +110,13 @@ def handle_and_check(args): # this needs to happen after args are parsed so we can determine the # correct profiles.yml file - profile_config = config.read_config(parsed.profiles_dir) - if not config.send_anonymous_usage_stats(profile_config): + profile_config = read_config(parsed.profiles_dir) + if not send_anonymous_usage_stats(profile_config): dbt.tracking.do_not_track() else: dbt.tracking.initialize_tracking() - if dbt.config.colorize_output(profile_config): + if colorize_output(profile_config): dbt.ui.printer.use_colors() try: @@ -207,16 +209,16 @@ def invoke_dbt(parsed): try: if parsed.which == 'deps': # deps doesn't need a profile, so don't require one. - cfg = config.Project.from_current_directory() + cfg = Project.from_current_directory() elif parsed.which != 'debug': # for debug, we will attempt to load the various configurations as # part of the task, so just leave cfg=None. - cfg = config.RuntimeConfig.from_args(parsed) - except config.DbtProjectError as e: + cfg = RuntimeConfig.from_args(parsed) + except DbtProjectError as e: logger.info("Encountered an error while reading the project:") logger.info(dbt.compat.to_string(e)) - all_profiles = config.read_profiles(parsed.profiles_dir).keys() + all_profiles = read_profiles(parsed.profiles_dir).keys() if len(all_profiles) > 0: logger.info("Defined profiles:") @@ -234,7 +236,7 @@ def invoke_dbt(parsed): result_type=e.result_type) return None - except config.DbtProfileError as e: + except DbtProfileError as e: logger.info("Encountered an error while reading profiles:") logger.info(" ERROR {}".format(str(e))) @@ -293,11 +295,11 @@ def parse_args(args): base_subparser.add_argument( '--profiles-dir', - default=config.DEFAULT_PROFILES_DIR, + default=DEFAULT_PROFILES_DIR, type=str, help=""" Which directory to look in for the profiles.yml file. Default = {} - """.format(config.DEFAULT_PROFILES_DIR) + """.format(DEFAULT_PROFILES_DIR) ) base_subparser.add_argument( From b4772bc3b68970c824532d35d32deebde8b4735c Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 11 Sep 2018 10:27:09 -0600 Subject: [PATCH 008/133] Missed a project->config --- dbt/runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/runner.py b/dbt/runner.py index 1949b31caae..9a359b17c86 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -32,7 +32,7 @@ def __init__(self, config): def deserialize_graph(self): logger.info("Loading dependency graph file.") - base_target_path = self.project.target_path + base_target_path = self.config.target_path graph_file = os.path.join( base_target_path, dbt.compilation.graph_file_name From 6652eced954916aa83787de7019d2a8335890950 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 11 Sep 2018 10:27:46 -0600 Subject: [PATCH 009/133] warn on unspecified test name --- test/integration/base.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/test/integration/base.py b/test/integration/base.py index aae81e99b83..e83bc2dc32b 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -15,6 +15,8 @@ from dbt.config import RuntimeConfig from dbt.logger import GLOBAL_LOGGER as logger +import logging +import warnings DBT_CONFIG_DIR = os.path.abspath( @@ -40,19 +42,19 @@ def __init__(self, kwargs): def _profile_from_test_name(test_name): - adapters_in_name = sum(x in test_name for x in - ('postgres', 'snowflake', 'redshift', 'bigquery')) + adapter_names = ('postgres', 'snowflake', 'redshift', 'bigquery') + adapters_in_name = sum(x in test_name for x in adapter_names) if adapters_in_name > 1: raise ValueError('test names must only have 1 profile choice embedded') - if 'snowflake' in test_name: - return 'snowflake' - elif 'redshift' in test_name: - return 'redshift' - elif 'bigquery' in test_name: - return 'bigquery' - else: - return 'postgres' + for adapter_name in adapter_names: + if adapter_name in test_name: + return adapter_name + + warnings.warn( + 'could not find adapter name in test name {}'.format(test_name) + ) + return 'postgres' class DBTIntegrationTest(unittest.TestCase): @@ -201,6 +203,8 @@ def _pick_profile(self): def setUp(self): flags.reset() + # disable capturing warnings + logging.captureWarnings(False) self._clean_files() self.use_profile(self._pick_profile()) From 18a5e44dbcc9805c0988e3c78c1a3796ea970904 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 13 Sep 2018 13:12:08 -0600 Subject: [PATCH 010/133] fix unit tests --- test/unit/test_postgres_adapter.py | 75 +++++++++++++++++------------ test/unit/test_snowflake_adapter.py | 39 ++++++++------- 2 files changed, 65 insertions(+), 49 deletions(-) diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index b65410b0a7c..a3ac793411d 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -10,15 +10,20 @@ from psycopg2 import extensions as psycopg2_extensions import agate -from dbt.config import Profile +from .utils import config_from_parts_or_dicts class TestPostgresAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True - - self.profile = Profile.from_raw_profile_info({ + project_cfg = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + } + profile_cfg = { 'outputs': { 'test': { 'type': 'postgres', @@ -31,11 +36,13 @@ def setUp(self): } }, 'target': 'test' - }, 'test') + } + + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) def test_acquire_connection_validations(self): try: - connection = PostgresAdapter.acquire_connection(self.profile, + connection = PostgresAdapter.acquire_connection(self.config, 'dummy') self.assertEquals(connection.type, 'postgres') except ValidationException as e: @@ -45,14 +52,14 @@ def test_acquire_connection_validations(self): .format(str(e))) def test_acquire_connection(self): - connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') + connection = PostgresAdapter.acquire_connection(self.config, 'dummy') self.assertEquals(connection.state, 'open') self.assertNotEquals(connection.handle, None) @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_default_keepalive(self, psycopg2): - connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') + connection = PostgresAdapter.acquire_connection(self.config, 'dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -64,9 +71,9 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - credentials = self.profile.credentials.incorporate(keepalives_idle=256) - self.profile.credentials = credentials - connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') + credentials = self.config.credentials.incorporate(keepalives_idle=256) + self.config.credentials = credentials + connection = PostgresAdapter.acquire_connection(self.config, 'dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -79,9 +86,9 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_set_zero_keepalive(self, psycopg2): - credentials = self.profile.credentials.incorporate(keepalives_idle=0) - self.profile.credentials = credentials - connection = PostgresAdapter.acquire_connection(self.profile, 'dummy') + credentials = self.config.credentials.incorporate(keepalives_idle=0) + self.config.credentials = credentials + connection = PostgresAdapter.acquire_connection(self.config, 'dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -114,26 +121,32 @@ def test_get_catalog_various_schemas(self, mock_run): # give manifest the dict it wants mock_manifest = mock.MagicMock(spec_set=['nodes'], nodes=nodes) - catalog = PostgresAdapter.get_catalog({}, {}, mock_manifest) + catalog = PostgresAdapter.get_catalog(mock.MagicMock(), mock_manifest) self.assertEqual( set(map(tuple, catalog)), {('foo', 'bar'), ('FOO', 'baz'), ('quux', 'bar')} ) + class TestConnectingPostgresAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = False - self.profile = { - 'dbname': 'postgres', - 'user': 'root', - 'host': 'database', - 'pass': 'password', - 'port': 5432, - 'schema': 'public' + profile_cfg = { + 'outputs': { + 'test': { + 'type': 'postgres', + 'dbname': 'postgres', + 'user': 'root', + 'host': 'database', + 'pass': 'password', + 'port': 5432, + 'schema': 'public' + } + }, + 'target': 'test' } - - self.project = { + project_cfg = { 'name': 'X', 'version': '0.1', 'profile': 'test', @@ -144,6 +157,8 @@ def setUp(self): } } + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + self.handle = mock.MagicMock(spec=psycopg2_extensions.connection) self.cursor = self.handle.cursor.return_value self.mock_execute = self.cursor.execute @@ -151,7 +166,7 @@ def setUp(self): self.psycopg2 = self.patcher.start() self.psycopg2.connect.return_value = self.handle - conn = PostgresAdapter.get_connection(self.profile) + conn = PostgresAdapter.get_connection(self.config) def tearDown(self): # we want a unique self.handle every time. @@ -160,8 +175,7 @@ def tearDown(self): def test_quoting_on_drop_schema(self): PostgresAdapter.drop_schema( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema' ) @@ -171,8 +185,7 @@ def test_quoting_on_drop_schema(self): def test_quoting_on_drop(self): PostgresAdapter.drop( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema', relation='test_table', relation_type='table' @@ -183,8 +196,7 @@ def test_quoting_on_drop(self): def test_quoting_on_truncate(self): PostgresAdapter.truncate( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema', table='test_table' ) @@ -194,8 +206,7 @@ def test_quoting_on_truncate(self): def test_quoting_on_rename(self): PostgresAdapter.rename( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema', from_name='table_a', to_name='table_b' diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 71746e2e30b..9aa290c7f57 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -9,20 +9,28 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa from snowflake import connector as snowflake_connector +from .utils import config_from_parts_or_dicts + class TestSnowflakeAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = False - self.profile = { - 'dbname': 'postgres', - 'user': 'root', - 'host': 'database', - 'pass': 'password', - 'port': 5432, - 'schema': 'public' + profile_cfg = { + 'outputs': { + 'test': { + 'type': 'snowflake', + 'account': 'test_account', + 'user': 'test_user', + 'password': 'test_password', + 'database': 'test_databse', + 'warehouse': 'test_warehouse', + 'schema': 'public', + }, + }, + 'target': 'test', } - self.project = { + project_cfg = { 'name': 'X', 'version': '0.1', 'profile': 'test', @@ -32,6 +40,7 @@ def setUp(self): 'schema': True, } } + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) self.handle = mock.MagicMock(spec=snowflake_connector.SnowflakeConnection) self.cursor = self.handle.cursor.return_value @@ -40,7 +49,7 @@ def setUp(self): self.snowflake = self.patcher.start() self.snowflake.return_value = self.handle - conn = SnowflakeAdapter.get_connection(self.profile) + conn = SnowflakeAdapter.get_connection(self.config) def tearDown(self): # we want a unique self.handle every time. @@ -49,8 +58,7 @@ def tearDown(self): def test_quoting_on_drop_schema(self): SnowflakeAdapter.drop_schema( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema' ) @@ -60,8 +68,7 @@ def test_quoting_on_drop_schema(self): def test_quoting_on_drop(self): SnowflakeAdapter.drop( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema', relation='test_table', relation_type='table' @@ -72,8 +79,7 @@ def test_quoting_on_drop(self): def test_quoting_on_truncate(self): SnowflakeAdapter.truncate( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema', table='test_table' ) @@ -83,8 +89,7 @@ def test_quoting_on_truncate(self): def test_quoting_on_rename(self): SnowflakeAdapter.rename( - profile=self.profile, - project_cfg=self.project, + config=self.config, schema='test_schema', from_name='table_a', to_name='table_b' From 28ef796d479f216217aa6c68ae7ce8838e9dc72b Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Sep 2018 09:23:38 -0600 Subject: [PATCH 011/133] add extra logic around connection release in the finally block to avoid raising during exception handling --- dbt/node_runners.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 01cb12e5ed1..458be5a93dd 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -16,6 +16,8 @@ import dbt.templates import dbt.writer +import six +import sys import time @@ -112,16 +114,44 @@ def safe_run(self, manifest): prefix=dbt.ui.printer.red(prefix), error=str(e).strip()) - logger.debug(error) + logger.error(error) raise e finally: - node_name = self.node.name - self.adapter.release_connection(self.config, node_name) + exc_info = sys.exc_info() + exc_str = self._safe_release_connection() + + # if we had an unhandled exception, re-raise it + if exc_info and exc_info[1]: + six.reraise(*exc_info) + + # if releasing failed and the result doesn't have an error yet, set + # an error + if exc_str is not None and result.error is None: + result.error = exc_str + result.status = 'ERROR' result.execution_time = time.time() - started return result + def _safe_release_connection(self): + """Try to release a connection. If an exception is hit, log and return + the error string. + """ + node_name = self.node.name + try: + self.adapter.release_connection(self.config, node_name) + except Exception as exc: + # log it + logger.error( + 'Error releasing connection for node {}: {!s}' + .format(node_name, exc) + ) + logger.debug(traceback.format_exc()) + return dbt.compat.to_string(exc) + + return None + def before_execute(self): raise NotImplementedException() From c58daa1dc920f32db0bddc770c6fb0c50f2d613d Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 13 Sep 2018 09:05:34 -0600 Subject: [PATCH 012/133] fix an error where a non-string is passed to RuntimeException for a message on python 2.7 --- dbt/exceptions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dbt/exceptions.py b/dbt/exceptions.py index 7bf3d2566b0..e28011dd193 100644 --- a/dbt/exceptions.py +++ b/dbt/exceptions.py @@ -69,7 +69,8 @@ def __str__(self, prefix="! "): if hasattr(self.msg, 'split'): split_msg = self.msg.split("\n") else: - split_msg = basestring(self.msg).split("\n") + # can't use basestring here, as on python2 it's an abstract class + split_msg = str(self.msg).split("\n") lines = ["{}{}".format(self.type + ' Error', node_string)] + split_msg From fb970192cd05f9fcc7f3c508f9584e7c7d5768f8 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 13 Sep 2018 09:58:39 -0600 Subject: [PATCH 013/133] handle python 2 quirks --- dbt/node_runners.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 458be5a93dd..f6ee27b5560 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -75,6 +75,7 @@ def safe_run(self, manifest): result = RunModelResult(self.node) started = time.time() + exc_info = (None, None, None) try: # if we fail here, we still have a compiled node to return @@ -107,6 +108,9 @@ def safe_run(self, manifest): result.status = 'ERROR' except Exception as e: + # set this here instead of finally, as python 2/3 exc_info() + # behavior with re-raised exceptions are slightly different + exc_info = sys.exc_info() prefix = "Unhandled error while executing {filepath}".format( filepath=self.node.build_path) @@ -118,7 +122,6 @@ def safe_run(self, manifest): raise e finally: - exc_info = sys.exc_info() exc_str = self._safe_release_connection() # if we had an unhandled exception, re-raise it From 22d4a1d73dde634885c5e4d73eabe4a9bb40f55b Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Sep 2018 09:00:12 -0600 Subject: [PATCH 014/133] PR feedback --- dbt/node_runners.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/dbt/node_runners.py b/dbt/node_runners.py index f6ee27b5560..aca890c74b0 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -19,6 +19,7 @@ import six import sys import time +import traceback INTERNAL_ERROR_STRING = """This is an error in dbt. Please try again. If \ @@ -145,12 +146,10 @@ def _safe_release_connection(self): try: self.adapter.release_connection(self.config, node_name) except Exception as exc: - # log it - logger.error( - 'Error releasing connection for node {}: {!s}' - .format(node_name, exc) + logger.debug( + 'Error releasing connection for node {}: {!s}\n{}' + .format(node_name, exc, traceback.format_exc()) ) - logger.debug(traceback.format_exc()) return dbt.compat.to_string(exc) return None From 15b13054d1d7da5bb9405264fc5d5104c31cf3bc Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Sep 2018 08:10:02 -0600 Subject: [PATCH 015/133] Include datasets with underscores when listing BigQuery datasets Co-authored-by: kf Fellows --- dbt/adapters/bigquery/impl.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index c102c0a3479..e7b0fcea239 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -495,7 +495,7 @@ def get_existing_schemas(cls, config, model_name=None): client = conn.handle with cls.exception_handler(config, 'list dataset', model_name): - all_datasets = client.list_datasets() + all_datasets = client.list_datasets(include_all=True) return [ds.dataset_id for ds in all_datasets] @classmethod @@ -535,9 +535,10 @@ def get_dbt_columns_from_bq_table(cls, table): @classmethod def check_schema_exists(cls, config, schema, model_name=None): conn = cls.get_connection(config, model_name) + client = conn.handle with cls.exception_handler(config, 'get dataset', model_name): - all_datasets = conn.handle.list_datasets() + all_datasets = client.list_datasets(include_all=True) return any([ds.dataset_id == schema for ds in all_datasets]) @classmethod From 9a91aa25844e686a02fed41a62b35a98cf28f846 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Sep 2018 08:13:37 -0600 Subject: [PATCH 016/133] remove never-called method --- dbt/node_runners.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 01cb12e5ed1..30d99d8f909 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -248,13 +248,6 @@ def call_already_exists(schema, table): "already_exists": call_already_exists, } - @classmethod - def create_schemas(cls, config, adapter, manifest): - required_schemas = cls.get_model_schemas(manifest) - existing_schemas = set(adapter.get_existing_schemas(config)) - for schema in (required_schemas - existing_schemas): - adapter.create_schema(config, schema) - class ModelRunner(CompileRunner): From c11cd92b831c54064664b8a103369b00ce4875df Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 14 Sep 2018 08:42:25 -0600 Subject: [PATCH 017/133] add a test that already passed anyway --- .../test_simple_bigquery_view.py | 36 +++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/test/integration/022_bigquery_test/test_simple_bigquery_view.py b/test/integration/022_bigquery_test/test_simple_bigquery_view.py index 2e49990a0e2..df824e6da59 100644 --- a/test/integration/022_bigquery_test/test_simple_bigquery_view.py +++ b/test/integration/022_bigquery_test/test_simple_bigquery_view.py @@ -1,7 +1,8 @@ from test.integration.base import DBTIntegrationTest, FakeArgs, use_profile +import random +import time - -class TestSimpleBigQueryRun(DBTIntegrationTest): +class TestBaseBigQueryRun(DBTIntegrationTest): @property def schema(self): @@ -22,6 +23,9 @@ def project_config(self): def profile_config(self): return self.bigquery_profile() + +class TestSimpleBigQueryRun(TestBaseBigQueryRun): + @use_profile('bigquery') def test__bigquery_simple_run(self): # make sure seed works twice. Full-refresh is a no-op @@ -68,3 +72,31 @@ def test__bigquery_exists_non_destructive(self): self.assertFalse(result.skipped) # status = # of failing rows self.assertEqual(result.status, 0) + + +class TestUnderscoreBigQueryRun(TestBaseBigQueryRun): + prefix = "_test{}{:04}".format(int(time.time()), random.randint(0, 9999)) + + @use_profile('bigquery') + def test_bigquery_run_twice(self): + self.run_dbt(['seed']) + results = self.run_dbt() + self.assertEqual(len(results), 4) + results = self.run_dbt() + self.assertEqual(len(results), 4) + + # The 'dupe' model should fail, but all others should pass + test_results = self.run_dbt(['test'], expect_pass=False) + + for result in test_results: + if 'dupe' in result.node.get('name'): + self.assertFalse(result.errored) + self.assertFalse(result.skipped) + self.assertTrue(result.status > 0) + + # assert that actual tests pass + else: + self.assertFalse(result.errored) + self.assertFalse(result.skipped) + # status = # of failing rows + self.assertEqual(result.status, 0) From 8e84f53c65bba0eb5197e1580d81b48aa2b8ac74 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 18 Sep 2018 14:24:43 -0600 Subject: [PATCH 018/133] make adapters into objects, fix unit tests --- dbt/adapters/bigquery/impl.py | 265 +++++++++----------- dbt/adapters/default/impl.py | 365 ++++++++++++---------------- dbt/adapters/factory.py | 29 ++- dbt/adapters/postgres/impl.py | 46 ++-- dbt/adapters/redshift/impl.py | 17 +- dbt/adapters/snowflake/impl.py | 68 +++--- dbt/context/common.py | 117 ++++----- dbt/node_runners.py | 36 ++- dbt/parser/base.py | 2 +- dbt/runner.py | 2 +- dbt/task/generate.py | 2 +- test/unit/test_bigquery_adapter.py | 59 +++-- test/unit/test_postgres_adapter.py | 42 ++-- test/unit/test_redshift_adapter.py | 58 +++-- test/unit/test_snowflake_adapter.py | 17 +- 15 files changed, 521 insertions(+), 604 deletions(-) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index e7b0fcea239..1757c564a39 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -54,9 +54,6 @@ class BigQueryAdapter(PostgresAdapter): "get_columns_in_table" ] - Relation = BigQueryRelation - Column = dbt.schema.BigQueryColumn - SCOPE = ('https://www.googleapis.com/auth/bigquery', 'https://www.googleapis.com/auth/cloud-platform', 'https://www.googleapis.com/auth/drive') @@ -68,6 +65,8 @@ class BigQueryAdapter(PostgresAdapter): } QUERY_TIMEOUT = 300 + Relation = BigQueryRelation + Column = dbt.schema.BigQueryColumn @classmethod def handle_error(cls, error, message, sql): @@ -78,20 +77,19 @@ def handle_error(cls, error, message, sql): raise dbt.exceptions.DatabaseException(error_msg) - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, + def exception_handler(self, sql, model_name=None, connection_name='master'): try: yield except google.cloud.exceptions.BadRequest as e: message = "Bad request while running:\n{sql}" - cls.handle_error(e, message, sql) + self.handle_error(e, message, sql) except google.cloud.exceptions.Forbidden as e: message = "Access denied while running:\n{sql}" - cls.handle_error(e, message, sql) + self.handle_error(e, message, sql) except Exception as e: logger.debug("Unhandled error while running:\n{}".format(sql)) @@ -106,12 +104,10 @@ def type(cls): def date_function(cls): return 'CURRENT_TIMESTAMP()' - @classmethod - def begin(cls, config, name='master'): + def begin(self, name): pass - @classmethod - def commit(cls, config, connection): + def commit(self, connection): pass @classmethod @@ -181,12 +177,11 @@ def close(cls, connection): return connection - @classmethod - def list_relations(cls, config, schema, model_name=None): - connection = cls.get_connection(config, model_name) + def list_relations(self, schema, model_name=None): + connection = self.get_connection(model_name) client = connection.handle - bigquery_dataset = cls.get_dataset(config, schema, model_name) + bigquery_dataset = self.get_dataset(schema, model_name) all_tables = client.list_tables( bigquery_dataset, @@ -203,12 +198,11 @@ def list_relations(cls, config, schema, model_name=None): # This will 404 if the dataset does not exist. This behavior mirrors # the implementation of list_relations for other adapters try: - return [cls.bq_table_to_relation(table) for table in all_tables] + return [self.bq_table_to_relation(table) for table in all_tables] except google.api_core.exceptions.NotFound as e: return [] - @classmethod - def get_relation(cls, config, schema=None, identifier=None, + def get_relation(self, schema=None, identifier=None, relations_list=None, model_name=None): if schema is None and relations_list is None: raise dbt.exceptions.RuntimeException( @@ -216,32 +210,27 @@ def get_relation(cls, config, schema=None, identifier=None, 'of relations to use') if relations_list is None and identifier is not None: - table = cls.get_bq_table(config, schema, identifier) + table = self.get_bq_table(schema, identifier) - return cls.bq_table_to_relation(table) + return self.bq_table_to_relation(table) - return super(BigQueryAdapter, cls).get_relation( - config, schema, identifier, relations_list, + return super(BigQueryAdapter, self).get_relation( + schema, identifier, relations_list, model_name) - @classmethod - def drop_relation(cls, config, relation, model_name=None): - conn = cls.get_connection(config, model_name) + def drop_relation(self, relation, model_name=None): + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, relation.schema, model_name) + dataset = self.get_dataset(relation.schema, model_name) relation_object = dataset.table(relation.identifier) client.delete_table(relation_object) - @classmethod - def rename(cls, config, schema, - from_name, to_name, model_name=None): + def rename(self, schema, from_name, to_name, model_name=None): raise dbt.exceptions.NotImplementedException( '`rename` is not implemented for this adapter!') - @classmethod - def rename_relation(cls, config, from_relation, to_relation, - model_name=None): + def rename_relation(self, from_relation, to_relation, model_name=None): raise dbt.exceptions.NotImplementedException( '`rename_relation` is not implemented for this adapter!') @@ -250,13 +239,12 @@ def get_timeout(cls, conn): credentials = conn['credentials'] return credentials.get('timeout_seconds', cls.QUERY_TIMEOUT) - @classmethod - def materialize_as_view(cls, config, dataset, model): + def materialize_as_view(self, dataset, model): model_name = model.get('name') model_alias = model.get('alias') model_sql = model.get('injected_sql') - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle view_ref = dataset.table(model_alias) @@ -266,7 +254,7 @@ def materialize_as_view(cls, config, dataset, model): logger.debug("Model SQL ({}):\n{}".format(model_name, model_sql)) - with cls.exception_handler(config, model_sql, model_name, model_name): + with self.exception_handler(model_sql, model_name, model_name): client.create_table(view) return "CREATE VIEW" @@ -286,26 +274,24 @@ def poll_until_job_completes(cls, job, timeout): elif job.error_result: raise job.exception() - @classmethod - def make_date_partitioned_table(cls, config, dataset_name, identifier, + def make_date_partitioned_table(self, dataset_name, identifier, model_name=None): - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, dataset_name, identifier) + dataset = self.get_dataset(dataset_name, identifier) table_ref = dataset.table(identifier) table = google.cloud.bigquery.Table(table_ref) table.partitioning_type = 'DAY' return client.create_table(table) - @classmethod - def materialize_as_table(cls, config, dataset, model, model_sql, + def materialize_as_table(self, dataset, model, model_sql, decorator=None): model_name = model.get('name') model_alias = model.get('alias') - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle if decorator is None: @@ -322,14 +308,13 @@ def materialize_as_table(cls, config, dataset, model, model_sql, query_job = client.query(model_sql, job_config=job_config) # this waits for the job to complete - with cls.exception_handler(config, model_sql, model_alias, - model_name): - query_job.result(timeout=cls.get_timeout(conn)) + with self.exception_handler(model_sql, model_alias, + model_name): + query_job.result(timeout=self.get_timeout(conn)) return "CREATE TABLE" - @classmethod - def execute_model(cls, config, model, + def execute_model(self, model, materialization, sql_override=None, decorator=None, model_name=None): @@ -337,20 +322,19 @@ def execute_model(cls, config, model, sql_override = model.get('injected_sql') if flags.STRICT_MODE: - connection = cls.get_connection(config, model.get('name')) + connection = self.get_connection(model.get('name')) Connection(**connection) model_name = model.get('name') model_schema = model.get('schema') - dataset = cls.get_dataset(config, - model_schema, model_name) + dataset = self.get_dataset(model_schema, model_name) if materialization == 'view': - res = cls.materialize_as_view(config, dataset, model) + res = self.materialize_as_view(dataset, model) elif materialization == 'table': - res = cls.materialize_as_table( - config, dataset, model, + res = self.materialize_as_table( + dataset, model, sql_override, decorator) else: msg = "Invalid relation type: '{}'".format(materialization) @@ -358,9 +342,8 @@ def execute_model(cls, config, model, return res - @classmethod - def raw_execute(cls, config, sql, model_name=None, fetch=False, **kwargs): - conn = cls.get_connection(config, model_name) + def raw_execute(self, sql, model_name=None, fetch=False, **kwargs): + conn = self.get_connection(model_name) client = conn.handle logger.debug('On %s: %s', model_name, sql) @@ -370,20 +353,18 @@ def raw_execute(cls, config, sql, model_name=None, fetch=False, **kwargs): query_job = client.query(sql, job_config) # this blocks until the query has completed - with cls.exception_handler(config, sql, model_name): + with self.exception_handler(sql, model_name): iterator = query_job.result() return query_job, iterator - @classmethod - def create_temporary_table(cls, config, sql, model_name=None, - **kwargs): + def create_temporary_table(self, sql, model_name=None, **kwargs): # BQ queries always return a temp table with their results - query_job, _ = cls.raw_execute(config, sql, model_name) + query_job, _ = self.raw_execute(sql, model_name) bq_table = query_job.destination - return cls.Relation.create( + return self.Relation.create( project=bq_table.project, schema=bq_table.dataset_id, identifier=bq_table.table_id, @@ -393,17 +374,15 @@ def create_temporary_table(cls, config, sql, model_name=None, }, type=BigQueryRelation.Table) - @classmethod - def alter_table_add_columns(cls, config, relation, columns, - model_name=None): + def alter_table_add_columns(self, relation, columns, model_name=None): logger.debug('Adding columns ({}) to table {}".'.format( columns, relation)) - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, relation.schema, model_name) + dataset = self.get_dataset(relation.schema, model_name) table_ref = dataset.table(relation.name) table = client.get_table(table_ref) @@ -414,13 +393,11 @@ def alter_table_add_columns(cls, config, relation, columns, new_table = google.cloud.bigquery.Table(table_ref, schema=new_schema) client.update_table(new_table, ['schema']) - @classmethod - def execute(cls, config, sql, model_name=None, fetch=None, **kwargs): - _, iterator = cls.raw_execute(config, sql, model_name, fetch, - **kwargs) + def execute(self, sql, model_name=None, fetch=None, **kwargs): + _, iterator = self.raw_execute(sql, model_name, fetch, **kwargs) if fetch: - res = cls.get_table_from_response(iterator) + res = self.get_table_from_response(iterator) else: res = dbt.clients.agate_helper.empty_table() @@ -428,9 +405,8 @@ def execute(cls, config, sql, model_name=None, fetch=None, **kwargs): status = 'OK' return status, res - @classmethod - def execute_and_fetch(cls, config, sql, model_name, auto_begin=None): - status, table = cls.execute(config, sql, model_name, fetch=True) + def execute_and_fetch(self, sql, model_name, auto_begin=None): + status, table = self.execute(sql, model_name, fetch=True) return status, table @classmethod @@ -441,118 +417,106 @@ def get_table_from_response(cls, resp): # BigQuery doesn't support BEGIN/COMMIT, so stub these out. - @classmethod - def add_begin_query(cls, config, name): + def add_begin_query(self, name): pass - @classmethod - def add_commit_query(cls, config, name): + def add_commit_query(self, name): pass - @classmethod - def create_schema(cls, config, schema, model_name=None): + def create_schema(self, schema, model_name=None): logger.debug('Creating schema "%s".', schema) - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, schema, model_name) + dataset = self.get_dataset(schema, model_name) # Emulate 'create schema if not exists ...' try: client.get_dataset(dataset) except google.api_core.exceptions.NotFound: - with cls.exception_handler(config, 'create dataset', model_name): + with self.exception_handler('create dataset', model_name): client.create_dataset(dataset) - @classmethod - def drop_tables_in_schema(cls, config, dataset): - conn = cls.get_connection(config) + def drop_tables_in_schema(self, dataset): + conn = self.get_connection() client = conn.handle for table in client.list_tables(dataset): client.delete_table(table.reference) - @classmethod - def drop_schema(cls, config, schema, model_name=None): + def drop_schema(self, schema, model_name=None): logger.debug('Dropping schema "%s".', schema) - if not cls.check_schema_exists(config, - schema, model_name): + if not self.check_schema_exists(schema, model_name): return - conn = cls.get_connection(config) + conn = self.get_connection(model_name) client = conn.handle - dataset = cls.get_dataset(config, schema, model_name) - with cls.exception_handler(config, 'drop dataset', model_name): - cls.drop_tables_in_schema(config, dataset) + dataset = self.get_dataset(schema, model_name) + with self.exception_handler('drop dataset', model_name): + self.drop_tables_in_schema(dataset) client.delete_dataset(dataset) - @classmethod - def get_existing_schemas(cls, config, model_name=None): - conn = cls.get_connection(config, model_name) + def get_existing_schemas(self, model_name=None): + conn = self.get_connection(model_name) client = conn.handle - with cls.exception_handler(config, 'list dataset', model_name): + with self.exception_handler('list dataset', model_name): all_datasets = client.list_datasets(include_all=True) return [ds.dataset_id for ds in all_datasets] - @classmethod - def get_columns_in_table(cls, config, schema_name, table_name, + def get_columns_in_table(self, schema_name, table_name, database=None, model_name=None): # BigQuery does not have databases -- the database parameter is here # for consistency with the base implementation - conn = cls.get_connection(config, model_name) + conn = self.get_connection(model_name) client = conn.handle try: dataset_ref = client.dataset(schema_name) table_ref = dataset_ref.table(table_name) table = client.get_table(table_ref) - return cls.get_dbt_columns_from_bq_table(table) + return self.get_dbt_columns_from_bq_table(table) except (ValueError, google.cloud.exceptions.NotFound) as e: logger.debug("get_columns_in_table error: {}".format(e)) return [] - @classmethod - def get_dbt_columns_from_bq_table(cls, table): + def get_dbt_columns_from_bq_table(self, table): "Translates BQ SchemaField dicts into dbt BigQueryColumn objects" columns = [] for col in table.schema: # BigQuery returns type labels that are not valid type specifiers - dtype = cls.Column.translate_type(col.field_type) - column = cls.Column( + dtype = self.Column.translate_type(col.field_type) + column = self.Column( col.name, dtype, col.fields, col.mode) columns.append(column) return columns - @classmethod - def check_schema_exists(cls, config, schema, model_name=None): - conn = cls.get_connection(config, model_name) + def check_schema_exists(self, schema, model_name=None): + conn = self.get_connection(model_name) client = conn.handle - with cls.exception_handler(config, 'get dataset', model_name): + with self.exception_handler('get dataset', model_name): all_datasets = client.list_datasets(include_all=True) return any([ds.dataset_id == schema for ds in all_datasets]) - @classmethod - def get_dataset(cls, config, dataset_name, model_name=None): - conn = cls.get_connection(config, model_name) + def get_dataset(self, dataset_name, model_name=None): + conn = self.get_connection(model_name) dataset_ref = conn.handle.dataset(dataset_name) return google.cloud.bigquery.Dataset(dataset_ref) - @classmethod - def bq_table_to_relation(cls, bq_table): + def bq_table_to_relation(self, bq_table): if bq_table is None: return None - return cls.Relation.create( + return self.Relation.create( project=bq_table.project, schema=bq_table.dataset_id, identifier=bq_table.table_id, @@ -560,13 +524,12 @@ def bq_table_to_relation(cls, bq_table): 'schema': True, 'identifier': True }, - type=cls.RELATION_TYPES.get(bq_table.table_type)) + type=self.RELATION_TYPES.get(bq_table.table_type)) - @classmethod - def get_bq_table(cls, config, dataset_name, identifier, model_name=None): - conn = cls.get_connection(config, model_name) + def get_bq_table(self, dataset_name, identifier, model_name=None): + conn = self.get_connection(model_name) - dataset = cls.get_dataset(config, dataset_name, model_name) + dataset = self.get_dataset(dataset_name, model_name) table_ref = dataset.table(identifier) @@ -576,16 +539,15 @@ def get_bq_table(cls, config, dataset_name, identifier, model_name=None): return None @classmethod - def warning_on_hooks(cls, hook_type): + def warning_on_hooks(hook_type): msg = "{} is not supported in bigquery and will be ignored" dbt.ui.printer.print_timestamped_line(msg.format(hook_type), dbt.ui.printer.COLOR_FG_YELLOW) - @classmethod - def add_query(cls, config, sql, model_name=None, auto_begin=True, + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): if model_name in ['on-run-start', 'on-run-end']: - cls.warning_on_hooks(model_name) + self.warning_on_hooks(model_name) else: raise dbt.exceptions.NotImplementedException( '`add_query` is not implemented for this adapter!') @@ -598,16 +560,13 @@ def is_cancelable(cls): def quote(cls, identifier): return '`{}`'.format(identifier) - @classmethod - def quote_schema_and_table(cls, config, schema, - table, model_name=None): - return cls.render_relation(config, cls.quote(schema), cls.quote(table)) + def quote_schema_and_table(self, schema, table, model_name=None): + return self.render_relation(self.quote(schema), self.quote(table)) - @classmethod - def render_relation(cls, config, schema, table): - connection = cls.get_connection(config) + def render_relation(cls, schema, table): + connection = self.get_connection() project = connection.credentials.project - return '{}.{}.{}'.format(cls.quote(project), schema, table) + return '{}.{}.{}'.format(self.quote(project), schema, table) @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -636,13 +595,12 @@ def _agate_to_schema(cls, agate_table, column_override): google.cloud.bigquery.SchemaField(col_name, type_)) return bq_schema - @classmethod - def load_dataframe(cls, config, schema, table_name, agate_table, + def load_dataframe(self, schema, table_name, agate_table, column_override, model_name=None): - bq_schema = cls._agate_to_schema(agate_table, column_override) - dataset = cls.get_dataset(config, schema, None) + bq_schema = self._agate_to_schema(agate_table, column_override) + dataset = self.get_dataset(schema, None) table = dataset.table(table_name) - conn = cls.get_connection(config, None) + conn = self.get_connection(None) client = conn.handle load_config = google.cloud.bigquery.LoadJobConfig() @@ -653,21 +611,19 @@ def load_dataframe(cls, config, schema, table_name, agate_table, job = client.load_table_from_file(f, table, rewind=True, job_config=load_config) - with cls.exception_handler(config, "LOAD TABLE"): - cls.poll_until_job_completes(job, cls.get_timeout(conn)) + with self.exception_handler("LOAD TABLE"): + self.poll_until_job_completes(job, self.get_timeout(conn)) - @classmethod - def expand_target_column_types(cls, config, temp_table, + def expand_target_column_types(self, temp_table, to_schema, to_table, model_name=None): # This is a no-op on BigQuery pass - @classmethod - def _flat_columns_in_table(cls, table): + def _flat_columns_in_table(self, table): """An iterator over the flattened columns for a given schema and table. Resolves child columns as having the name "parent.child". """ - for col in cls.get_dbt_columns_from_bq_table(table): + for col in self.get_dbt_columns_from_bq_table(table): flattened = col.flatten() for subcol in flattened: yield subcol @@ -727,9 +683,8 @@ def _get_stats_columns(cls, table, relation_type): ) return zip(column_names, column_values) - @classmethod - def get_catalog(cls, config, manifest): - connection = cls.get_connection(config, 'catalog') + def get_catalog(self, manifest): + connection = self.get_connection('catalog') client = connection.handle schemas = { @@ -749,11 +704,11 @@ def get_catalog(cls, config, manifest): 'column_type', 'column_comment', ) - all_names = column_names + cls._get_stats_column_names() + all_names = column_names + self._get_stats_column_names() columns = [] for schema_name in schemas: - relations = cls.list_relations(config, schema_name) + relations = self.list_relations(schema_name) for relation in relations: # This relation contains a subset of the info we care about. @@ -762,9 +717,9 @@ def get_catalog(cls, config, manifest): table_ref = dataset_ref.table(relation.identifier) table = client.get_table(table_ref) - flattened = cls._flat_columns_in_table(table) - relation_stats = dict(cls._get_stats_columns(table, - relation.type)) + flattened = self._flat_columns_in_table(table) + relation_stats = dict(self._get_stats_columns(table, + relation.type)) for index, column in enumerate(flattened, start=1): column_data = ( diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 5fdf83667f5..3a879ec36de 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -88,17 +88,18 @@ class DefaultAdapter(object): "quote", "convert_type" ] - Relation = DefaultRelation Column = Column + def __init__(self, config): + self.config = config + ### # ADAPTER-SPECIFIC FUNCTIONS -- each of these must be overridden in # every adapter ### - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, + def exception_handler(self, sql, model_name=None, connection_name=None): raise dbt.exceptions.NotImplementedException( '`exception_handler` is not implemented for this adapter!') @@ -118,15 +119,12 @@ def get_status(cls, cursor): raise dbt.exceptions.NotImplementedException( '`get_status` is not implemented for this adapter!') - @classmethod - def alter_column_type(cls, config, schema, table, - column_name, new_column_type, model_name=None): + def alter_column_type(self, schema, table, column_name, new_column_type, + model_name=None): raise dbt.exceptions.NotImplementedException( '`alter_column_type` is not implemented for this adapter!') - @classmethod - def query_for_existing(cls, config, schemas, - model_name=None): + def query_for_existing(self, schemas, model_name=None): if not isinstance(schemas, (list, tuple)): schemas = [schemas] @@ -134,23 +132,20 @@ def query_for_existing(cls, config, schemas, for schema in schemas: all_relations.extend( - cls.list_relations(config, schema, model_name)) + self.list_relations(schema, model_name)) return {relation.identifier: relation.type for relation in all_relations} - @classmethod - def get_existing_schemas(cls, config, model_name=None): + def get_existing_schemas(self, model_name=None): raise dbt.exceptions.NotImplementedException( '`get_existing_schemas` is not implemented for this adapter!') - @classmethod - def check_schema_exists(cls, config, schema): + def check_schema_exists(self, schema): raise dbt.exceptions.NotImplementedException( '`check_schema_exists` is not implemented for this adapter!') - @classmethod - def cancel_connection(cls, config, connection): + def cancel_connection(self, connection): raise dbt.exceptions.NotImplementedException( '`cancel_connection` is not implemented for this adapter!') @@ -170,19 +165,17 @@ def get_result_from_cursor(cls, cursor): return dbt.clients.agate_helper.table_from_data(data, column_names) - @classmethod - def drop(cls, config, schema, relation, relation_type, model_name=None): + def drop(self, schema, relation, relation_type, model_name=None): identifier = relation - relation = cls.Relation.create( + relation = self.Relation.create( schema=schema, identifier=identifier, type=relation_type, - quote_policy=config.quoting) + quote_policy=self.config.quoting) - return cls.drop_relation(config, relation, model_name) + return self.drop_relation(relation, model_name) - @classmethod - def drop_relation(cls, config, relation, model_name=None): + def drop_relation(self, relation, model_name=None): if relation.type is None: dbt.exceptions.raise_compiler_error( 'Tried to drop relation {}, but its type is null.' @@ -190,72 +183,65 @@ def drop_relation(cls, config, relation, model_name=None): sql = 'drop {} if exists {} cascade'.format(relation.type, relation) - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) - @classmethod - def truncate(cls, config, schema, table, model_name=None): - relation = cls.Relation.create( + def truncate(self, schema, table, model_name=None): + relation = self.Relation.create( schema=schema, identifier=table, type='table', - quote_policy=config.quoting) + quote_policy=self.config.quoting) - return cls.truncate_relation(config, relation, model_name) + return self.truncate_relation(relation, model_name) - @classmethod - def truncate_relation(cls, config, - relation, model_name=None): + def truncate_relation(self, relation, model_name=None): sql = 'truncate table {}'.format(relation) - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) - @classmethod - def rename(cls, config, schema, - from_name, to_name, model_name=None): - quote_policy = config.quoting - from_relation = cls.Relation.create( + def rename(self, schema, from_name, to_name, model_name=None): + quote_policy = self.config.quoting + from_relation = self.Relation.create( schema=schema, identifier=from_name, quote_policy=quote_policy ) - to_relation = cls.Relation.create( + to_relation = self.Relation.create( identifier=to_name, quote_policy=quote_policy ) - return cls.rename_relation( - config, + return self.rename_relation( from_relation=from_relation, to_relation=to_relation, model_name=model_name) - @classmethod - def rename_relation(cls, config, from_relation, to_relation, + def rename_relation(self, from_relation, to_relation, model_name=None): sql = 'alter table {} rename to {}'.format( from_relation, to_relation.include(schema=False)) - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) @classmethod def is_cancelable(cls): return True - @classmethod - def get_missing_columns(cls, config, - from_schema, from_table, - to_schema, to_table, - model_name=None): + def get_missing_columns(self, from_schema, from_table, + to_schema, to_table, model_name=None): """Returns dict of {column:type} for columns in from_table that are missing from to_table""" - from_columns = {col.name: col for col in - cls.get_columns_in_table( - config, from_schema, from_table, - model_name=model_name)} - to_columns = {col.name: col for col in - cls.get_columns_in_table( - config, to_schema, to_table, - model_name=model_name)} + from_columns = { + col.name: col for col in + self.get_columns_in_table( + from_schema, from_table, + model_name=model_name) + } + to_columns = { + col.name: col for col in + self.get_columns_in_table( + to_schema, to_table, + model_name=model_name) + } missing_columns = set(from_columns.keys()) - set(to_columns.keys()) @@ -287,18 +273,17 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): return sql - @classmethod - def get_columns_in_table(cls, config, schema_name, + def get_columns_in_table(self, schema_name, table_name, database=None, model_name=None): - sql = cls._get_columns_in_table_sql(schema_name, table_name, database) - connection, cursor = cls.add_query(config, sql, model_name) + sql = self._get_columns_in_table_sql(schema_name, table_name, database) + connection, cursor = self.add_query(sql, model_name) data = cursor.fetchall() columns = [] for row in data: name, data_type, char_size, numeric_size = row - column = cls.Column(name, data_type, char_size, numeric_size) + column = self.Column(name, data_type, char_size, numeric_size) columns.append(column) return columns @@ -307,20 +292,19 @@ def get_columns_in_table(cls, config, schema_name, def _table_columns_to_dict(cls, columns): return {col.name: col for col in columns} - @classmethod - def expand_target_column_types(cls, config, + def expand_target_column_types(self, temp_table, to_schema, to_table, model_name=None): - reference_columns = cls._table_columns_to_dict( - cls.get_columns_in_table( - config, None, temp_table, model_name=model_name)) + reference_columns = self._table_columns_to_dict( + self.get_columns_in_table(None, temp_table, model_name=model_name) + ) - target_columns = cls._table_columns_to_dict( - cls.get_columns_in_table( - config, to_schema, to_table, - model_name=model_name)) + target_columns = self._table_columns_to_dict( + self.get_columns_in_table(to_schema, to_table, + model_name=model_name) + ) for column_name, reference_column in reference_columns.items(): target_column = target_columns.get(column_name) @@ -328,38 +312,35 @@ def expand_target_column_types(cls, config, if target_column is not None and \ target_column.can_expand_to(reference_column): col_string_size = reference_column.string_size() - new_type = cls.Column.string_type(col_string_size) + new_type = self.Column.string_type(col_string_size) logger.debug("Changing col type from %s to %s in table %s.%s", target_column.data_type, new_type, to_schema, to_table) - cls.alter_column_type(config, to_schema, - to_table, column_name, new_type, - model_name) + self.alter_column_type(to_schema, to_table, column_name, + new_type, model_name) ### # RELATIONS ### - @classmethod - def list_relations(cls, config, schema, model_name=None): + def list_relations(self, schema, model_name=None): raise dbt.exceptions.NotImplementedException( '`list_relations` is not implemented for this adapter!') - @classmethod - def _make_match_kwargs(cls, config, schema, identifier): - if identifier is not None and config.quoting['identifier'] is False: + def _make_match_kwargs(self, schema, identifier): + quoting = self.config.quoting + if identifier is not None and quoting['identifier'] is False: identifier = identifier.lower() - if schema is not None and config.quoting['schema'] is False: + if schema is not None and quoting['schema'] is False: schema = schema.lower() return filter_null_values({'identifier': identifier, 'schema': schema}) - @classmethod - def get_relation(cls, config, schema=None, identifier=None, + def get_relation(self, schema=None, identifier=None, relations_list=None, model_name=None): if schema is None and relations_list is None: raise dbt.exceptions.RuntimeException( @@ -367,11 +348,11 @@ def get_relation(cls, config, schema=None, identifier=None, 'of relations to use') if relations_list is None: - relations_list = cls.list_relations(config, schema, model_name) + relations_list = self.list_relations(schema, model_name) matches = [] - search = cls._make_match_kwargs(config, schema, identifier) + search = self._make_match_kwargs(schema, identifier) for relation in relations_list: if relation.matches(**search): @@ -389,16 +370,14 @@ def get_relation(cls, config, schema=None, identifier=None, ### # SANE ANSI SQL DEFAULTS ### - @classmethod - def get_create_schema_sql(cls, config, schema): - schema = cls._quote_as_configured(config, schema, 'schema') + def get_create_schema_sql(self, schema): + schema = self.quote_as_configured(schema, 'schema') return ('create schema if not exists {schema}' .format(schema=schema)) - @classmethod - def get_drop_schema_sql(cls, config, schema): - schema = cls._quote_as_configured(config, schema, 'schema') + def get_drop_schema_sql(self, schema): + schema = self.quote_as_configured(schema, 'schema') return ('drop schema if exists {schema} cascade' .format(schema=schema)) @@ -407,12 +386,10 @@ def get_drop_schema_sql(cls, config, schema): # ODBC FUNCTIONS -- these should not need to change for every adapter, # although some adapters may override them ### - @classmethod - def get_default_schema(cls, config): - return config.credentials.schema + def get_default_schema(self): + return self.config.credentials.schema - @classmethod - def get_connection(cls, config, name=None, recache_if_missing=True): + def get_connection(self, name=None, recache_if_missing=True): global connections_in_use if name is None: @@ -429,22 +406,21 @@ def get_connection(cls, config, name=None, recache_if_missing=True): '(recache_if_missing is off).'.format(name)) logger.debug('Acquiring new {} connection "{}".' - .format(cls.type(), name)) + .format(self.type(), name)) - connection = cls.acquire_connection(config, name) + connection = self.acquire_connection(name) connections_in_use[name] = connection - return cls.get_connection(config, name) + return self.get_connection(name) - @classmethod - def cancel_open_connections(cls, config): + def cancel_open_connections(self): global connections_in_use for name, connection in connections_in_use.items(): if name == 'master': continue - cls.cancel_connection(config, connection) + self.cancel_connection(connection) yield name @classmethod @@ -453,17 +429,16 @@ def total_connections_allocated(cls): return len(connections_in_use) + len(connections_available) - @classmethod - def acquire_connection(cls, config, name): + def acquire_connection(self, name): global connections_available, lock # we add a magic number, 2 because there are overhead connections, # one for pre- and post-run hooks and other misc operations that occur # before the run starts, and one for integration tests. - max_connections = config.threads + 2 + max_connections = self.config.threads + 2 with lock: - num_allocated = cls.total_connections_allocated() + num_allocated = self.total_connections_allocated() if len(connections_available) > 0: logger.debug('Re-using an available connection from the pool.') @@ -481,41 +456,37 @@ def acquire_connection(cls, config, name): .format(num_allocated)) result = Connection( - type=cls.type(), + type=self.type(), name=name, state='init', transaction_open=False, handle=None, - credentials=config.credentials + credentials=self.config.credentials ) - return cls.open_connection(result) + return self.open_connection(result) - @classmethod - def release_connection(cls, config, name='master'): + def release_connection(self, name): global connections_in_use, connections_available, lock - if name not in connections_in_use: - return + with lock: - to_release = cls.get_connection(config, name, recache_if_missing=False) + if name not in connections_in_use: + return - try: - lock.acquire() + to_release = self.get_connection(name, recache_if_missing=False) if to_release.state == 'open': if to_release.transaction_open is True: - cls.rollback(to_release) + self.rollback(to_release) to_release.name = None connections_available.append(to_release) else: - cls.close(to_release) + self.close(to_release) del connections_in_use[name] - finally: - lock.release() @classmethod def cleanup_connections(cls): @@ -538,23 +509,18 @@ def cleanup_connections(cls): connections_in_use = {} connections_available = [] - @classmethod - def reload(cls, connection): - return cls.get_connection(connection.credentials, - connection.name) + def reload(self, connection): + return self.get_connection(connection.name) - @classmethod - def add_begin_query(cls, config, name): - return cls.add_query(config, 'BEGIN', name, auto_begin=False) + def add_begin_query(self, name): + return self.add_query('BEGIN', name, auto_begin=False) - @classmethod - def add_commit_query(cls, config, name): - return cls.add_query(config, 'COMMIT', name, auto_begin=False) + def add_commit_query(self, name): + return self.add_query('COMMIT', name, auto_begin=False) - @classmethod - def begin(cls, config, name='master'): + def begin(self, name): global connections_in_use - connection = cls.get_connection(config, name) + connection = self.get_connection(name) if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) @@ -564,15 +530,14 @@ def begin(cls, config, name='master'): 'Tried to begin a new transaction on connection "{}", but ' 'it already had one open!'.format(connection.get('name'))) - cls.add_begin_query(config, name) + self.add_begin_query(name) connection.transaction_open = True connections_in_use[name] = connection return connection - @classmethod - def commit_if_has_connection(cls, config, name): + def commit_if_has_connection(self, name): global connections_in_use if name is None: @@ -581,18 +546,17 @@ def commit_if_has_connection(cls, config, name): if connections_in_use.get(name) is None: return - connection = cls.get_connection(config, name, False) + connection = self.get_connection(name, False) - return cls.commit(config, connection) + return self.commit(connection) - @classmethod - def commit(cls, config, connection): + def commit(self, connection): global connections_in_use if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) - connection = cls.reload(connection) + connection = self.reload(connection) if connection.transaction_open is False: raise dbt.exceptions.InternalException( @@ -600,19 +564,18 @@ def commit(cls, config, connection): 'it does not have one open!'.format(connection.name)) logger.debug('On {}: COMMIT'.format(connection.name)) - cls.add_commit_query(config, connection.name) + self.add_commit_query(connection.name) connection.transaction_open = False connections_in_use[connection.name] = connection return connection - @classmethod - def rollback(cls, connection): + def rollback(self, connection): if dbt.flags.STRICT_MODE: Connection(**connection) - connection = cls.reload(connection) + connection = self.reload(connection) if connection.transaction_open is False: raise dbt.exceptions.InternalException( @@ -640,19 +603,18 @@ def close(cls, connection): return connection - @classmethod - def add_query(cls, config, sql, model_name=None, auto_begin=True, + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): - connection = cls.get_connection(config, model_name) + connection = self.get_connection(model_name) connection_name = connection.name if auto_begin and connection.transaction_open is False: - cls.begin(config, connection_name) + self.begin(connection_name) logger.debug('Using {} connection "{}".' - .format(cls.type(), connection_name)) + .format(self.type(), connection_name)) - with cls.exception_handler(config, sql, model_name, connection_name): + with self.exception_handler(sql, model_name, connection_name): if abridge_sql_log: logger.debug('On %s: %s....', connection_name, sql[0:512]) else: @@ -663,99 +625,82 @@ def add_query(cls, config, sql, model_name=None, auto_begin=True, cursor.execute(sql, bindings) logger.debug("SQL status: %s in %0.2f seconds", - cls.get_status(cursor), (time.time() - pre)) + self.get_status(cursor), (time.time() - pre)) return connection, cursor - @classmethod - def clear_transaction(cls, config, conn_name='master'): - conn = cls.begin(config, conn_name) - cls.commit(config, conn) + def clear_transaction(self, conn_name='master'): + conn = self.begin(conn_name) + self.commit(conn) return conn_name - @classmethod - def execute_one(cls, config, sql, model_name=None, auto_begin=False): - cls.get_connection(config, model_name) + def execute_one(self, sql, model_name=None, auto_begin=False): + self.get_connection(model_name) - return cls.add_query(config, sql, model_name, auto_begin) + return self.add_query(sql, model_name, auto_begin) - @classmethod - def execute_and_fetch(cls, config, sql, model_name=None, + def execute_and_fetch(self, sql, model_name=None, auto_begin=False): - _, cursor = cls.execute_one(config, sql, model_name, auto_begin) + _, cursor = self.execute_one(sql, model_name, auto_begin) - status = cls.get_status(cursor) - table = cls.get_result_from_cursor(cursor) + status = self.get_status(cursor) + table = self.get_result_from_cursor(cursor) return status, table - @classmethod - def execute(cls, config, sql, model_name=None, auto_begin=False, + def execute(self, sql, model_name=None, auto_begin=False, fetch=False): if fetch: - return cls.execute_and_fetch(config, sql, model_name, auto_begin) + return self.execute_and_fetch(sql, model_name, auto_begin) else: - _, cursor = cls.execute_one(config, sql, model_name, auto_begin) - status = cls.get_status(cursor) + _, cursor = self.execute_one(sql, model_name, auto_begin) + status = self.get_status(cursor) return status, dbt.clients.agate_helper.empty_table() - @classmethod - def execute_all(cls, config, sqls, model_name=None): - connection = cls.get_connection(config, model_name) + def execute_all(self, sqls, model_name=None): + connection = self.get_connection(model_name) if len(sqls) == 0: return connection for i, sql in enumerate(sqls): - connection, _ = cls.add_query(config, sql, model_name) + connection, _ = self.add_query(sql, model_name) return connection - @classmethod - def create_schema(cls, config, schema, model_name=None): + def create_schema(self, schema, model_name=None): logger.debug('Creating schema "%s".', schema) - sql = cls.get_create_schema_sql(config, schema) - res = cls.add_query(config, sql, model_name) + sql = self.get_create_schema_sql(schema) + res = self.add_query(sql, model_name) - cls.commit_if_has_connection(config, model_name) + self.commit_if_has_connection(model_name) return res - @classmethod - def drop_schema(cls, config, schema, model_name=None): + def drop_schema(self, schema, model_name=None): logger.debug('Dropping schema "%s".', schema) - sql = cls.get_drop_schema_sql(config, schema) - return cls.add_query(config, sql, model_name) + sql = self.get_drop_schema_sql(schema) + return self.add_query(sql, model_name) - @classmethod - def already_exists(cls, config, schema, table, model_name=None): - relation = cls.get_relation(config, schema=schema, identifier=table) + def already_exists(self, schema, table, model_name=None): + relation = self.get_relation(schema=schema, identifier=table) return relation is not None @classmethod def quote(cls, identifier): return '"{}"'.format(identifier) - @classmethod - def _quote_as_configured(cls, config, identifier, quote_key): - """This is the actual implementation of quote_as_configured, without - the extra arguments needed for use inside materialization code. - """ - default = cls.Relation.DEFAULTS['quote_policy'].get(quote_key) - if config.quoting.get(quote_key, default): - return cls.quote(identifier) - else: - return identifier - - @classmethod - def quote_as_configured(cls, config, identifier, quote_key, - model_name=None): + def quote_as_configured(self, identifier, quote_key, model_name=None): """Quote or do not quote the given identifer as configured in the project config for the quote key. The quote key should be one of 'database' (on bigquery, 'profile'), 'identifier', or 'schema', or it will be treated as if you set `True`. """ - return cls._quote_as_configured(config, identifier, quote_key) + default = self.Relation.DEFAULTS['quote_policy'].get(quote_key) + if self.config.quoting.get(quote_key, default): + return self.quote(identifier) + else: + return identifier @classmethod def convert_text_type(cls, agate_table, col_idx): @@ -809,8 +754,7 @@ def convert_agate_type(cls, agate_table, col_idx): ### # Operations involving the manifest ### - @classmethod - def run_operation(cls, config, manifest, operation_name): + def run_operation(self, manifest, operation_name): """Look the operation identified by operation_name up in the manifest and run it. @@ -824,7 +768,7 @@ def run_operation(cls, config, manifest, operation_name): import dbt.context.runtime context = dbt.context.runtime.generate( operation, - config, + self.config, manifest, ) @@ -838,13 +782,12 @@ def run_operation(cls, config, manifest, operation_name): def _filter_table(cls, table, manifest): return table.where(_filter_schemas(manifest)) - @classmethod - def get_catalog(cls, config, manifest): + def get_catalog(self, manifest): try: - table = cls.run_operation(config, manifest, - GET_CATALOG_OPERATION_NAME) + table = self.run_operation(manifest, + GET_CATALOG_OPERATION_NAME) finally: - cls.release_connection(config, GET_CATALOG_OPERATION_NAME) + self.release_connection(GET_CATALOG_OPERATION_NAME) - results = cls._filter_table(table, manifest) + results = self._filter_table(table, manifest) return results diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index 7d450dc2b73..12ac37ca3d9 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -7,21 +7,26 @@ import dbt.exceptions +import threading -adapters = { + +ADAPTER_TYPES = { 'postgres': PostgresAdapter, 'redshift': RedshiftAdapter, 'snowflake': SnowflakeAdapter, 'bigquery': BigQueryAdapter } +_ADAPTERS = {} +_ADAPTER_LOCK = threading.Lock() + -def get_adapter_by_name(adapter_name): - adapter = adapters.get(adapter_name, None) +def get_adapter_class_by_name(adapter_name): + adapter = ADAPTER_TYPES.get(adapter_name, None) if adapter is None: message = "Invalid adapter type {}! Must be one of {}" - adapter_names = ", ".join(adapters.keys()) + adapter_names = ", ".join(ADAPTER_TYPES.keys()) formatted_message = message.format(adapter_name, adapter_names) raise dbt.exceptions.RuntimeException(formatted_message) @@ -29,5 +34,19 @@ def get_adapter_by_name(adapter_name): return adapter + def get_adapter(config): - return get_adapter_by_name(config.credentials.type) + adapter_name = config.credentials.type + if adapter_name in _ADAPTERS: + return _ADAPTERS[adapter_name] + + adapter_type = get_adapter_class_by_name(adapter_name) + with _ADAPTER_LOCK: + # check again, in case something was setting it before + if adapter_name in _ADAPTERS: + return _ADAPTERS[adapter_name] + + adapter = adapter_type(config) + _ADAPTERS[adapter_name] = adapter + return adapter + diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 3ad39e3fb46..29355df5a76 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -14,10 +14,8 @@ class PostgresAdapter(dbt.adapters.default.DefaultAdapter): DEFAULT_TCP_KEEPALIVE = 0 # 0 means to use the default value - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, - connection_name=None): + def exception_handler(self, sql, model_name=None, connection_name=None): try: yield @@ -26,7 +24,7 @@ def exception_handler(cls, config, sql, model_name=None, try: # attempt to release the connection - cls.release_connection(config, connection_name) + self.release_connection(connection_name) except psycopg2.Error: logger.debug("Failed to release connection!") pass @@ -37,7 +35,7 @@ def exception_handler(cls, config, sql, model_name=None, except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.RuntimeException(e) @classmethod @@ -96,8 +94,7 @@ def open_connection(cls, connection): return connection - @classmethod - def cancel_connection(cls, config, connection): + def cancel_connection(self, connection): connection_name = connection.name pid = connection.handle.get_backend_pid() @@ -105,7 +102,7 @@ def cancel_connection(cls, config, connection): logger.debug("Cancelling query '{}' ({})".format(connection_name, pid)) - _, cursor = cls.add_query(config, sql, 'master') + _, cursor = self.add_query(sql, 'master') res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) @@ -113,8 +110,7 @@ def cancel_connection(cls, config, connection): # DATABASE INSPECTION FUNCTIONS # These require the profile AND project, as they need to know # database-specific configs at the project level. - @classmethod - def alter_column_type(cls, config, schema, table, column_name, + def alter_column_type(self, schema, table, column_name, new_column_type, model_name=None): """ 1. Create a new column (w/ temp name and correct type) @@ -123,10 +119,10 @@ def alter_column_type(cls, config, schema, table, column_name, 4. Rename the new column to existing column """ - relation = cls.Relation.create( + relation = self.Relation.create( schema=schema, identifier=table, - quote_policy=config.quoting + quote_policy=self.config.quoting ) opts = { @@ -143,12 +139,11 @@ def alter_column_type(cls, config, schema, table, column_name, alter table {relation} rename column "{tmp_column}" to "{old_column}"; """.format(**opts).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) return connection, cursor - @classmethod - def list_relations(cls, config, schema, model_name=None): + def list_relations(self, schema, model_name=None): sql = """ select tablename as name, schemaname as schema, 'table' as type from pg_tables where schemaname ilike '{schema}' @@ -157,13 +152,13 @@ def list_relations(cls, config, schema, model_name=None): where schemaname ilike '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, + auto_begin=False) results = cursor.fetchall() - return [cls.Relation.create( - database=config.credentials.dbname, + return [self.Relation.create( + database=self.config.credentials.dbname, schema=_schema, identifier=name, quote_policy={ @@ -173,24 +168,21 @@ def list_relations(cls, config, schema, model_name=None): type=type) for (name, _schema, type) in results] - @classmethod - def get_existing_schemas(cls, config, model_name=None): + def get_existing_schemas(self, model_name=None): sql = "select distinct nspname from pg_namespace" - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() return [row[0] for row in results] - @classmethod - def check_schema_exists(cls, config, schema, model_name=None): + def check_schema_exists(self, schema, model_name=None): sql = """ select count(*) from pg_namespace where nspname = '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, + auto_begin=False) results = cursor.fetchone() return results[0] > 0 diff --git a/dbt/adapters/redshift/impl.py b/dbt/adapters/redshift/impl.py index a1f36cb297e..a15ec186f0e 100644 --- a/dbt/adapters/redshift/impl.py +++ b/dbt/adapters/redshift/impl.py @@ -156,8 +156,7 @@ def _get_columns_in_table_sql(cls, schema_name, table_name, database): table_schema_filter=table_schema_filter).strip() return sql - @classmethod - def drop_relation(cls, config, relation, model_name=None): + def drop_relation(self, relation, model_name=None): """ In Redshift, DROP TABLE ... CASCADE should not be used inside a transaction. Redshift doesn't prevent the CASCADE @@ -178,18 +177,18 @@ def drop_relation(cls, config, relation, model_name=None): with drop_lock: - connection = cls.get_connection(config, model_name) + connection = self.get_connection(model_name) if connection.transaction_open: - cls.commit(config, connection) + self.commit(connection) - cls.begin(config, connection.name) + self.begin(connection.name) - to_return = super(PostgresAdapter, cls).drop_relation( - config, relation, model_name) + to_return = super(RedshiftAdapter, self).drop_relation( + relation, model_name) - cls.commit(config, connection) - cls.begin(config, connection.name) + self.commit(connection) + self.begin(connection.name) return to_return diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 47bc08efa96..44963c38a22 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -20,11 +20,10 @@ class SnowflakeAdapter(PostgresAdapter): Relation = SnowflakeRelation - @classmethod @contextmanager - def exception_handler(cls, config, sql, model_name=None, + def exception_handler(self, sql, model_name=None, connection_name='master'): - connection = cls.get_connection(config, connection_name) + connection = self.get_connection(connection_name) try: yield @@ -36,7 +35,7 @@ def exception_handler(cls, config, sql, model_name=None, if 'Empty SQL statement' in msg: logger.debug("got empty sql statement, moving on") elif 'This session does not have a current database' in msg: - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.FailedToConnectException( ('{}\n\nThis error sometimes occurs when invalid ' 'credentials are provided, or when your default role ' @@ -44,12 +43,12 @@ def exception_handler(cls, config, sql, model_name=None, 'Please double check your profile and try again.') .format(msg)) else: - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.DatabaseException(msg) except Exception as e: logger.debug("Error running SQL: %s", sql) logger.debug("Rolling back transaction.") - cls.release_connection(config, connection_name) + self.release_connection(connection_name) raise dbt.exceptions.RuntimeException(e.msg) @classmethod @@ -102,8 +101,7 @@ def open_connection(cls, connection): return connection - @classmethod - def list_relations(cls, config, schema, model_name=None): + def list_relations(self, schema, model_name=None): sql = """ select table_name as name, table_schema as schema, table_type as type @@ -111,8 +109,7 @@ def list_relations(cls, config, schema, model_name=None): where table_schema ilike '{schema}' """.format(schema=schema).strip() # noqa - _, cursor = cls.add_query( - config, sql, model_name, auto_begin=False) + _, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() @@ -121,8 +118,8 @@ def list_relations(cls, config, schema, model_name=None): 'VIEW': 'view' } - return [cls.Relation.create( - database=config.credentials.database, + return [self.Relation.create( + database=self.config.credentials.database, schema=_schema, identifier=name, quote_policy={ @@ -132,38 +129,32 @@ def list_relations(cls, config, schema, model_name=None): type=relation_type_lookup.get(type)) for (name, _schema, type) in results] - @classmethod - def rename_relation(cls, config, from_relation, to_relation, + def rename_relation(self, from_relation, to_relation, model_name=None): sql = 'alter table {} rename to {}'.format( from_relation, to_relation) - connection, cursor = cls.add_query(config, sql, model_name) + connection, cursor = self.add_query(sql, model_name) - @classmethod - def add_begin_query(cls, config, name): - return cls.add_query(config, 'BEGIN', name, auto_begin=False) + def add_begin_query(self, name): + return self.add_query('BEGIN', name, auto_begin=False) - @classmethod - def get_existing_schemas(cls, config, model_name=None): + def get_existing_schemas(self, model_name=None): sql = "select distinct schema_name from information_schema.schemata" - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() return [row[0] for row in results] - @classmethod - def check_schema_exists(cls, config, schema, model_name=None): + def check_schema_exists(self, schema, model_name=None): sql = """ select count(*) from information_schema.schemata where upper(schema_name) = upper('{schema}') """.format(schema=schema).strip() # noqa - connection, cursor = cls.add_query(config, sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchone() return results[0] > 0 @@ -177,8 +168,7 @@ def _split_queries(cls, sql): split_query = snowflake.connector.util_text.split_statements(sql_buf) return [part[0] for part in split_query] - @classmethod - def add_query(cls, config, sql, model_name=None, auto_begin=True, + def add_query(self, sql, model_name=None, auto_begin=True, bindings=None, abridge_sql_log=False): connection = None @@ -189,7 +179,7 @@ def add_query(cls, config, sql, model_name=None, auto_begin=True, # which allows any iterable thing to be passed as a binding. bindings = tuple(bindings) - queries = cls._split_queries(sql) + queries = self._split_queries(sql) for individual_query in queries: # hack -- after the last ';', remove comments and don't run @@ -202,9 +192,10 @@ def add_query(cls, config, sql, model_name=None, auto_begin=True, if without_comments == "": continue - connection, cursor = super(PostgresAdapter, cls).add_query( - config, individual_query, model_name, auto_begin, - bindings=bindings, abridge_sql_log=abridge_sql_log) + connection, cursor = super(SnowflakeAdapter, self).add_query( + individual_query, model_name, auto_begin, bindings=bindings, + abridge_sql_log=abridge_sql_log + ) if cursor is None: raise dbt.exceptions.RuntimeException( @@ -224,19 +215,18 @@ def _filter_table(cls, table, manifest): ) return super(SnowflakeAdapter, cls)._filter_table(lowered, manifest) - @classmethod - def _make_match_kwargs(cls, config, schema, identifier): - if identifier is not None and config.quoting['identifier'] is False: + def _make_match_kwargs(self, schema, identifier): + quoting = self.config.quoting + if identifier is not None and quoting['identifier'] is False: identifier = identifier.upper() - if schema is not None and config.quoting['schema'] is False: + if schema is not None and quoting['schema'] is False: schema = schema.upper() return filter_null_values({'identifier': identifier, 'schema': schema}) - @classmethod - def cancel_connection(cls, config, connection): + def cancel_connection(self, connection): handle = connection.handle sid = handle.session_id @@ -246,7 +236,7 @@ def cancel_connection(cls, config, connection): logger.debug("Cancelling query '{}' ({})".format(connection_name, sid)) - _, cursor = cls.add_query(config, sql, 'master') + _, cursor = self.add_query(sql, 'master') res = cursor.fetchone() logger.debug("Cancel query '{}': {}".format(connection_name, res)) diff --git a/dbt/context/common.py b/dbt/context/common.py index 8187207270e..d1e20eeafb6 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -1,3 +1,5 @@ +import copy +import functools import json import os @@ -24,51 +26,68 @@ import datetime +class RelationProxy(object): + def __init__(self, adapter): + self.config = adapter.config + self.relation_type = adapter.Relation + + def __getattr__(self, key): + return getattr(self.relation_type, key) + + def create(self, *args, **kwargs): + kwargs['quote_policy'] = dbt.utils.merge( + self.config.quoting, + kwargs.pop('quote_policy', {}) + ) + return self.relation_type.create(*args, **kwargs) + + class DatabaseWrapper(object): """ - Wrapper for runtime database interaction. Should only call adapter - functions. + Wrapper for runtime database interaction. Mostly a compatibility layer now. """ - - def __init__(self, model, adapter, config): + def __init__(self, model, adapter): self.model = model self.adapter = adapter - self.config = config - self.Relation = adapter.Relation - - # Fun with metaprogramming - # Most adapter functions take `profile` as the first argument, and - # `model_name` as the last. This automatically injects those arguments. - # In model code, these functions can be called without those two args. - for context_function in self.adapter.context_functions: - setattr(self, - context_function, - self.wrap(context_function, (self.config,))) - - for profile_function in self.adapter.profile_functions: - setattr(self, - profile_function, - self.wrap(profile_function, (self.config,))) - - for raw_function in self.adapter.raw_functions: - setattr(self, - raw_function, - getattr(self.adapter, raw_function)) - - def wrap(self, fn, arg_prefix): + self.Relation = RelationProxy(self.adapter) + + # TODO: clean up this part of the adapter classes + self._wrapped = frozenset( + self.adapter.context_functions + self.adapter.profile_functions + ) + self._proxied = frozenset(self.adapter.raw_functions) + + def wrap(self, name): + func = getattr(self.adapter, name) + + @functools.wraps(func) def wrapped(*args, **kwargs): - args = arg_prefix + args kwargs['model_name'] = self.model.get('name') - return getattr(self.adapter, fn)(*args, **kwargs) + return func(*args, **kwargs) return wrapped + def __getattr__(self, name): + if name in self._wrapped: + return self.wrap(name) + elif name in self._proxied: + return getattr(self.adapter, name) + else: + raise AttributeError( + "'{}' object has no attribute '{}'".format( + self.__class__.__name__, name + ) + ) + + @property + def config(self): + return self.adapter.config + def type(self): return self.adapter.type() def commit(self): - return self.adapter.commit_if_has_connection( - self.config, self.model.get('name')) + return self.adapter.commit_if_has_connection(self.model.get('name')) def _add_macros(context, model, manifest): @@ -307,34 +326,6 @@ def get_this_relation(db_wrapper, config, model): config, model) -def create_relation(relation_type, quoting_config): - - class RelationWithContext(relation_type): - @classmethod - def create(cls, *args, **kwargs): - quote_policy = quoting_config - - if 'quote_policy' in kwargs: - quote_policy = dbt.utils.merge( - quote_policy, - kwargs.pop('quote_policy')) - - return relation_type.create(*args, - quote_policy=quote_policy, - **kwargs) - - return RelationWithContext - - -def create_adapter(adapter_type, relation_type): - - class AdapterWithContext(adapter_type): - - Relation = relation_type - - return AdapterWithContext - - def generate_base(model, model_dict, config, manifest, source_config, provider): """Generate the common aspects of the config dict.""" @@ -357,16 +348,12 @@ def generate_base(model, model_dict, config, manifest, source_config, pre_hooks = None post_hooks = None - relation_type = create_relation(adapter.Relation, - config.quoting) + db_wrapper = DatabaseWrapper(model_dict, adapter) - db_wrapper = DatabaseWrapper(model_dict, - create_adapter(adapter, relation_type), - config) context = dbt.utils.merge(context, { "adapter": db_wrapper, "api": { - "Relation": relation_type, + "Relation": db_wrapper.Relation, "Column": adapter.Column, }, "column": adapter.Column, diff --git a/dbt/node_runners.py b/dbt/node_runners.py index f67618937a2..20a25a835e1 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -144,7 +144,7 @@ def _safe_release_connection(self): """ node_name = self.node.name try: - self.adapter.release_connection(self.config, node_name) + self.adapter.release_connection(node_name) except Exception as exc: logger.debug( 'Error releasing connection for node {}: {!s}\n{}' @@ -229,7 +229,7 @@ def compile(self, manifest): def _compile_node(cls, adapter, config, node, manifest): compiler = dbt.compilation.Compiler(config) node = compiler.compile_node(node, manifest) - node = cls._inject_runtime_config(adapter, config, node) + node = cls._inject_runtime_config(adapter, node) if(node.injected_sql is not None and not (dbt.utils.is_type(node, NodeType.Archive))): @@ -247,30 +247,29 @@ def _compile_node(cls, adapter, config, node, manifest): return node @classmethod - def _inject_runtime_config(cls, adapter, config, node): + def _inject_runtime_config(cls, adapter, node): wrapped_sql = node.wrapped_sql - context = cls._node_context(adapter, config, node) + context = cls._node_context(adapter, node) sql = dbt.clients.jinja.get_rendered(wrapped_sql, context) node.wrapped_sql = sql return node @classmethod - def _node_context(cls, adapter, config, node): + def _node_context(cls, adapter, node): def call_get_columns_in_table(schema_name, table_name): return adapter.get_columns_in_table( - config, schema_name, - table_name, model_name=node.alias) + schema_name, table_name, model_name=node.alias + ) def call_get_missing_columns(from_schema, from_table, to_schema, to_table): return adapter.get_missing_columns( - config, from_schema, from_table, - to_schema, to_table, node.alias) + from_schema, from_table, to_schema, to_table, node.alias + ) def call_already_exists(schema, table): - return adapter.already_exists( - config, schema, table, node.alias) + return adapter.already_exists(schema, table, node.alias) return { "run_started_at": dbt.tracking.active_user.run_started_at, @@ -304,7 +303,7 @@ def run_hooks(cls, config, adapter, manifest, hook_type): # implement a for-loop over these sql statements in jinja-land. # Also, consider configuring psycopg2 (and other adapters?) to # ensure that a transaction is only created if dbt initiates it. - adapter.clear_transaction(config, model_name) + adapter.clear_transaction(model_name) compiled = cls._compile_node(adapter, config, hook, manifest) statement = compiled.wrapped_sql @@ -317,10 +316,10 @@ def run_hooks(cls, config, adapter, manifest, hook_type): sql = hook_dict.get('sql', '') if len(sql.strip()) > 0: - adapter.execute(config, sql, model_name=model_name, - auto_begin=False, fetch=False) + adapter.execute(sql, model_name=model_name, auto_begin=False, + fetch=False) - adapter.release_connection(config, model_name) + adapter.release_connection(model_name) @classmethod def safe_run_hooks(cls, config, adapter, manifest, hook_type): @@ -339,12 +338,12 @@ def create_schemas(cls, config, adapter, manifest): # is the one defined in the profile. Create this schema if it # does not exist, otherwise subsequent queries will fail. Generally, # dbt expects that this schema will exist anyway. - required_schemas.add(adapter.get_default_schema(config)) + required_schemas.add(adapter.get_default_schema()) - existing_schemas = set(adapter.get_existing_schemas(config)) + existing_schemas = set(adapter.get_existing_schemas()) for schema in (required_schemas - existing_schemas): - adapter.create_schema(config, schema) + adapter.create_schema(schema) @classmethod def before_run(cls, config, adapter, manifest): @@ -443,7 +442,6 @@ def print_start_line(self): def execute_test(self, test): res, table = self.adapter.execute_and_fetch( - self.config, test.wrapped_sql, test.name, auto_begin=True) diff --git a/dbt/parser/base.py b/dbt/parser/base.py index 1ee4bb15178..89ca46a5c33 100644 --- a/dbt/parser/base.py +++ b/dbt/parser/base.py @@ -119,7 +119,7 @@ def parse_node(cls, node, node_path, root_project_config, db_wrapper = context['adapter'] adapter = db_wrapper.adapter runtime_config = db_wrapper.config - adapter.release_connection(runtime_config, parsed_node.name) + adapter.release_connection(parsed_node.name) # Special macro defined in the global project schema_override = config.config.get('schema') diff --git a/dbt/runner.py b/dbt/runner.py index 9a359b17c86..e968ca7d83b 100644 --- a/dbt/runner.py +++ b/dbt/runner.py @@ -150,7 +150,7 @@ def execute_nodes(self, linker, Runner, manifest, node_dependency_list): dbt.ui.printer.print_timestamped_line(msg, yellow) raise - for conn_name in adapter.cancel_open_connections(self.config): + for conn_name in adapter.cancel_open_connections(): dbt.ui.printer.print_cancel_line(conn_name) dbt.ui.printer.print_run_end_messages(node_results, diff --git a/dbt/task/generate.py b/dbt/task/generate.py index 05808f1016f..6cc1df5e02c 100644 --- a/dbt/task/generate.py +++ b/dbt/task/generate.py @@ -213,7 +213,7 @@ def run(self): adapter = get_adapter(self.config) dbt.ui.printer.print_timestamped_line("Building catalog") - results = adapter.get_catalog(self.config, manifest) + results = adapter.get_catalog(manifest) results = [ dict(zip(results.column_names, row)) diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index d1f0f402153..f00283e539f 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -11,38 +11,56 @@ fake_conn = {"handle": None, "state": "open", "type": "bigquery"} +from .utils import config_from_parts_or_dicts + class TestBigQueryAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True + self.raw_profile = { + 'outputs': { + 'oauth': { + 'type': 'bigquery', + 'method': 'oauth', + 'project': 'dbt-unit-000000', + 'schema': 'dummy_schema', + 'threads': 1, + }, + 'service_account': { + 'type': 'bigquery', + 'method': 'service-account', + 'project': 'dbt-unit-000000', + 'schema': 'dummy_schema', + 'keyfile': '/tmp/dummy-service-account.json', + 'threads': 1, + }, + }, + 'target': 'oauth', + } - self.oauth_credentials = BigQueryCredentials( - method='oauth', - project='dbt-unit-000000', - schema='dummy_schema' - ) - self.oauth_profile = MagicMock( - credentials=self.oauth_credentials, - threads=1 - ) + self.project_cfg = { + 'name': 'X', + 'version': '0.1', + 'project-root': '/tmp/dbt/does-not-exist', + } - self.service_account_credentials = BigQueryCredentials( - method='service-account', - project='dbt-unit-000000', - schema='dummy_schema', - keyfile='/tmp/dummy-service-account.json' - ) - self.service_account_profile = MagicMock( - credentials=self.service_account_credentials, - threads=1 + def get_adapter(self, profile): + project = self.project_cfg.copy() + project['profile'] = profile + + config = config_from_parts_or_dicts( + project=project, + profile=self.raw_profile, ) + return BigQueryAdapter(config) @patch('dbt.adapters.bigquery.BigQueryAdapter.open_connection', return_value=fake_conn) def test_acquire_connection_oauth_validations(self, mock_open_connection): + adapter = self.get_adapter('oauth') try: - connection = BigQueryAdapter.acquire_connection(self.oauth_profile, 'dummy') + connection = adapter.acquire_connection('dummy') self.assertEquals(connection.get('type'), 'bigquery') except dbt.exceptions.ValidationException as e: @@ -56,8 +74,9 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): @patch('dbt.adapters.bigquery.BigQueryAdapter.open_connection', return_value=fake_conn) def test_acquire_connection_service_account_validations(self, mock_open_connection): + adapter = self.get_adapter('service_account') try: - connection = BigQueryAdapter.acquire_connection(self.service_account_profile, 'dummy') + connection = adapter.acquire_connection('dummy') self.assertEquals(connection.get('type'), 'bigquery') except dbt.exceptions.ValidationException as e: diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index a3ac793411d..7dea6a40a3c 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -40,10 +40,13 @@ def setUp(self): self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + @property + def adapter(self): + return PostgresAdapter(self.config) + def test_acquire_connection_validations(self): try: - connection = PostgresAdapter.acquire_connection(self.config, - 'dummy') + connection = self.adapter.acquire_connection('dummy') self.assertEquals(connection.type, 'postgres') except ValidationException as e: self.fail('got ValidationException: {}'.format(str(e))) @@ -52,14 +55,14 @@ def test_acquire_connection_validations(self): .format(str(e))) def test_acquire_connection(self): - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + connection = self.adapter.acquire_connection('dummy') self.assertEquals(connection.state, 'open') self.assertNotEquals(connection.handle, None) @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_default_keepalive(self, psycopg2): - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -71,9 +74,11 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - credentials = self.config.credentials.incorporate(keepalives_idle=256) - self.config.credentials = credentials - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + credentials = self.adapter.config.credentials.incorporate( + keepalives_idle=256 + ) + self.adapter.config.credentials = credentials + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -88,7 +93,7 @@ def test_changed_keepalive(self, psycopg2): def test_set_zero_keepalive(self, psycopg2): credentials = self.config.credentials.incorporate(keepalives_idle=0) self.config.credentials = credentials - connection = PostgresAdapter.acquire_connection(self.config, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='postgres', @@ -121,7 +126,7 @@ def test_get_catalog_various_schemas(self, mock_run): # give manifest the dict it wants mock_manifest = mock.MagicMock(spec_set=['nodes'], nodes=nodes) - catalog = PostgresAdapter.get_catalog(mock.MagicMock(), mock_manifest) + catalog = self.adapter.get_catalog(mock_manifest) self.assertEqual( set(map(tuple, catalog)), {('foo', 'bar'), ('FOO', 'baz'), ('quux', 'bar')} @@ -166,26 +171,23 @@ def setUp(self): self.psycopg2 = self.patcher.start() self.psycopg2.connect.return_value = self.handle - conn = PostgresAdapter.get_connection(self.config) + self.adapter = PostgresAdapter(self.config) + self.adapter.get_connection() def tearDown(self): # we want a unique self.handle every time. - PostgresAdapter.cleanup_connections() + self.adapter.cleanup_connections() self.patcher.stop() def test_quoting_on_drop_schema(self): - PostgresAdapter.drop_schema( - config=self.config, - schema='test_schema' - ) + self.adapter.drop_schema(schema='test_schema') self.mock_execute.assert_has_calls([ mock.call('drop schema if exists "test_schema" cascade', None) ]) def test_quoting_on_drop(self): - PostgresAdapter.drop( - config=self.config, + self.adapter.drop( schema='test_schema', relation='test_table', relation_type='table' @@ -195,8 +197,7 @@ def test_quoting_on_drop(self): ]) def test_quoting_on_truncate(self): - PostgresAdapter.truncate( - config=self.config, + self.adapter.truncate( schema='test_schema', table='test_table' ) @@ -205,8 +206,7 @@ def test_quoting_on_truncate(self): ]) def test_quoting_on_rename(self): - PostgresAdapter.rename( - config=self.config, + self.adapter.rename( schema='test_schema', from_name='table_a', to_name='table_b' diff --git a/test/unit/test_redshift_adapter.py b/test/unit/test_redshift_adapter.py index 4b713737d3d..d1b6559178d 100644 --- a/test/unit/test_redshift_adapter.py +++ b/test/unit/test_redshift_adapter.py @@ -8,7 +8,8 @@ from dbt.adapters.redshift import RedshiftAdapter from dbt.exceptions import ValidationException, FailedToConnectException from dbt.logger import GLOBAL_LOGGER as logger # noqa -from dbt.config import Profile + +from .utils import config_from_parts_or_dicts @classmethod @@ -23,7 +24,7 @@ class TestRedshiftAdapter(unittest.TestCase): def setUp(self): flags.STRICT_MODE = True - self.profile = Profile.from_raw_profile_info({ + profile_cfg = { 'outputs': { 'test': { 'type': 'redshift', @@ -36,55 +37,72 @@ def setUp(self): } }, 'target': 'test' - }, 'test') + } + + project_cfg = { + 'name': 'X', + 'version': '0.1', + 'profile': 'test', + 'project-root': '/tmp/dbt/does-not-exist', + 'quoting': { + 'identifier': False, + 'schema': True, + } + } + + self.config = config_from_parts_or_dicts(project_cfg, profile_cfg) + + @property + def adapter(self): + return RedshiftAdapter(self.config) def test_implicit_database_conn(self): - creds = RedshiftAdapter.get_credentials(self.profile.credentials) - self.assertEquals(creds, self.profile.credentials) + creds = RedshiftAdapter.get_credentials(self.config.credentials) + self.assertEquals(creds, self.config.credentials) def test_explicit_database_conn(self): - self.profile.method = 'database' + self.config.method = 'database' - creds = RedshiftAdapter.get_credentials(self.profile.credentials) - self.assertEquals(creds, self.profile.credentials) + creds = RedshiftAdapter.get_credentials(self.config.credentials) + self.assertEquals(creds, self.config.credentials) def test_explicit_iam_conn(self): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( method='iam', cluster_id='my_redshift', iam_duration_seconds=1200 ) with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - creds = RedshiftAdapter.get_credentials(self.profile.credentials) + creds = RedshiftAdapter.get_credentials(self.config.credentials) - expected_creds = self.profile.credentials.incorporate(password='tmp_password') + expected_creds = self.config.credentials.incorporate(password='tmp_password') self.assertEquals(creds, expected_creds) def test_invalid_auth_method(self): # we have to set method this way, otherwise it won't validate - self.profile.credentials._contents['method'] = 'badmethod' + self.config.credentials._contents['method'] = 'badmethod' with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.get_credentials(self.profile.credentials) + RedshiftAdapter.get_credentials(self.config.credentials) self.assertTrue('badmethod' in context.exception.msg) def test_invalid_iam_no_cluster_id(self): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( method='iam' ) with self.assertRaises(dbt.exceptions.FailedToConnectException) as context: with mock.patch.object(RedshiftAdapter, 'fetch_cluster_credentials', new=fetch_cluster_credentials): - RedshiftAdapter.get_credentials(self.profile.credentials) + RedshiftAdapter.get_credentials(self.config.credentials) self.assertTrue("'cluster_id' must be provided" in context.exception.msg) @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_default_keepalive(self, psycopg2): - connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='redshift', @@ -97,10 +115,10 @@ def test_default_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_changed_keepalive(self, psycopg2): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( keepalives_idle=256 ) - connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='redshift', @@ -113,10 +131,10 @@ def test_changed_keepalive(self, psycopg2): @mock.patch('dbt.adapters.postgres.impl.psycopg2') def test_set_zero_keepalive(self, psycopg2): - self.profile.credentials = self.profile.credentials.incorporate( + self.config.credentials = self.config.credentials.incorporate( keepalives_idle=0 ) - connection = RedshiftAdapter.acquire_connection(self.profile, 'dummy') + connection = self.adapter.acquire_connection('dummy') psycopg2.connect.assert_called_once_with( dbname='redshift', diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index 9aa290c7f57..d85086c04ac 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -49,16 +49,16 @@ def setUp(self): self.snowflake = self.patcher.start() self.snowflake.return_value = self.handle - conn = SnowflakeAdapter.get_connection(self.config) + self.adapter = SnowflakeAdapter(self.config) + self.adapter.get_connection() def tearDown(self): # we want a unique self.handle every time. - SnowflakeAdapter.cleanup_connections() + self.adapter.cleanup_connections() self.patcher.stop() def test_quoting_on_drop_schema(self): - SnowflakeAdapter.drop_schema( - config=self.config, + self.adapter.drop_schema( schema='test_schema' ) @@ -67,8 +67,7 @@ def test_quoting_on_drop_schema(self): ]) def test_quoting_on_drop(self): - SnowflakeAdapter.drop( - config=self.config, + self.adapter.drop( schema='test_schema', relation='test_table', relation_type='table' @@ -78,8 +77,7 @@ def test_quoting_on_drop(self): ]) def test_quoting_on_truncate(self): - SnowflakeAdapter.truncate( - config=self.config, + self.adapter.truncate( schema='test_schema', table='test_table' ) @@ -88,8 +86,7 @@ def test_quoting_on_truncate(self): ]) def test_quoting_on_rename(self): - SnowflakeAdapter.rename( - config=self.config, + self.adapter.rename( schema='test_schema', from_name='table_a', to_name='table_b' From c0ce5cb3e387085f5865d85ef2d4d73c6a18e7b2 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 14 Sep 2018 15:25:42 -0600 Subject: [PATCH 019/133] fix some tests and shut up some warnings --- .../test_simple_archive.py | 4 ++-- .../test_docs_generate.py | 17 ++++++++++------ test/integration/base.py | 20 +++++++------------ 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/test/integration/004_simple_archive_test/test_simple_archive.py b/test/integration/004_simple_archive_test/test_simple_archive.py index 970990ecaef..b6924bd7c40 100644 --- a/test/integration/004_simple_archive_test/test_simple_archive.py +++ b/test/integration/004_simple_archive_test/test_simple_archive.py @@ -162,8 +162,8 @@ def test__bigquery__archive_with_new_field(self): # A more thorough test would assert that archived == expected, but BigQuery does not support the # "EXCEPT DISTINCT" operator on nested fields! Instead, just check that schemas are congruent. - expected_cols = self.adapter.get_columns_in_table(self.config, self.unique_schema(), 'archive_expected') - archived_cols = self.adapter.get_columns_in_table(self.config, self.unique_schema(), 'archive_actual') + expected_cols = self.adapter.get_columns_in_table(self.unique_schema(), 'archive_expected') + archived_cols = self.adapter.get_columns_in_table(self.unique_schema(), 'archive_actual') self.assertTrue(len(expected_cols) > 0, "source table does not exist -- bad test") self.assertEqual(len(expected_cols), len(archived_cols), "actual and expected column lengths are different") diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index ca67a3f6ae3..955e324eb27 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -26,6 +26,11 @@ def __eq__(self, other): return isinstance(other, basestring) and self.contains in other +def _read_file(path): + with open(path) as fp: + return fp.read() + + class TestDocsGenerate(DBTIntegrationTest): def setUp(self): super(TestDocsGenerate,self).setUp() @@ -729,7 +734,7 @@ def expected_seeded_manifest(self): 'path': 'model.sql', 'original_file_path': model_sql_path, 'package_name': 'test', - 'raw_sql': open(model_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(model_sql_path).rstrip('\n'), 'refs': [['seed']], 'depends_on': {'nodes': ['seed.test.seed'], 'macros': []}, 'unique_id': 'model.test.model', @@ -920,7 +925,7 @@ def expected_seeded_manifest(self): def expected_postgres_references_manifest(self): my_schema_name = self.unique_schema() docs_path = self.dir('ref_models/docs.md') - docs_file = open(docs_path).read().lstrip() + docs_file = _read_file(docs_path).lstrip() return { 'nodes': { 'model.test.ephemeral_copy': { @@ -1204,7 +1209,7 @@ def expected_bigquery_complex_manifest(self): 'original_file_path': clustered_sql_path, 'package_name': 'test', 'path': 'clustered.sql', - 'raw_sql': open(clustered_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(clustered_sql_path).rstrip('\n'), 'refs': [['seed']], 'resource_type': 'model', 'root_path': os.getcwd(), @@ -1258,7 +1263,7 @@ def expected_bigquery_complex_manifest(self): 'original_file_path': nested_view_sql_path, 'package_name': 'test', 'path': 'nested_view.sql', - 'raw_sql': open(nested_view_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(nested_view_sql_path).rstrip('\n'), 'refs': [['nested_table']], 'resource_type': 'model', 'root_path': os.getcwd(), @@ -1312,7 +1317,7 @@ def expected_bigquery_complex_manifest(self): 'original_file_path': nested_table_sql_path, 'package_name': 'test', 'path': 'nested_table.sql', - 'raw_sql': open(nested_table_sql_path).read().rstrip('\n'), + 'raw_sql': _read_file(nested_table_sql_path).rstrip('\n'), 'refs': [], 'resource_type': 'model', 'root_path': os.getcwd(), @@ -1388,7 +1393,7 @@ def expected_redshift_incremental_view_manifest(self): "path": "model.sql", "original_file_path": model_sql_path, "package_name": "test", - "raw_sql": open(model_sql_path).read().rstrip('\n'), + "raw_sql": _read_file(model_sql_path).rstrip('\n'), "refs": [["seed"]], "depends_on": { "nodes": ["seed.test.seed"], diff --git a/test/integration/base.py b/test/integration/base.py index e83bc2dc32b..4a1f23cafc7 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -267,7 +267,7 @@ def load_config(self): adapter = get_adapter(config) adapter.cleanup_connections() - connection = adapter.acquire_connection(config, '__test') + connection = adapter.acquire_connection('__test') self.handle = connection.handle self.adapter_type = connection.type self.adapter = adapter @@ -277,9 +277,7 @@ def load_config(self): self._create_schema() def quote_as_configured(self, value, quote_key): - return self.adapter.quote_as_configured( - self.config, value, quote_key - ) + return self.adapter.quote_as_configured(value, quote_key) def _clean_files(self): if os.path.exists(DBT_PROFILES): @@ -310,7 +308,7 @@ def tearDown(self): def _create_schema(self): if self.adapter_type == 'bigquery': - self.adapter.create_schema(self.config, self.unique_schema(), '__test') + self.adapter.create_schema(self.unique_schema(), '__test') else: schema = self.quote_as_configured(self.unique_schema(), 'schema') self.run_sql('CREATE SCHEMA {}'.format(schema)) @@ -318,7 +316,7 @@ def _create_schema(self): def _drop_schema(self): if self.adapter_type == 'bigquery': - self.adapter.drop_schema(self.config, self.unique_schema(), '__test') + self.adapter.drop_schema(self.unique_schema(), '__test') else: had_existing = False try: @@ -385,9 +383,8 @@ def run_sql_bigquery(self, sql, fetch): """Run an SQL query on a bigquery adapter. No cursors, transactions, etc. to worry about""" - adapter = get_adapter(self.config) do_fetch = fetch != 'None' - _, res = adapter.execute(self.config, sql, fetch=do_fetch) + _, res = self.adapter.execute(sql, fetch=do_fetch) # convert dataframe to matrix-ish repr if fetch == 'one': @@ -449,11 +446,8 @@ def filter_many_columns(self, column): def get_table_columns(self, table, schema=None): schema = self.unique_schema() if schema is None else schema - columns = self.adapter.get_columns_in_table( - self.config, - schema, - table - ) + columns = self.adapter.get_columns_in_table(schema, table) + return sorted(((c.name, c.dtype, c.char_size) for c in columns), key=lambda x: x[0]) From 5c60f18146be919079d0ffafece8f167355129bf Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 17 Sep 2018 11:26:37 -0600 Subject: [PATCH 020/133] move quoting all the way into the project config, make tests reset adapters --- dbt/adapters/bigquery/relation.py | 2 +- dbt/adapters/default/relation.py | 17 ++++++++++++-- dbt/adapters/factory.py | 5 ++++- dbt/adapters/snowflake/relation.py | 3 ++- dbt/config.py | 22 ++++++------------- dbt/context/common.py | 9 ++++---- dbt/contracts/project.py | 6 +++++ .../001_simple_copy_test/test_simple_copy.py | 1 + test/unit/test_config.py | 8 +++++-- 9 files changed, 46 insertions(+), 27 deletions(-) diff --git a/dbt/adapters/bigquery/relation.py b/dbt/adapters/bigquery/relation.py index f5807dc8e3e..d962d493d72 100644 --- a/dbt/adapters/bigquery/relation.py +++ b/dbt/adapters/bigquery/relation.py @@ -86,7 +86,7 @@ def matches(self, project=None, schema=None, identifier=None): return True @classmethod - def create_from_node(cls, config, node, **kwargs): + def _create_from_node(cls, config, node, **kwargs): return cls.create( project=config.credentials.project, schema=node.get('schema'), diff --git a/dbt/adapters/default/relation.py b/dbt/adapters/default/relation.py index b5d1f46d38f..03928bc3939 100644 --- a/dbt/adapters/default/relation.py +++ b/dbt/adapters/default/relation.py @@ -173,14 +173,27 @@ def quoted(self, identifier): identifier=identifier) @classmethod - def create_from_node(cls, project, node, table_name=None, **kwargs): + def _create_from_node(cls, config, node, table_name, quote_policy, + **kwargs): return cls.create( - database=project.credentials.dbname, + database=config.credentials.dbname, schema=node.get('schema'), identifier=node.get('alias'), table_name=table_name, + quote_policy=quote_policy, **kwargs) + @classmethod + def create_from_node(cls, config, node, table_name=None, quote_policy=None, + **kwargs): + if quote_policy is None: + quote_policy = {} + + quote_policy = dbt.utils.merge(config.quoting, quote_policy) + return cls._create_from_node(config=config, quote_policy=quote_policy, + node=node, table_name=table_name, + **kwargs) + @classmethod def create(cls, database=None, schema=None, identifier=None, table_name=None, diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index 12ac37ca3d9..4ddc923074d 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -34,7 +34,6 @@ def get_adapter_class_by_name(adapter_name): return adapter - def get_adapter(config): adapter_name = config.credentials.type if adapter_name in _ADAPTERS: @@ -50,3 +49,7 @@ def get_adapter(config): _ADAPTERS[adapter_name] = adapter return adapter + +def get_relation_class_by_name(adapter_name): + adapter = get_adapter_class_by_name(adapter_name) + return adapter.Relation diff --git a/dbt/adapters/snowflake/relation.py b/dbt/adapters/snowflake/relation.py index bd879965404..bf5b61c6485 100644 --- a/dbt/adapters/snowflake/relation.py +++ b/dbt/adapters/snowflake/relation.py @@ -1,4 +1,5 @@ from dbt.adapters.default.relation import DefaultRelation +import dbt.utils class SnowflakeRelation(DefaultRelation): @@ -44,7 +45,7 @@ class SnowflakeRelation(DefaultRelation): } @classmethod - def create_from_node(cls, config, node, **kwargs): + def _create_from_node(cls, config, node, **kwargs): return cls.create( database=config.credentials.database, schema=node.get('schema'), diff --git a/dbt/config.py b/dbt/config.py index d56826f5daf..cadb7d8411f 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -12,6 +12,7 @@ PackageConfig, ProfileConfig from dbt.context.common import env_var from dbt import compat +from dbt.adapters.factory import get_relation_class_by_name from dbt.logger import GLOBAL_LOGGER as logger @@ -19,18 +20,6 @@ DEFAULT_THREADS = 1 DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True DEFAULT_USE_COLORS = True -DEFAULT_QUOTING_GLOBAL = { - 'identifier': True, - 'schema': True, -} -# some adapters need different quoting rules, for example snowflake gets a bit -# weird with quoting on -DEFAULT_QUOTING_ADAPTER = { - 'snowflake': { - 'identifier': False, - 'schema': False, - }, -} DEFAULT_PROFILES_DIR = os.path.join(os.path.expanduser('~'), '.dbt') @@ -676,8 +665,8 @@ def from_parts(cls, project, profile, cli_vars): :returns RuntimeConfig: The new configuration. """ quoting = deepcopy( - DEFAULT_QUOTING_ADAPTER.get(profile.credentials.type, - DEFAULT_QUOTING_GLOBAL) + get_relation_class_by_name(profile.credentials.type) + .DEFAULTS['quote_policy'] ) quoting.update(project.quoting) return cls( @@ -725,11 +714,14 @@ def new_project(self, project_root): # load the new project and its packages project = Project.from_project_root(project_root) - return self.from_parts( + cfg = self.from_parts( project=project, profile=profile, cli_vars=deepcopy(self.cli_vars) ) + # force our quoting back onto the new project. + cfg.quoting = deepcopy(self.quoting) + return cfg def serialize(self): """Serialize the full configuration to a single dictionary. For any diff --git a/dbt/context/common.py b/dbt/context/common.py index d1e20eeafb6..917747d08f6 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -28,7 +28,7 @@ class RelationProxy(object): def __init__(self, adapter): - self.config = adapter.config + self.quoting_config = adapter.config.quoting self.relation_type = adapter.Relation def __getattr__(self, key): @@ -36,7 +36,7 @@ def __getattr__(self, key): def create(self, *args, **kwargs): kwargs['quote_policy'] = dbt.utils.merge( - self.config.quoting, + self.quoting_config, kwargs.pop('quote_policy', {}) ) return self.relation_type.create(*args, **kwargs) @@ -49,7 +49,7 @@ class DatabaseWrapper(object): def __init__(self, model, adapter): self.model = model self.adapter = adapter - self.Relation = RelationProxy(self.adapter) + self.Relation = RelationProxy(adapter) # TODO: clean up this part of the adapter classes self._wrapped = frozenset( @@ -322,8 +322,7 @@ def _return(value): def get_this_relation(db_wrapper, config, model): - return db_wrapper.adapter.Relation.create_from_node( - config, model) + return db_wrapper.Relation.create_from_node(config, model) def generate_base(model, model_dict, config, manifest, source_config, diff --git a/dbt/contracts/project.py b/dbt/contracts/project.py index f46c51b2107..6e69ac2b1b3 100644 --- a/dbt/contracts/project.py +++ b/dbt/contracts/project.py @@ -115,6 +115,12 @@ 'schema': { 'type': 'boolean', }, + 'database': { + 'type': 'boolean', + }, + 'project': { + 'type': 'boolean', + } }, }, 'models': { diff --git a/test/integration/001_simple_copy_test/test_simple_copy.py b/test/integration/001_simple_copy_test/test_simple_copy.py index 6f9930827d3..7cc4ed580e4 100644 --- a/test/integration/001_simple_copy_test/test_simple_copy.py +++ b/test/integration/001_simple_copy_test/test_simple_copy.py @@ -16,6 +16,7 @@ def dir(path): def models(self): return self.dir("models") + class TestSimpleCopy(BaseTestSimpleCopy): @use_profile("postgres") def test__postgres__simple_copy(self): diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 4c0d3d3a819..67b7cc26d2b 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -724,7 +724,11 @@ def test_from_parts(self): expected_project = project.to_project_config() self.assertEqual(expected_project['quoting'], {}) - expected_project['quoting'] = {'identifier': True, 'schema': True} + expected_project['quoting'] = { + 'database': True, + 'identifier': True, + 'schema': True, + } self.assertEqual(config.to_project_config(), expected_project) def test_str(self): @@ -771,7 +775,7 @@ def test_from_args(self): self.assertEqual(config.clean_targets, ['target']) self.assertEqual(config.log_path, 'logs') self.assertEqual(config.modules_path, 'dbt_modules') - self.assertEqual(config.quoting, {'identifier': True, 'schema': True}) + self.assertEqual(config.quoting, {'database': True, 'identifier': True, 'schema': True}) self.assertEqual(config.models, {}) self.assertEqual(config.on_run_start, []) self.assertEqual(config.on_run_end, []) From 0b0e9e02e785adb6d6a75c758b2f967e3c2068ea Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 17 Sep 2018 11:27:05 -0600 Subject: [PATCH 021/133] add a way for tests to reset the adapters known to dbt between runs --- dbt/adapters/factory.py | 7 +++++++ test/integration/base.py | 8 ++++++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dbt/adapters/factory.py b/dbt/adapters/factory.py index 4ddc923074d..b1ec2bbcd0f 100644 --- a/dbt/adapters/factory.py +++ b/dbt/adapters/factory.py @@ -50,6 +50,13 @@ def get_adapter(config): return adapter +def reset_adapters(): + """Clear the adapters. This is useful for tests, which change configs. + """ + with _ADAPTER_LOCK: + _ADAPTERS.clear() + + def get_relation_class_by_name(adapter_name): adapter = get_adapter_class_by_name(adapter_name) return adapter.Relation diff --git a/test/integration/base.py b/test/integration/base.py index 4a1f23cafc7..9dc0040b95b 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -11,7 +11,7 @@ import dbt.flags as flags -from dbt.adapters.factory import get_adapter +from dbt.adapters.factory import get_adapter, reset_adapters from dbt.config import RuntimeConfig from dbt.logger import GLOBAL_LOGGER as logger @@ -296,7 +296,8 @@ def _clean_files(self): def tearDown(self): self._clean_files() - self.adapter = get_adapter(self.config) + if not hasattr(self, 'adapter'): + self.adapter = get_adapter(self.config) self._drop_schema() @@ -305,6 +306,7 @@ def tearDown(self): self.handle.close() self.adapter.cleanup_connections() + reset_adapters() def _create_schema(self): if self.adapter_type == 'bigquery': @@ -340,6 +342,8 @@ def profile_config(self): return {} def run_dbt(self, args=None, expect_pass=True, strict=True): + # clear the adapter cache + reset_adapters() if args is None: args = ["run"] From 63793b74f2895ff3e6c1d5ce34cb895e4048ace3 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 17 Sep 2018 14:09:21 -0600 Subject: [PATCH 022/133] combine profile and context functions --- dbt/adapters/bigquery/impl.py | 7 +++++-- dbt/adapters/default/impl.py | 5 ++--- dbt/context/common.py | 3 +-- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 1757c564a39..f0722e40d7b 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -28,7 +28,7 @@ class BigQueryAdapter(PostgresAdapter): - context_functions = [ + config_functions = [ # deprecated -- use versions that take relations instead "query_for_existing", "execute_model", @@ -51,7 +51,10 @@ class BigQueryAdapter(PostgresAdapter): "drop_relation", "rename_relation", - "get_columns_in_table" + "get_columns_in_table", + + # formerly profile functions + "add_query", ] SCOPE = ('https://www.googleapis.com/auth/bigquery', diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 3a879ec36de..8fe29e4a684 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -52,7 +52,7 @@ def test(row): class DefaultAdapter(object): requires = {} - context_functions = [ + config_functions = [ "get_columns_in_table", "get_missing_columns", "expand_target_column_types", @@ -75,9 +75,8 @@ class DefaultAdapter(object): "drop_relation", "rename_relation", "truncate_relation", - ] - profile_functions = [ + # formerly profile functions "execute", "add_query", ] diff --git a/dbt/context/common.py b/dbt/context/common.py index 917747d08f6..0c5dbc23ff6 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -51,9 +51,8 @@ def __init__(self, model, adapter): self.adapter = adapter self.Relation = RelationProxy(adapter) - # TODO: clean up this part of the adapter classes self._wrapped = frozenset( - self.adapter.context_functions + self.adapter.profile_functions + self.adapter.config_functions ) self._proxied = frozenset(self.adapter.raw_functions) From 665264723d078fd0b7125c3c1da6b37c4382532e Mon Sep 17 00:00:00 2001 From: Ben Edwards Date: Thu, 20 Sep 2018 14:14:36 +1000 Subject: [PATCH 023/133] Add newline around SQL in incremental materialization to guard against line comments. --- .../macros/materializations/incremental/incremental.sql | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dbt/include/global_project/macros/materializations/incremental/incremental.sql b/dbt/include/global_project/macros/materializations/incremental/incremental.sql index ef4ab3ae43e..6d32e503631 100644 --- a/dbt/include/global_project/macros/materializations/incremental/incremental.sql +++ b/dbt/include/global_project/macros/materializations/incremental/incremental.sql @@ -60,7 +60,9 @@ {% set tmp_table_sql -%} {# We are using a subselect instead of a CTE here to allow PostgreSQL to use indexes. -#} - select * from ({{ sql }}) as dbt_incr_sbq + select * from ( + {{ sql }} + ) as dbt_incr_sbq where ({{ sql_where }}) or ({{ sql_where }}) is null {%- endset %} From 2cb73945835227000fb7d9a7fa43d8bef529887b Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 26 Sep 2018 09:56:56 -0600 Subject: [PATCH 024/133] move get_resource_fqns method into the manifest, add tests, add a missing import --- dbt/compilation.py | 13 +------ dbt/config.py | 2 +- dbt/contracts/graph/manifest.py | 10 +++++ test/unit/test_config.py | 66 +++++++++++++++++++++++++++++++++ test/unit/test_manifest.py | 44 +++++++++++++++++++++- 5 files changed, 121 insertions(+), 14 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index 20657de019d..685273b7683 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -232,16 +232,6 @@ def _check_resource_uniqueness(cls, manifest): names_resources[name] = node alias_resources[alias] = node - def get_resource_fqns(self, manifest): - resource_fqns = {} - for unique_id, node in manifest.nodes.items(): - resource_type_plural = node.resource_type + 's' - if resource_type_plural not in resource_fqns: - resource_fqns[resource_type_plural] = [] - resource_fqns[resource_type_plural].append(node.fqn) - - return resource_fqns - def compile(self): linker = Linker() @@ -253,8 +243,7 @@ def compile(self): self._check_resource_uniqueness(manifest) - resource_fqns = self.get_resource_fqns(manifest) - + resource_fqns = manifest.get_resource_fqns() self.config.warn_for_unused_resource_config_paths(resource_fqns) self.link_graph(linker, manifest) diff --git a/dbt/config.py b/dbt/config.py index 8a77f821b48..5e51c3f1a74 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -16,7 +16,7 @@ from dbt.logger import GLOBAL_LOGGER as logger from dbt.utils import DBTConfigKeys - +import dbt.ui.printer DEFAULT_THREADS = 1 DEFAULT_SEND_ANONYMOUS_USAGE_STATS = True diff --git a/dbt/contracts/graph/manifest.py b/dbt/contracts/graph/manifest.py index 9f99e0b7a8c..60e59d7c761 100644 --- a/dbt/contracts/graph/manifest.py +++ b/dbt/contracts/graph/manifest.py @@ -284,6 +284,16 @@ def get_materialization_macro(self, materialization_name, return macro + def get_resource_fqns(self): + resource_fqns = {} + for unique_id, node in self.nodes.items(): + resource_type_plural = node.resource_type + 's' + if resource_type_plural not in resource_fqns: + resource_fqns[resource_type_plural] = [] + resource_fqns[resource_type_plural].append(node.fqn) + + return resource_fqns + def _filter_subgraph(self, subgraph, predicate): """ Given a subgraph of the manifest, and a predicate, filter diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 2a18e94292d..7dcd51dd0dd 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -736,6 +736,72 @@ def test__unused_resource_config_paths(self): unused = project.get_unused_resource_config_paths(resource_fqns) self.assertEqual(len(unused), 3) + def test__get_unused_resource_config_paths_empty(self): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + unused = project.get_unused_resource_config_paths({'models': [ + ['my_test_project', 'foo', 'bar'], + ['my_test_project', 'foo', 'baz'], + ]}) + self.assertEqual(len(unused), 0) + + @mock.patch.object(dbt.config, 'logger') + def test__warn_for_unused_resource_config_paths_empty(self, mock_logger): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + unused = project.warn_for_unused_resource_config_paths({'models': [ + ['my_test_project', 'foo', 'bar'], + ['my_test_project', 'foo', 'baz'], + ]}) + mock_logger.info.assert_not_called() + + +class TestProjectWithConfigs(BaseConfigTest): + def setUp(self): + self.profiles_dir = '/invalid-profiles-path' + self.project_dir = '/invalid-root-path' + super(TestProjectWithConfigs, self).setUp() + self.default_project_data['project-root'] = self.project_dir + self.default_project_data['models'] = { + 'enabled': True, + 'my_test_project': { + 'foo': { + 'materialized': 'view', + 'bar': { + 'materialized': 'table', + } + }, + 'baz': { + 'materialized': 'table', + } + } + } + + def test__get_unused_resource_config_paths(self): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + unused = project.get_unused_resource_config_paths({'models': [ + ['my_test_project', 'foo', 'bar'], + ['my_test_project', 'foo', 'baz'], + ]}) + self.assertEqual(len(unused), 1) + self.assertEqual(unused[0], ['models', 'my_test_project', 'baz']) + + @mock.patch.object(dbt.config, 'logger') + def test__warn_for_unused_resource_config_paths(self, mock_logger): + project = dbt.config.Project.from_project_config( + self.default_project_data + ) + unused = project.warn_for_unused_resource_config_paths({'models': [ + ['my_test_project', 'foo', 'bar'], + ['my_test_project', 'foo', 'baz'], + ]}) + mock_logger.info.assert_called_once() + + class TestProjectFile(BaseFileTest): def setUp(self): diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index a4fd33f8e7a..bc0c2aaca69 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -303,8 +303,50 @@ def test_no_nodes_with_metadata(self, mock_user): } ) + def test_get_resource_fqns_empty(self): + manifest = Manifest(nodes={}, macros={}, docs={}, + generated_at=timestring()) + self.assertEqual(manifest.get_resource_fqns(), {}) - + def test_get_resource_fqns(self): + nodes = copy.copy(self.nested_nodes) + nodes['seed.root.seed'] = ParsedNode( + name='seed', + schema='analytics', + alias='seed', + resource_type='seed', + unique_id='seed.root.seed', + fqn=['root', 'seed'], + empty=False, + package_name='root', + refs=[['events']], + depends_on={ + 'nodes': [], + 'macros': [] + }, + config=self.model_config, + tags=[], + path='seed.csv', + original_file_path='seed.csv', + root_path='', + raw_sql='-- csv --' + ) + manifest = Manifest(nodes=nodes, macros={}, docs={}, + generated_at=timestring()) + expect = { + 'models': sorted(( + ['snowplow', 'events'], + ['root', 'events'], + ['root', 'dep'], + ['root', 'nested'], + ['root', 'sibling'], + ['root', 'multi'], + )), + 'seeds': [['root', 'seed']], + } + resource_fqns = manifest.get_resource_fqns() + resource_fqns['models'].sort() + self.assertEqual(resource_fqns, expect) class MixedManifestTest(unittest.TestCase): From c2bc1c5361b8c2062e4f217ea7fc7aba0ca998fc Mon Sep 17 00:00:00 2001 From: Boris Uvarov Date: Mon, 3 Sep 2018 18:47:51 +0300 Subject: [PATCH 025/133] Add client_session_keep_alive option for Snowflake adapter in order to prevent session timeout after 4 hours of inactivity --- dbt/adapters/snowflake/impl.py | 4 +++- dbt/contracts/connection.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 44963c38a22..20d088e59c4 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -84,7 +84,9 @@ def open_connection(cls, connection): schema=credentials.schema, warehouse=credentials.warehouse, role=credentials.get('role', None), - autocommit=False + autocommit=False, + client_session_keep_alive=credentials.get( + 'client_session_keep_alive', False) ) connection.handle = handle diff --git a/dbt/contracts/connection.py b/dbt/contracts/connection.py index a3d00852c9c..f253b043005 100644 --- a/dbt/contracts/connection.py +++ b/dbt/contracts/connection.py @@ -110,6 +110,9 @@ 'role': { 'type': 'string', }, + 'client_session_keep_alive': { + 'type': 'boolean', + } }, 'required': ['account', 'user', 'password', 'database', 'schema'], } From 16e055a7402cb7e2ab9b47fffba7e954a7f3daab Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 27 Sep 2018 09:09:32 -0600 Subject: [PATCH 026/133] add a deep_map utility function --- dbt/utils.py | 57 +++++++++++++++++++++++++-- test/unit/test_utils.py | 87 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+), 4 deletions(-) diff --git a/dbt/utils.py b/dbt/utils.py index f2d65d67892..fe4a092953c 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -1,12 +1,14 @@ from datetime import datetime from decimal import Decimal -import os -import hashlib -import itertools -import json + import collections import copy import functools +import hashlib +import itertools +import json +import numbers +import os import dbt.exceptions import dbt.flags @@ -263,6 +265,53 @@ def deep_merge_item(destination, key, value): destination[key] = value +def deep_map(func, value, keypath=(), memo=None, _notfound=object()): + """map the function func() onto each non-container value in 'value' + recursively, returning a new value. As long as func does not manipulate + value, then deep_map will also not manipulate it. + + value should be a value returned by `yaml.safe_load` or `json.load` - the + only expected types are list, dict, native python number, str, NoneType, + and bool. + + func() will be called on numbers, strings, Nones, and booleans. Its first + parameter will be the value, and the second will be its keypath, an + iterable over the __getitem__ keys needed to get to it. + """ + # TODO: if we could guarantee no cycles, we would not need to memoize + if memo is None: + memo = {} + + value_id = id(value) + cached = memo.get(value_id, _notfound) + if cached is not _notfound: + return cached + + atomic_types = (int, float, basestring, type(None), bool) + + if isinstance(value, list): + ret = [ + deep_map(func, v, (keypath + (idx,)), memo) + for idx, v in enumerate(value) + ] + elif isinstance(value, dict): + ret = { + k: deep_map(func, v, (keypath + (k,)), memo) + for k, v in value.items() + } + elif isinstance(value, atomic_types): + ret = func(value, keypath) + else: + ok_types = (list, dict) + atomic_types + # TODO(jeb): real error + raise TypeError( + 'in deep_map, expected one of {!r}, got {!r}' + .format(ok_types, type(value)) + ) + memo[value_id] = ret + return ret + + class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index e18feeef1d4..b8cbcf64ffd 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -41,3 +41,90 @@ def test__simple_cases(self): case['expected'], actual, 'failed on {} (actual {}, expected {})'.format( case['description'], actual, case['expected'])) + + + +class TestDeepMap(unittest.TestCase): + def setUp(self): + self.input_value = { + 'foo': { + 'bar': 'hello', + 'baz': [1, 90.5, '990', '89.9'], + }, + 'nested': [ + { + 'test': '90', + 'other_test': None, + }, + { + 'test': 400, + 'other_test': 4.7e9, + }, + ], + } + + @staticmethod + def intify_all(value, _): + try: + return int(value) + except (TypeError, ValueError): + return -1 + + def test__simple_cases(self): + expected = { + 'foo': { + 'bar': -1, + 'baz': [1, 90, 990, -1], + }, + 'nested': [ + { + 'test': 90, + 'other_test': -1, + }, + { + 'test': 400, + 'other_test': 4700000000, + }, + ], + } + actual = dbt.utils.deep_map(self.intify_all, self.input_value) + self.assertEquals(actual, expected) + + actual = dbt.utils.deep_map(self.intify_all, expected) + self.assertEquals(actual, expected) + + + @staticmethod + def special_keypath(value, keypath): + + if tuple(keypath) == ('foo', 'baz', 1): + return 'hello' + else: + return value + + def test__keypath(self): + expected = { + 'foo': { + 'bar': 'hello', + # the only change from input is the second entry here + 'baz': [1, 'hello', '990', '89.9'], + }, + 'nested': [ + { + 'test': '90', + 'other_test': None, + }, + { + 'test': 400, + 'other_test': 4.7e9, + }, + ], + } + actual = dbt.utils.deep_map(self.special_keypath, self.input_value) + self.assertEquals(actual, expected) + + actual = dbt.utils.deep_map(self.special_keypath, expected) + self.assertEquals(actual, expected) + + + From 5e5916ce08ef777a575187f6a0d599d70326c5cf Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 27 Sep 2018 11:15:53 -0600 Subject: [PATCH 027/133] on profiles, render env vars and cli vars --- dbt/config.py | 218 ++++++++++++++++++++++++++++----------- dbt/context/common.py | 4 + test/unit/test_config.py | 144 +++++++++++++++----------- test/unit/test_utils.py | 3 - test/unit/utils.py | 7 +- 5 files changed, 248 insertions(+), 128 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index cadb7d8411f..513c76e2289 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -10,7 +10,7 @@ from dbt.contracts.connection import Connection, create_credentials from dbt.contracts.project import Project as ProjectContract, Configuration, \ PackageConfig, ProfileConfig -from dbt.context.common import env_var +from dbt.context.common import env_var, Var from dbt import compat from dbt.adapters.factory import get_relation_class_by_name @@ -104,27 +104,94 @@ def colorize_output(config): return config.get('use_colors', True) -def _render(key, value, ctx): - """Render an entry in the credentials dictionary, in case it's jinja. +class ConfigRenderer(object): + """A renderer provides configuration rendering for a given set of cli + variables and a render type. + """ + def __init__(self, cli_vars): + self.context = {'env_var': env_var} + self.context['var'] = Var(None, self.context, cli_vars) - If the parsed entry is a string and has the name 'port', this will attempt - to cast it to an int, and on failure will return the parsed string. + @staticmethod + def _is_hook_path(keypath): + if not keypath: + return False - :param key str: The key to convert on. - :param value Any: The value to potentially render - :param ctx dict: The context dictionary, mapping function names to - functions that take a str and return a value - :return Any: The rendered entry. - """ - if not isinstance(value, compat.basestring): + first = keypath[0] + # run hooks + if first in {'on-run-start', 'on-run-end'}: + return True + # model hooks + if first in {'seeds', 'models'}: + if 'pre-hook' in keypath or 'post-hook' in keypath: + return True + + return False + + @staticmethod + def _is_port_path(keypath): + return len(keypath) == 2 and keypath[-1] == 'port' + + @staticmethod + def _convert_port(value, keypath): + + if len(keypath) != 4: + return value + + if keypath[-1] == 'port' and keypath[1] == 'outputs': + try: + return int(value) + except ValueError: + pass # let the validator or connection handle this return value - result = dbt.clients.jinja.get_rendered(value, ctx) - if key == 'port': - try: - return int(result) - except ValueError: - pass # let the validator or connection handle this - return result + + def _render_project_entry(self, value, keypath): + """Render an entry, in case it's jinja. This is meant to be passed to + dbt.utils.deep_map. + + If the parsed entry is a string and has the name 'port', this will + attempt to cast it to an int, and on failure will return the parsed + string. + + :param value Any: The value to potentially render + :param key str: The key to convert on. + :return Any: The rendered entry. + """ + # hooks should be treated as raw sql, they'll get rendered later + if self._is_hook_path(keypath): + return value + + return self.render_value(value) + + def render_value(self, value, keypath=None): + # keypath is ignored. + # if it wasn't read as a string, ignore it + if not isinstance(value, compat.basestring): + return value + + return dbt.clients.jinja.get_rendered(value, self.context) + + def _render_profile_data(self, value, keypath): + result = self.render_value(value) + if len(keypath) == 1 and keypath[-1] == 'port': + try: + result = int(result) + except ValueError: + # let the validator or connection handle this + pass + return result + + def render(self, as_parsed): + return dbt.utils.deep_map(self.render_value, as_parsed) + + def render_project(self, as_parsed): + """Render the parsed data, returning a new dict (or whatever was read). + """ + return dbt.utils.deep_map(self._render_project_entry, as_parsed) + + def render_profile_data(self, as_parsed): + """Render the chosen profile entry, as it was parsed.""" + return dbt.utils.deep_map(self._render_profile_data, as_parsed) class Project(object): @@ -362,16 +429,6 @@ def validate(self): except dbt.exceptions.ValidationException as exc: raise DbtProfileError(str(exc)) - @staticmethod - def _rendered_profile(profile): - # if entries are strings, we want to render them so we can get any - # environment variables that might store important credentials - # elements. - return { - k: _render(k, v, {'env_var': env_var}) - for k, v in profile.items() - } - @staticmethod def _credentials_from_profile(profile, profile_name, target_name): # credentials carry their 'type' in their actual type, not their @@ -401,25 +458,12 @@ def _pick_profile_name(args_profile_name, project_profile_name=None): return profile_name @staticmethod - def _pick_target(raw_profile, profile_name, target_override=None): - - if target_override is not None: - target_name = target_override - elif 'target' in raw_profile: - target_name = raw_profile['target'] - else: - raise DbtProfileError( - "target not specified in profile '{}'".format(profile_name) - ) - return target_name - - @staticmethod - def _get_profile_data(raw_profile, profile_name, target_name): - if 'outputs' not in raw_profile: + def _get_profile_data(profile, profile_name, target_name): + if 'outputs' not in profile: raise DbtProfileError( "outputs not specified in profile '{}'".format(profile_name) ) - outputs = raw_profile['outputs'] + outputs = profile['outputs'] if target_name not in outputs: outputs = '\n'.join(' - {}'.format(output) @@ -468,15 +512,50 @@ def from_credentials(cls, credentials, threads, profile_name, target_name, return profile @classmethod - def from_raw_profile_info(cls, raw_profile, profile_name, user_cfg=None, - target_override=None, threads_override=None): + def _render_profile(cls, raw_profile, profile_name, target_override, + cli_vars): + """This is a containment zone for the hateful way we're rendering + profiles. + """ + renderer = ConfigRenderer(cli_vars=cli_vars) + + # rendering profiles is a bit complex. Two constraints cause trouble: + # 1) users should be able to use environment/cli variables to specify + # the target in their profile. + # 2) Missing environment/cli variables in profiles/targets that don't + # end up getting selected should not cause errors. + # so first we'll just render the target name, then we use that rendered + # name to extract a profile that we can render. + if target_override is not None: + target_name = target_override + elif 'target' in raw_profile: + # render the target if it was parsed from yaml + target_name = renderer.render_value(raw_profile['target']) + else: + raise DbtProfileError( + "target not specified in profile '{}'".format(profile_name) + ) + + raw_profile_data = cls._get_profile_data( + raw_profile, profile_name, target_name + ) + + profile_data = renderer.render_profile_data(raw_profile_data) + return target_name, profile_data + + @classmethod + def from_raw_profile_info(cls, raw_profile, profile_name, cli_vars, + user_cfg=None, target_override=None, + threads_override=None): """Create a profile from its raw profile information. (this is an intermediate step, mostly useful for unit testing) - :param raw_profiles dict: The profile data for a single profile, from - disk as yaml. + :param raw_profile dict: The profile data for a single profile, from + disk as yaml and its values rendered with jinja. :param profile_name str: The profile name used. + :param cli_vars dict: The command-line variables passed as arguments, + as a dict. :param user_cfg Optional[dict]: The global config for the user, if it was present. :param target_override Optional[str]: The target to use, if provided on @@ -487,22 +566,20 @@ def from_raw_profile_info(cls, raw_profile, profile_name, user_cfg=None, target could not be found :returns Profile: The new Profile object. """ - target_name = cls._pick_target( - raw_profile, profile_name, target_override + # user_cfg is not rendered since it only contains booleans. + # TODO: should it be, and the values coerced to bool? + target_name, profile_data = cls._render_profile( + raw_profile, profile_name, target_override, cli_vars ) - profile_data = cls._get_profile_data( - raw_profile, profile_name, target_name - ) - rendered_profile = cls._rendered_profile(profile_data) # valid connections never include the number of threads, but it's # stored on a per-connection level in the raw configs - threads = rendered_profile.pop('threads', DEFAULT_THREADS) + threads = profile_data.pop('threads', DEFAULT_THREADS) if threads_override is not None: threads = threads_override credentials = cls._credentials_from_profile( - rendered_profile, profile_name, target_name + profile_data, profile_name, target_name ) return cls.from_credentials( credentials=credentials, @@ -513,11 +590,13 @@ def from_raw_profile_info(cls, raw_profile, profile_name, user_cfg=None, ) @classmethod - def from_raw_profiles(cls, raw_profiles, profile_name, + def from_raw_profiles(cls, raw_profiles, profile_name, cli_vars, target_override=None, threads_override=None): """ :param raw_profiles dict: The profile data, from disk as yaml. :param profile_name str: The profile name to use. + :param cli_vars dict: The command-line variables passed as arguments, + as a dict. :param target_override Optional[str]: The target to use, if provided on the command line. :param threads_override Optional[str]: The thread count to use, if @@ -528,29 +607,35 @@ def from_raw_profiles(cls, raw_profiles, profile_name, target could not be found :returns Profile: The new Profile object. """ - # TODO(jeb): Validate the raw_profiles structure right here if profile_name not in raw_profiles: raise DbtProjectError( "Could not find profile named '{}'".format(profile_name) ) + + # First, we've already got our final decision on profile name, and we + # don't render keys, so we can pluck that out raw_profile = raw_profiles[profile_name] + user_cfg = raw_profiles.get('config') return cls.from_raw_profile_info( raw_profile=raw_profile, profile_name=profile_name, + cli_vars=cli_vars, user_cfg=user_cfg, target_override=target_override, threads_override=threads_override, ) @classmethod - def from_args(cls, args, project_profile_name=None): + def from_args(cls, args, project_profile_name=None, cli_vars=None): """Given the raw profiles as read from disk and the name of the desired profile if specified, return the profile component of the runtime config. :param args argparse.Namespace: The arguments as parsed from the cli. + :param cli_vars dict: The command-line variables passed as arguments, + as a dict. :param project_profile_name Optional[str]: The profile name, if specified in a project. :raises DbtProjectError: If there is no profile name specified in the @@ -560,6 +645,9 @@ def from_args(cls, args, project_profile_name=None): target could not be found. :returns Profile: The new Profile object. """ + if cli_vars is None: + cli_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) + threads_override = getattr(args, 'threads', None) # TODO(jeb): is it even possible for this to not be set? profiles_dir = getattr(args, 'profiles_dir', DEFAULT_PROFILES_DIR) @@ -571,6 +659,7 @@ def from_args(cls, args, project_profile_name=None): return cls.from_raw_profiles( raw_profiles=raw_profiles, profile_name=profile_name, + cli_vars=cli_vars, target_override=target_override, threads_override=threads_override ) @@ -759,13 +848,18 @@ def from_args(cls, args): :raises DbtProfileError: If the profile is invalid or missing. :raises ValidationException: If the cli variables are invalid. """ + cli_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) + # build the project and read in packages.yml project = Project.from_current_directory() # build the profile - profile = Profile.from_args(args, project.profile_name) + profile = Profile.from_args( + args=args, + project_profile_name=project.profile_name, + cli_vars=cli_vars + ) - cli_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) return cls.from_parts( project=project, profile=profile, diff --git a/dbt/context/common.py b/dbt/context/common.py index 0c5dbc23ff6..88088354ffa 100644 --- a/dbt/context/common.py +++ b/dbt/context/common.py @@ -229,6 +229,10 @@ def __init__(self, model, context, overrides): elif isinstance(model, ParsedNode): local_vars = model.config.get('vars', {}) self.model_name = model.name + elif model is None: + # during config parsing we have no model and no local vars + self.model_name = '' + local_vars = {} else: # still used for wrapping self.model_name = model.nice_name diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 67b7cc26d2b..8f1c45363ba 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -205,10 +205,13 @@ def setUp(self): self.profiles_dir = '/invalid-path' super(TestProfile, self).setUp() - def test_from_raw_profiles(self): - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' + def from_raw_profiles(self): + return dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default', {} ) + + def test_from_raw_profiles(self): + profile = self.from_raw_profiles() self.assertEqual(profile.profile_name, 'default') self.assertEqual(profile.target_name, 'postgres') self.assertEqual(profile.threads, 7) @@ -228,9 +231,7 @@ def test_config_override(self): 'send_anonymous_usage_stats': False, 'use_colors': False } - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + profile = self.from_raw_profiles() self.assertEqual(profile.profile_name, 'default') self.assertEqual(profile.target_name, 'postgres') self.assertFalse(profile.send_anonymous_usage_stats) @@ -240,9 +241,7 @@ def test_partial_config_override(self): self.default_profile_data['config'] = { 'send_anonymous_usage_stats': False, } - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + profile = self.from_raw_profiles() self.assertEqual(profile.profile_name, 'default') self.assertEqual(profile.target_name, 'postgres') self.assertFalse(profile.send_anonymous_usage_stats) @@ -251,9 +250,7 @@ def test_partial_config_override(self): def test_missing_type(self): del self.default_profile_data['default']['outputs']['postgres']['type'] with self.assertRaises(dbt.config.DbtProfileError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + self.from_raw_profiles() self.assertIn('type', str(exc.exception)) self.assertIn('postgres', str(exc.exception)) self.assertIn('default', str(exc.exception)) @@ -261,9 +258,7 @@ def test_missing_type(self): def test_bad_type(self): self.default_profile_data['default']['outputs']['postgres']['type'] = 'invalid' with self.assertRaises(dbt.config.DbtProfileError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + self.from_raw_profiles() self.assertIn('Credentials', str(exc.exception)) self.assertIn('postgres', str(exc.exception)) self.assertIn('default', str(exc.exception)) @@ -271,9 +266,7 @@ def test_bad_type(self): def test_invalid_credentials(self): del self.default_profile_data['default']['outputs']['postgres']['host'] with self.assertRaises(dbt.config.DbtProfileError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + self.from_raw_profiles() self.assertIn('Credentials', str(exc.exception)) self.assertIn('postgres', str(exc.exception)) self.assertIn('default', str(exc.exception)) @@ -281,16 +274,14 @@ def test_invalid_credentials(self): def test_target_missing(self): del self.default_profile_data['default']['target'] with self.assertRaises(dbt.config.DbtProfileError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + self.from_raw_profiles() self.assertIn('target not specified in profile', str(exc.exception)) self.assertIn('default', str(exc.exception)) def test_profile_invalid_project(self): with self.assertRaises(dbt.config.DbtProjectError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'invalid-profile' + dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'invalid-profile', {} ) self.assertEqual(exc.exception.result_type, 'invalid_project') @@ -299,8 +290,9 @@ def test_profile_invalid_project(self): def test_profile_invalid_target(self): with self.assertRaises(dbt.config.DbtProfileError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default', target_override='nope', + dbt.config.Profile.from_raw_profiles( + self.default_profile_data, 'default', {}, + target_override='nope', ) self.assertIn('nope', str(exc.exception)) @@ -310,51 +302,56 @@ def test_profile_invalid_target(self): def test_no_outputs(self): with self.assertRaises(dbt.config.DbtProfileError) as exc: - profile = dbt.config.Profile.from_raw_profiles( - {'some-profile': {'target': 'blah'}}, 'some-profile' + dbt.config.Profile.from_raw_profiles( + {'some-profile': {'target': 'blah'}}, 'some-profile', {} ) self.assertIn('outputs not specified', str(exc.exception)) self.assertIn('some-profile', str(exc.exception)) def test_neq(self): - profile = dbt.config.Profile.from_raw_profiles( - self.default_profile_data, 'default' - ) + profile = self.from_raw_profiles() self.assertNotEqual(profile, object()) def test_eq(self): profile = dbt.config.Profile.from_raw_profiles( - deepcopy(self.default_profile_data), 'default' + deepcopy(self.default_profile_data), 'default', {} ) other = dbt.config.Profile.from_raw_profiles( - deepcopy(self.default_profile_data), 'default' + deepcopy(self.default_profile_data), 'default', {} ) self.assertEqual(profile, other) - def test_invalid_env_vars(self): - self.env_override['env_value_port'] = 'hello' - with mock.patch.dict(os.environ, self.env_override): - with self.assertRaises(dbt.config.DbtProfileError) as exc: - dbt.config.Profile.from_raw_profile_info( - self.default_profile_data['default'], - 'default', - target_override='with-vars' - ) - self.assertIn("not of type 'integer'", str(exc.exception)) - class TestProfileFile(BaseFileTest): def setUp(self): super(TestProfileFile, self).setUp() self.write_profile(self.default_profile_data) + def from_raw_profile_info(self, raw_profile=None, profile_name='default', **kwargs): + if raw_profile is None: + raw_profile = self.default_profile_data['default'] + kw = { + 'raw_profile': raw_profile, + 'profile_name': profile_name, + 'cli_vars': {}, + } + kw.update(kwargs) + return dbt.config.Profile.from_raw_profile_info(**kw) + + def from_args(self, project_profile_name='default', **kwargs): + kw = { + 'args': self.args, + 'project_profile_name': project_profile_name, + 'cli_vars': {}, + } + kw.update(kwargs) + return dbt.config.Profile.from_args(**kw) + + def test_profile_simple(self): - profile = dbt.config.Profile.from_args(self.args, 'default') - from_raw = dbt.config.Profile.from_raw_profile_info( - self.default_profile_data['default'], - 'default' - ) + profile = self.from_args() + from_raw = self.from_raw_profile_info() self.assertEqual(profile.profile_name, 'default') self.assertEqual(profile.target_name, 'postgres') @@ -374,8 +371,8 @@ def test_profile_simple(self): def test_profile_override(self): self.args.profile = 'other' self.args.threads = 3 - profile = dbt.config.Profile.from_args(self.args, 'default') - from_raw = dbt.config.Profile.from_raw_profile_info( + profile = self.from_args() + from_raw = self.from_raw_profile_info( self.default_profile_data['other'], 'other', threads_override=3, @@ -398,10 +395,8 @@ def test_profile_override(self): def test_target_override(self): self.args.target = 'redshift' - profile = dbt.config.Profile.from_args(self.args, 'default') - from_raw = dbt.config.Profile.from_raw_profile_info( - self.default_profile_data['default'], - 'default', + profile = self.from_args() + from_raw = self.from_raw_profile_info( target_override='redshift' ) @@ -423,10 +418,8 @@ def test_target_override(self): def test_env_vars(self): self.args.target = 'with-vars' with mock.patch.dict(os.environ, self.env_override): - profile = dbt.config.Profile.from_args(self.args, 'default') - from_raw = dbt.config.Profile.from_raw_profile_info( - self.default_profile_data['default'], - 'default', + profile = self.from_args() + from_raw = self.from_raw_profile_info( target_override='with-vars' ) @@ -442,9 +435,40 @@ def test_env_vars(self): self.assertEqual(profile.credentials.password, 'env-postgres-pass') self.assertEqual(profile, from_raw) + def test_env_vars_env_target(self): + self.default_profile_data['default']['target'] = "{{ env_var('env_value_target') }}" + self.write_profile(self.default_profile_data) + self.env_override['env_value_target'] = 'with-vars' + with mock.patch.dict(os.environ, self.env_override): + profile = self.from_args() + from_raw = self.from_raw_profile_info( + target_override='with-vars' + ) + + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'with-vars') + self.assertEqual(profile.threads, 1) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertEqual(profile.credentials.type, 'postgres') + self.assertEqual(profile.credentials.host, 'env-postgres-host') + self.assertEqual(profile.credentials.port, 6543) + self.assertEqual(profile.credentials.user, 'env-postgres-user') + self.assertEqual(profile.credentials.password, 'env-postgres-pass') + self.assertEqual(profile, from_raw) + + def test_invalid_env_vars(self): + self.env_override['env_value_port'] = 'hello' + self.args.target = 'with-vars' + with mock.patch.dict(os.environ, self.env_override): + with self.assertRaises(dbt.config.DbtProfileError) as exc: + profile = self.from_args() + + self.assertIn("not of type 'integer'", str(exc.exception)) + def test_no_profile(self): with self.assertRaises(dbt.config.DbtProjectError) as exc: - dbt.config.Profile.from_args(self.args) + profile = self.from_args(project_profile_name=None) self.assertIn('no profile was specified', str(exc.exception)) @@ -708,7 +732,7 @@ def get_project(self): def get_profile(self): return dbt.config.Profile.from_raw_profiles( - self.default_profile_data, self.default_project_data['profile'] + self.default_profile_data, self.default_project_data['profile'], {} ) def test_from_parts(self): diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index b8cbcf64ffd..61d02d8097f 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -43,7 +43,6 @@ def test__simple_cases(self): case['description'], actual, case['expected'])) - class TestDeepMap(unittest.TestCase): def setUp(self): self.input_value = { @@ -126,5 +125,3 @@ def test__keypath(self): actual = dbt.utils.deep_map(self.special_keypath, expected) self.assertEquals(actual, expected) - - diff --git a/test/unit/utils.py b/test/unit/utils.py index d09fae907e1..ed94245814b 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -8,13 +8,14 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): from dbt.config import Project, Profile, RuntimeConfig from dbt.utils import parse_cli_vars from copy import deepcopy + if not isinstance(cli_vars, dict): + cli_vars = parse_cli_vars(cli_vars) if not isinstance(project, Project): project = Project.from_project_config(deepcopy(project), packages) if not isinstance(profile, Profile): profile = Profile.from_raw_profile_info(deepcopy(profile), - project.profile_name) - if not isinstance(cli_vars, dict): - cli_vars = parse_cli_vars(cli_vars) + project.profile_name, + cli_vars) return RuntimeConfig.from_parts( project=project, From 4da156f3921ce77e79ccb3e5149edd324149670e Mon Sep 17 00:00:00 2001 From: Boris Uvarov Date: Fri, 28 Sep 2018 12:20:18 +0300 Subject: [PATCH 028/133] Add unit tests --- test/unit/test_snowflake_adapter.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/unit/test_snowflake_adapter.py b/test/unit/test_snowflake_adapter.py index d85086c04ac..3239d681941 100644 --- a/test/unit/test_snowflake_adapter.py +++ b/test/unit/test_snowflake_adapter.py @@ -94,3 +94,26 @@ def test_quoting_on_rename(self): self.mock_execute.assert_has_calls([ mock.call('alter table "test_schema".table_a rename to table_b', None) ]) + + def test_client_session_keep_alive_false_by_default(self): + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=False, database='test_databse', + password='test_password', role=None, schema='public', + user='test_user', warehouse='test_warehouse') + ]) + + def test_client_session_keep_alive_true(self): + self.config.credentials = self.config.credentials.incorporate( + client_session_keep_alive=True) + self.adapter = SnowflakeAdapter(self.config) + self.adapter.get_connection(name='new_connection_with_new_config') + + self.snowflake.assert_has_calls([ + mock.call( + account='test_account', autocommit=False, + client_session_keep_alive=True, database='test_databse', + password='test_password', role=None, schema='public', + user='test_user', warehouse='test_warehouse') + ]) From addcb1460be6e5893ee5fb4371907f8bb3e62365 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 28 Sep 2018 07:36:47 -0600 Subject: [PATCH 029/133] Add rendering to projects as well, fix up existing unit tests --- dbt/config.py | 21 +++++++++++++-------- dbt/main.py | 2 +- dbt/task/debug.py | 8 ++++++-- dbt/task/deps.py | 4 ++-- test/unit/test_config.py | 6 +++--- 5 files changed, 25 insertions(+), 16 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index 513c76e2289..abaa8d255da 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -350,7 +350,7 @@ def validate(self): raise DbtProjectError(str(exc)) @classmethod - def from_project_root(cls, project_root): + def from_project_root(cls, project_root, cli_vars): """Create a project from a root directory. Reads in dbt_project.yml and packages.yml, if it exists. @@ -369,14 +369,19 @@ def from_project_root(cls, project_root): .format(project_yaml_filepath) ) + if isinstance(cli_vars, compat.basestring): + cli_vars = dbt.utils.parse_cli_vars(cli_vars) + renderer = ConfigRenderer(cli_vars) + project_dict = _load_yaml(project_yaml_filepath) - project_dict['project-root'] = project_root + rendered_project = renderer.render_project(project_dict) + rendered_project['project-root'] = project_root packages_dict = package_data_from_root(project_root) - return cls.from_project_config(project_dict, packages_dict) + return cls.from_project_config(rendered_project, packages_dict) @classmethod - def from_current_directory(cls): - return cls.from_project_root(os.getcwd()) + def from_current_directory(cls, cli_vars): + return cls.from_project_root(os.getcwd(), cli_vars) def hashed_name(self): return hashlib.md5(self.project_name.encode('utf-8')).hexdigest() @@ -800,8 +805,8 @@ def new_project(self, project_root): # copy profile profile = Profile(**self.to_profile_info()) profile.validate() - # load the new project and its packages - project = Project.from_project_root(project_root) + # load the new project and its packages. Don't pass cli variables. + project = Project.from_project_root(project_root, {}) cfg = self.from_parts( project=project, @@ -851,7 +856,7 @@ def from_args(cls, args): cli_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) # build the project and read in packages.yml - project = Project.from_current_directory() + project = Project.from_current_directory(cli_vars) # build the profile profile = Profile.from_args( diff --git a/dbt/main.py b/dbt/main.py index 126b9d0b872..89618ccb375 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -209,7 +209,7 @@ def invoke_dbt(parsed): try: if parsed.which == 'deps': # deps doesn't need a profile, so don't require one. - cfg = Project.from_current_directory() + cfg = Project.from_current_directory(getattr(parsed, 'vars', '{}')) elif parsed.which != 'debug': # for debug, we will attempt to load the various configurations as # part of the task, so just leave cfg=None. diff --git a/dbt/task/debug.py b/dbt/task/debug.py index 2d521371e9e..3918897c9b4 100644 --- a/dbt/task/debug.py +++ b/dbt/task/debug.py @@ -3,6 +3,7 @@ from dbt.logger import GLOBAL_LOGGER as logger import dbt.clients.system import dbt.config +import dbt.utils from dbt.task.base_task import BaseTask @@ -27,15 +28,18 @@ def diag(self): # if we got here, a 'dbt_project.yml' does exist, but we have not tried # to parse it. project_profile = None + cli_vars = dbt.utils.parse_cli_vars(getattr(self.args, 'vars', '{}')) + try: - project = dbt.config.Project.from_current_directory() + project = dbt.config.Project.from_current_directory(cli_vars) project_profile = project.profile_name except dbt.config.DbtConfigError as exc: project = 'ERROR loading project: {!s}'.format(exc) # log the profile we decided on as well, if it's available. try: - profile = dbt.config.Profile.from_args(self.args, project_profile) + profile = dbt.config.Profile.from_args(self.args, project_profile, + cli_vars) except dbt.config.DbtConfigError as exc: profile = 'ERROR loading profile: {!s}'.format(exc) diff --git a/dbt/task/deps.py b/dbt/task/deps.py index fd07f2b6f58..9f5daba2dde 100644 --- a/dbt/task/deps.py +++ b/dbt/task/deps.py @@ -237,7 +237,7 @@ def _checkout(self, project): def _fetch_metadata(self, project): path = self._checkout(project) - return project.from_project_root(path) + return project.from_project_root(path, {}) def install(self, project): dest_path = self.get_installation_path(project) @@ -273,7 +273,7 @@ def _fetch_metadata(self, project): self.local, project.project_root) - return project.from_project_root(project_file_path) + return project.from_project_root(project_file_path, {}) def install(self, project): src_path = dbt.clients.system.resolve_path_from_base( diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 8f1c45363ba..7eb40fe9364 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -693,7 +693,7 @@ def test_invalid_project_name(self): def test_no_project(self): with self.assertRaises(dbt.config.DbtProjectError) as exc: - dbt.config.Project.from_project_root(self.project_dir) + dbt.config.Project.from_project_root(self.project_dir, {}) self.assertIn('no dbt_project.yml', str(exc.exception)) @@ -706,7 +706,7 @@ def setUp(self): self.default_project_data['project-root'] = self.project_dir def test_from_project_root(self): - project = dbt.config.Project.from_project_root(self.project_dir) + project = dbt.config.Project.from_project_root(self.project_dir, {}) from_config = dbt.config.Project.from_project_config( self.default_project_data ) @@ -715,7 +715,7 @@ def test_from_project_root(self): def test_with_invalid_package(self): self.write_packages({'invalid': ['not a package of any kind']}) with self.assertRaises(dbt.config.DbtProjectError) as exc: - dbt.config.Project.from_project_root(self.project_dir) + dbt.config.Project.from_project_root(self.project_dir, {}) class TestRuntimeConfig(BaseConfigTest): From a9487e89bfdfef17cbc8de46fdb4ac8807fdd4a0 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 28 Sep 2018 10:22:16 -0600 Subject: [PATCH 030/133] more unit tests --- test/unit/test_config.py | 108 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 7eb40fe9364..a9aef85cc65 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -129,6 +129,15 @@ def setUp(self): 'pass': "{{ env_var('env_value_pass') }}", 'dbname': "{{ env_var('env_value_dbname') }}", 'schema': "{{ env_var('env_value_schema') }}", + }, + 'cli-and-env-vars': { + 'type': "{{ env_var('env_value_type') }}", + 'host': "{{ var('cli_value_host') }}", + 'port': "{{ env_var('env_value_port') }}", + 'user': "{{ env_var('env_value_user') }}", + 'pass': "{{ env_var('env_value_pass') }}", + 'dbname': "{{ env_var('env_value_dbname') }}", + 'schema': "{{ env_var('env_value_schema') }}", } }, 'target': 'postgres', @@ -158,6 +167,7 @@ def setUp(self): 'env_value_pass': 'env-postgres-pass', 'env_value_dbname': 'env-postgres-dbname', 'env_value_schema': 'env-postgres-schema', + 'env_value_project': 'blah', } @@ -466,6 +476,28 @@ def test_invalid_env_vars(self): self.assertIn("not of type 'integer'", str(exc.exception)) + def test_cli_and_env_vars(self): + self.args.target = 'cli-and-env-vars' + self.args.vars = '{"cli_value_host": "cli-postgres-host"}' + with mock.patch.dict(os.environ, self.env_override): + profile = self.from_args(cli_vars=None) + from_raw = self.from_raw_profile_info( + target_override='cli-and-env-vars', + cli_vars={'cli_value_host': 'cli-postgres-host'}, + ) + + self.assertEqual(profile.profile_name, 'default') + self.assertEqual(profile.target_name, 'cli-and-env-vars') + self.assertEqual(profile.threads, 1) + self.assertTrue(profile.send_anonymous_usage_stats) + self.assertTrue(profile.use_colors) + self.assertEqual(profile.credentials.type, 'postgres') + self.assertEqual(profile.credentials.host, 'cli-postgres-host') + self.assertEqual(profile.credentials.port, 6543) + self.assertEqual(profile.credentials.user, 'env-postgres-user') + self.assertEqual(profile.credentials.password, 'env-postgres-pass') + self.assertEqual(profile, from_raw) + def test_no_profile(self): with self.assertRaises(dbt.config.DbtProjectError) as exc: profile = self.from_args(project_profile_name=None) @@ -711,6 +743,8 @@ def test_from_project_root(self): self.default_project_data ) self.assertEqual(project, from_config) + self.assertEqual(project.version, "0.0.1") + self.assertEqual(project.project_name, 'my_test_project') def test_with_invalid_package(self): self.write_packages({'invalid': ['not a package of any kind']}) @@ -718,6 +752,28 @@ def test_with_invalid_package(self): dbt.config.Project.from_project_root(self.project_dir, {}) +class TestVariableProjectFile(BaseFileTest): + def setUp(self): + super(TestVariableProjectFile, self).setUp() + self.default_project_data['version'] = "{{ var('cli_version') }}" + self.default_project_data['name'] = "{{ env_var('env_value_project') }}" + self.write_project(self.default_project_data) + # and after the fact, add the project root + self.default_project_data['project-root'] = self.project_dir + + def test_cli_and_env_vars(self): + cli_vars = '{"cli_version": "0.1.2"}' + with mock.patch.dict(os.environ, self.env_override): + project = dbt.config.Project.from_project_root( + self.project_dir, + cli_vars + ) + + self.assertEqual(project.version, "0.1.2") + self.assertEqual(project.project_name, 'blah') + + + class TestRuntimeConfig(BaseConfigTest): def setUp(self): self.profiles_dir = '/invalid-profiles-path' @@ -806,3 +862,55 @@ def test_from_args(self): self.assertEqual(config.archive, []) self.assertEqual(config.seeds, {}) self.assertEqual(config.packages, PackageConfig(packages=[])) + + +class TestVariableRuntimeConfigFiles(BaseFileTest): + def setUp(self): + super(TestVariableRuntimeConfigFiles, self).setUp() + self.default_project_data.update({ + 'version': "{{ var('cli_version') }}", + 'name': "{{ env_var('env_value_project') }}", + 'on-run-end': [ + "{{ env_var('env_value_project') }}", + ], + 'models': { + 'foo': { + 'post-hook': "{{ env_var('env_value_target') }}", + }, + 'bar': { + # just gibberish, make sure it gets interpreted + 'materialized': "{{ env_var('env_value_project') }}", + } + }, + 'seeds': { + 'foo': { + 'post-hook': "{{ env_var('env_value_target') }}", + }, + 'bar': { + # just gibberish, make sure it gets interpreted + 'materialized': "{{ env_var('env_value_project') }}", + } + }, + }) + self.write_project(self.default_project_data) + self.write_profile(self.default_profile_data) + # and after the fact, add the project root + self.default_project_data['project-root'] = self.project_dir + + def test_cli_and_env_vars(self): + self.args.target = 'cli-and-env-vars' + self.args.vars = '{"cli_value_host": "cli-postgres-host", "cli_version": "0.1.2"}' + with mock.patch.dict(os.environ, self.env_override), temp_cd(self.project_dir): + config = dbt.config.RuntimeConfig.from_args(self.args) + + self.assertEqual(config.version, "0.1.2") + self.assertEqual(config.project_name, 'blah') + self.assertEqual(config.credentials.host, 'cli-postgres-host') + self.assertEqual(config.credentials.user, 'env-postgres-user') + # make sure hooks are not interpreted + self.assertEqual(config.on_run_end, ["{{ env_var('env_value_project') }}"]) + self.assertEqual(config.models['foo']['post-hook'], "{{ env_var('env_value_target') }}") + self.assertEqual(config.models['bar']['materialized'], 'blah') + self.assertEqual(config.seeds['foo']['post-hook'], "{{ env_var('env_value_target') }}") + self.assertEqual(config.seeds['bar']['materialized'], 'blah') + From af44abf7a6eb73a440b9f48d34a54e5895334aca Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 25 Sep 2018 09:27:50 -0600 Subject: [PATCH 031/133] make clean not require a profile, make bare dbt -d fail with better errors --- dbt/main.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/dbt/main.py b/dbt/main.py index 126b9d0b872..57ebc25ec29 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -107,7 +107,6 @@ def handle(args): def handle_and_check(args): parsed = parse_args(args) - # this needs to happen after args are parsed so we can determine the # correct profiles.yml file profile_config = read_config(parsed.profiles_dir) @@ -207,7 +206,7 @@ def invoke_dbt(parsed): cfg = None try: - if parsed.which == 'deps': + if parsed.which in {'deps', 'clean'}: # deps doesn't need a profile, so don't require one. cfg = Project.from_current_directory() elif parsed.which != 'debug': @@ -495,4 +494,10 @@ def parse_args(args): parsed = p.parse_args(args) + if not hasattr(parsed, 'which'): + # the user did not provide a valid subcommand. trigger the help message + # and exit with a error + p.print_help() + p.exit(1) + return parsed From ba4cc78a75eec41a50e3ef94061f1d44d045d6cf Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 28 Sep 2018 14:22:39 -0600 Subject: [PATCH 032/133] properly move config errors into the dbt exception hierarchy --- dbt/config.py | 16 +--------------- dbt/exceptions.py | 15 +++++++++++++++ dbt/task/debug.py | 5 +++-- dbt/utils.py | 3 +-- test/unit/test_config.py | 27 ++++++++++++++------------- 5 files changed, 34 insertions(+), 32 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index cadb7d8411f..023b188e5c1 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -10,6 +10,7 @@ from dbt.contracts.connection import Connection, create_credentials from dbt.contracts.project import Project as ProjectContract, Configuration, \ PackageConfig, ProfileConfig +from dbt.exceptions import DbtProjectError, DbtProfileError from dbt.context.common import env_var from dbt import compat from dbt.adapters.factory import get_relation_class_by_name @@ -43,21 +44,6 @@ """.format(profiles_file=DEFAULT_PROFILES_DIR) -class DbtConfigError(Exception): - def __init__(self, message, project=None, result_type='invalid_project'): - self.project = project - super(DbtConfigError, self).__init__(message) - self.result_type = result_type - - -class DbtProjectError(DbtConfigError): - pass - - -class DbtProfileError(DbtConfigError): - pass - - def read_profile(profiles_dir): path = os.path.join(profiles_dir, 'profiles.yml') diff --git a/dbt/exceptions.py b/dbt/exceptions.py index e28011dd193..511be3fc427 100644 --- a/dbt/exceptions.py +++ b/dbt/exceptions.py @@ -129,6 +129,21 @@ class DependencyException(Exception): pass +class DbtConfigError(RuntimeException): + def __init__(self, message, project=None, result_type='invalid_project'): + self.project = project + super(DbtConfigError, self).__init__(message) + self.result_type = result_type + + +class DbtProjectError(DbtConfigError): + pass + + +class DbtProfileError(DbtConfigError): + pass + + class SemverException(Exception): def __init__(self, msg=None): self.msg = msg diff --git a/dbt/task/debug.py b/dbt/task/debug.py index 2d521371e9e..69679988387 100644 --- a/dbt/task/debug.py +++ b/dbt/task/debug.py @@ -3,6 +3,7 @@ from dbt.logger import GLOBAL_LOGGER as logger import dbt.clients.system import dbt.config +import dbt.exceptions from dbt.task.base_task import BaseTask @@ -30,13 +31,13 @@ def diag(self): try: project = dbt.config.Project.from_current_directory() project_profile = project.profile_name - except dbt.config.DbtConfigError as exc: + except dbt.exceptions.DbtConfigError as exc: project = 'ERROR loading project: {!s}'.format(exc) # log the profile we decided on as well, if it's available. try: profile = dbt.config.Profile.from_args(self.args, project_profile) - except dbt.config.DbtConfigError as exc: + except dbt.exceptions.DbtConfigError as exc: profile = 'ERROR loading profile: {!s}'.format(exc) logger.info("args: {}".format(self.args)) diff --git a/dbt/utils.py b/dbt/utils.py index f2d65d67892..8c18482de06 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -171,7 +171,6 @@ def get_docs_macro_name(docs_name, with_prefix=True): def dependencies_for_path(config, module_path): """Given a module path, yield all dependencies in that path.""" logger.debug("Loading dependency project from {}".format(module_path)) - import dbt.config for obj in os.listdir(module_path): full_obj = os.path.join(module_path, obj) @@ -183,7 +182,7 @@ def dependencies_for_path(config, module_path): try: yield config.new_project(full_obj) - except dbt.config.DbtProjectError as e: + except dbt.exceptions.DbtProjectError as e: logger.info( "Error reading dependency project at {}".format( full_obj) diff --git a/test/unit/test_config.py b/test/unit/test_config.py index 67b7cc26d2b..8de9980fb27 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -10,6 +10,7 @@ import yaml import dbt.config +import dbt.exceptions from dbt.contracts.connection import PostgresCredentials, RedshiftCredentials from dbt.contracts.project import PackageConfig @@ -250,7 +251,7 @@ def test_partial_config_override(self): def test_missing_type(self): del self.default_profile_data['default']['outputs']['postgres']['type'] - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: profile = dbt.config.Profile.from_raw_profiles( self.default_profile_data, 'default' ) @@ -260,7 +261,7 @@ def test_missing_type(self): def test_bad_type(self): self.default_profile_data['default']['outputs']['postgres']['type'] = 'invalid' - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: profile = dbt.config.Profile.from_raw_profiles( self.default_profile_data, 'default' ) @@ -270,7 +271,7 @@ def test_bad_type(self): def test_invalid_credentials(self): del self.default_profile_data['default']['outputs']['postgres']['host'] - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: profile = dbt.config.Profile.from_raw_profiles( self.default_profile_data, 'default' ) @@ -280,7 +281,7 @@ def test_invalid_credentials(self): def test_target_missing(self): del self.default_profile_data['default']['target'] - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: profile = dbt.config.Profile.from_raw_profiles( self.default_profile_data, 'default' ) @@ -288,7 +289,7 @@ def test_target_missing(self): self.assertIn('default', str(exc.exception)) def test_profile_invalid_project(self): - with self.assertRaises(dbt.config.DbtProjectError) as exc: + with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: profile = dbt.config.Profile.from_raw_profiles( self.default_profile_data, 'invalid-profile' ) @@ -298,7 +299,7 @@ def test_profile_invalid_project(self): self.assertIn('invalid-profile', str(exc.exception)) def test_profile_invalid_target(self): - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: profile = dbt.config.Profile.from_raw_profiles( self.default_profile_data, 'default', target_override='nope', ) @@ -309,7 +310,7 @@ def test_profile_invalid_target(self): self.assertIn('- with-vars', str(exc.exception)) def test_no_outputs(self): - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: profile = dbt.config.Profile.from_raw_profiles( {'some-profile': {'target': 'blah'}}, 'some-profile' ) @@ -335,7 +336,7 @@ def test_eq(self): def test_invalid_env_vars(self): self.env_override['env_value_port'] = 'hello' with mock.patch.dict(os.environ, self.env_override): - with self.assertRaises(dbt.config.DbtProfileError) as exc: + with self.assertRaises(dbt.exceptions.DbtProfileError) as exc: dbt.config.Profile.from_raw_profile_info( self.default_profile_data['default'], 'default', @@ -443,7 +444,7 @@ def test_env_vars(self): self.assertEqual(profile, from_raw) def test_no_profile(self): - with self.assertRaises(dbt.config.DbtProjectError) as exc: + with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: dbt.config.Profile.from_args(self.args) self.assertIn('no profile was specified', str(exc.exception)) @@ -661,14 +662,14 @@ def test_all_overrides(self): def test_invalid_project_name(self): self.default_project_data['name'] = 'invalid-project-name' - with self.assertRaises(dbt.config.DbtProjectError) as exc: + with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: project = dbt.config.Project.from_project_config( self.default_project_data ) self.assertIn('invalid-project-name', str(exc.exception)) def test_no_project(self): - with self.assertRaises(dbt.config.DbtProjectError) as exc: + with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: dbt.config.Project.from_project_root(self.project_dir) self.assertIn('no dbt_project.yml', str(exc.exception)) @@ -690,7 +691,7 @@ def test_from_project_root(self): def test_with_invalid_package(self): self.write_packages({'invalid': ['not a package of any kind']}) - with self.assertRaises(dbt.config.DbtProjectError) as exc: + with self.assertRaises(dbt.exceptions.DbtProjectError) as exc: dbt.config.Project.from_project_root(self.project_dir) @@ -744,7 +745,7 @@ def test_validate_fails(self): profile = self.get_profile() # invalid - must be boolean profile.use_colors = None - with self.assertRaises(dbt.config.DbtProjectError): + with self.assertRaises(dbt.exceptions.DbtProjectError): dbt.config.RuntimeConfig.from_parts(project, profile, {}) From 36dcca2f1fea7f4427f21b286ceff9b19c8db988 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 28 Sep 2018 14:30:16 -0600 Subject: [PATCH 033/133] avoid logging stack traces to the console on dbt-created errors --- dbt/main.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dbt/main.py b/dbt/main.py index 126b9d0b872..b1897fb631e 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -29,6 +29,7 @@ from dbt.config import Project, RuntimeConfig, DbtProjectError, \ DbtProfileError, DEFAULT_PROFILES_DIR, read_config, \ send_anonymous_usage_stats, colorize_output, read_profiles +from dbt.exceptions import DbtProfileError, DbtProfileError, RuntimeException PROFILES_HELP_MESSAGE = """ @@ -92,7 +93,10 @@ def main(args=None): if logger_initialized: logger.debug(traceback.format_exc()) - else: + elif not isinstance(e, RuntimeException): + # if it did not come from dbt proper and the logger is not + # initialized (so there's no safe path to log to), log the stack + # trace at error level. logger.error(traceback.format_exc()) exit_code = ExitCodes.UnhandledError From e4ca3503914aa56aab7abe3548d730baaa297272 Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Wed, 19 Sep 2018 17:32:05 -0400 Subject: [PATCH 034/133] (fixes #311) Configure tags, and select them with --models --- dbt/contracts/graph/parsed.py | 15 +++- dbt/graph/selector.py | 62 +++++++++++-- dbt/model.py | 23 +++-- dbt/parser/base.py | 5 ++ dbt/utils.py | 1 + .../models/base_users.sql | 3 +- .../models/emails.sql | 2 +- .../models/users.sql | 3 +- .../models/users_rollup.sql | 3 +- .../test_graph_selection.py | 30 +++++++ .../test_schema_test_graph_selection.py | 19 ++++ .../test_tag_selection.py | 63 ++++++++++++++ .../test_docs_generate.py | 54 ++++++++---- test/unit/test_compiler.py | 1 + test/unit/test_context.py | 1 + test/unit/test_graph_selection.py | 87 +++++++++++++++---- test/unit/test_manifest.py | 2 + test/unit/test_parser.py | 2 + 18 files changed, 321 insertions(+), 55 deletions(-) create mode 100644 test/integration/007_graph_selection_tests/test_tag_selection.py diff --git a/dbt/contracts/graph/parsed.py b/dbt/contracts/graph/parsed.py index 4a45de5c080..88720ee4cfc 100644 --- a/dbt/contracts/graph/parsed.py +++ b/dbt/contracts/graph/parsed.py @@ -68,10 +68,23 @@ 'type': 'object', 'additionalProperties': True, }, + 'tags': { + 'anyOf': [ + { + 'type': 'array', + 'items': { + 'type': 'string' + }, + }, + { + 'type': 'string' + } + ] + }, }, 'required': [ 'enabled', 'materialized', 'post-hook', 'pre-hook', 'vars', - 'quoting', 'column_types' + 'quoting', 'column_types', 'tags' ] } diff --git a/dbt/graph/selector.py b/dbt/graph/selector.py index 5f447558185..759c4f44c43 100644 --- a/dbt/graph/selector.py +++ b/dbt/graph/selector.py @@ -4,10 +4,17 @@ from dbt.utils import is_enabled, get_materialization, coalesce from dbt.node_types import NodeType from dbt.contracts.graph.parsed import ParsedNode +import dbt.exceptions SELECTOR_PARENTS = '+' SELECTOR_CHILDREN = '+' SELECTOR_GLOB = '*' +SELECTOR_DELIMITER = ':' + + +class SELECTOR_FILTERS: + FQN = 'fqn' + TAG = 'tag' def split_specs(node_specs): @@ -34,12 +41,27 @@ def parse_spec(node_spec): index_end -= 1 node_selector = node_spec[index_start:index_end] - qualified_node_name = node_selector.split('.') + + if SELECTOR_DELIMITER in node_selector: + selector_parts = node_selector.split(SELECTOR_DELIMITER, 1) + selector_type, selector_value = selector_parts + + node_filter = { + "type": selector_type, + "value": selector_value + } + + else: + node_filter = { + "type": SELECTOR_FILTERS.FQN, + "value": node_selector + + } return { "select_parents": select_parents, "select_children": select_children, - "qualified_node_name": qualified_node_name, + "filter": node_filter, "raw": node_spec } @@ -74,10 +96,11 @@ def is_selected_node(real_node, node_selector): return True -def get_nodes_by_qualified_name(graph, qualified_name): +def get_nodes_by_qualified_name(graph, name): """ returns a node if matched, else throws a CompilerError. qualified_name should be either 1) a node name or 2) a dot-notation qualified selector""" + qualified_name = name.split('.') package_names = get_package_names(graph) for node in graph.nodes(): @@ -98,13 +121,40 @@ def get_nodes_by_qualified_name(graph, qualified_name): break +def get_nodes_by_tag(graph, tag_name): + """ yields nodes from graph that have the specified tag """ + + for node in graph.nodes(): + tags = graph.node[node]['tags'] + + if tag_name in tags: + yield node + + def get_nodes_from_spec(graph, spec): select_parents = spec['select_parents'] select_children = spec['select_children'] - qualified_node_name = spec['qualified_node_name'] - selected_nodes = set(get_nodes_by_qualified_name(graph, - qualified_node_name)) + filter_map = { + SELECTOR_FILTERS.FQN: get_nodes_by_qualified_name, + SELECTOR_FILTERS.TAG: get_nodes_by_tag, + } + + node_filter = spec['filter'] + filter_func = filter_map.get(node_filter['type']) + + if filter_func is None: + valid_selectors = ", ".join(filter_map.keys()) + logger.info("The '{}' selector specified in {} is invalid. Must be " + "one of [{}]".format( + node_filter['type'], + spec['raw'], + valid_selectors)) + + selected_nodes = set() + + else: + selected_nodes = set(filter_func(graph, node_filter['value'])) additional_nodes = set() test_nodes = set() diff --git a/dbt/model.py b/dbt/model.py index eb4e5a3f491..27c0546de72 100644 --- a/dbt/model.py +++ b/dbt/model.py @@ -11,7 +11,7 @@ class SourceConfig(object): ConfigKeys = DBTConfigKeys - AppendListFields = ['pre-hook', 'post-hook'] + AppendListFields = ['pre-hook', 'post-hook', 'tags'] ExtendDictFields = ['vars', 'column_types', 'quoting'] ClobberFields = [ 'alias', @@ -94,22 +94,21 @@ def update_in_model_config(self, config): # make sure we're not clobbering an array of hooks with a single hook # string - hook_fields = ['pre-hook', 'post-hook'] - for hook_field in hook_fields: - if hook_field in config: - config[hook_field] = self.__get_hooks(config, hook_field) + for field in self.AppendListFields: + if field in config: + config[field] = self.__get_as_list(config, field) self.in_model_config.update(config) - def __get_hooks(self, relevant_configs, key): + def __get_as_list(self, relevant_configs, key): if key not in relevant_configs: return [] - hooks = relevant_configs[key] - if not isinstance(hooks, (list, tuple)): - hooks = [hooks] + items = relevant_configs[key] + if not isinstance(items, (list, tuple)): + items = [items] - return hooks + return items def smart_update(self, mutable_config, new_configs): relevant_configs = { @@ -118,9 +117,9 @@ def smart_update(self, mutable_config, new_configs): } for key in SourceConfig.AppendListFields: - new_hooks = self.__get_hooks(relevant_configs, key) + append_fields = self.__get_as_list(relevant_configs, key) mutable_config[key].extend([ - h for h in new_hooks if h not in mutable_config[key] + f for f in append_fields if f not in mutable_config[key] ]) for key in SourceConfig.ExtendDictFields: diff --git a/dbt/parser/base.py b/dbt/parser/base.py index 89ca46a5c33..2e79b28ca1f 100644 --- a/dbt/parser/base.py +++ b/dbt/parser/base.py @@ -7,6 +7,7 @@ import dbt.hooks import dbt.clients.jinja import dbt.context.parser +from dbt.compat import basestring from dbt.utils import coalesce from dbt.logger import GLOBAL_LOGGER as logger @@ -128,6 +129,10 @@ def parse_node(cls, node, node_path, root_project_config, parsed_node.schema = get_schema(schema_override) parsed_node.alias = config.config.get('alias', default_alias) + # Set tags on node provided in config blocks + model_tags = config.config.get('tags', []) + parsed_node.tags.extend(model_tags) + # Overwrite node config config_dict = parsed_node.get('config', {}) config_dict.update(config.config) diff --git a/dbt/utils.py b/dbt/utils.py index f2d65d67892..0d8218f9b65 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -34,6 +34,7 @@ 'column_types', 'bind', 'quoting', + 'tags', ] diff --git a/test/integration/007_graph_selection_tests/models/base_users.sql b/test/integration/007_graph_selection_tests/models/base_users.sql index fbab29807f3..858eb5ba343 100644 --- a/test/integration/007_graph_selection_tests/models/base_users.sql +++ b/test/integration/007_graph_selection_tests/models/base_users.sql @@ -1,7 +1,8 @@ {{ config( - materialized = 'ephemeral' + materialized = 'ephemeral', + tags = ['base'] ) }} diff --git a/test/integration/007_graph_selection_tests/models/emails.sql b/test/integration/007_graph_selection_tests/models/emails.sql index a233a42d901..c9134abadfe 100644 --- a/test/integration/007_graph_selection_tests/models/emails.sql +++ b/test/integration/007_graph_selection_tests/models/emails.sql @@ -1,6 +1,6 @@ {{ - config(materialized='ephemeral') + config(materialized='ephemeral', tags=['base']) }} select distinct email from {{ ref('base_users') }} diff --git a/test/integration/007_graph_selection_tests/models/users.sql b/test/integration/007_graph_selection_tests/models/users.sql index ba7d4f48f41..058f048e660 100644 --- a/test/integration/007_graph_selection_tests/models/users.sql +++ b/test/integration/007_graph_selection_tests/models/users.sql @@ -1,7 +1,8 @@ {{ config( - materialized = 'table' + materialized = 'table', + tags='bi' ) }} diff --git a/test/integration/007_graph_selection_tests/models/users_rollup.sql b/test/integration/007_graph_selection_tests/models/users_rollup.sql index 8013ab10829..b9764f3ae5d 100644 --- a/test/integration/007_graph_selection_tests/models/users_rollup.sql +++ b/test/integration/007_graph_selection_tests/models/users_rollup.sql @@ -1,7 +1,8 @@ {{ config( - materialized = 'view' + materialized = 'view', + tags = ['bi'] ) }} diff --git a/test/integration/007_graph_selection_tests/test_graph_selection.py b/test/integration/007_graph_selection_tests/test_graph_selection.py index 9dd1d349b93..b13a497ea56 100644 --- a/test/integration/007_graph_selection_tests/test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_graph_selection.py @@ -26,6 +26,36 @@ def test__postgres__specific_model(self): self.assertFalse('base_users' in created_models) self.assertFalse('emails' in created_models) + @attr(type='postgres') + def test__postgres__tags(self): + self.use_profile('postgres') + self.use_default_project() + self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql") + + results = self.run_dbt(['run', '--models', 'tag:bi']) + self.assertEqual(len(results), 2) + + created_models = self.get_models_in_schema() + self.assertFalse('base_users' in created_models) + self.assertFalse('emails' in created_models) + self.assertTrue('users' in created_models) + self.assertTrue('users_rollup' in created_models) + + @attr(type='postgres') + def test__postgres__tags_and_children(self): + self.use_profile('postgres') + self.use_default_project() + self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql") + + results = self.run_dbt(['run', '--models', 'tag:base+']) + self.assertEqual(len(results), 2) + + created_models = self.get_models_in_schema() + self.assertFalse('base_users' in created_models) + self.assertFalse('emails' in created_models) + self.assertTrue('users_rollup' in created_models) + self.assertTrue('users' in created_models) + @attr(type='snowflake') def test__snowflake__specific_model(self): self.use_profile('snowflake') diff --git a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py index af21250936e..35e991a965b 100644 --- a/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py +++ b/test/integration/007_graph_selection_tests/test_schema_test_graph_selection.py @@ -62,6 +62,15 @@ def test__postgres__schema_tests_specify_model(self): ['unique_users_id'] ) + @attr(type='postgres') + def test__postgres__schema_tests_specify_tag(self): + self.run_schema_and_assert( + ['tag:bi'], + None, + ['unique_users_id', + 'unique_users_rollup_gender'] + ) + @attr(type='postgres') def test__postgres__schema_tests_specify_model_and_children(self): self.run_schema_and_assert( @@ -70,6 +79,16 @@ def test__postgres__schema_tests_specify_model_and_children(self): ['unique_users_id', 'unique_users_rollup_gender'] ) + @attr(type='postgres') + def test__postgres__schema_tests_specify_tag_and_children(self): + self.run_schema_and_assert( + ['tag:base+'], + None, + ['unique_emails_email', + 'unique_users_id', + 'unique_users_rollup_gender'] + ) + @attr(type='postgres') def test__postgres__schema_tests_specify_model_and_parents(self): self.run_schema_and_assert( diff --git a/test/integration/007_graph_selection_tests/test_tag_selection.py b/test/integration/007_graph_selection_tests/test_tag_selection.py new file mode 100644 index 00000000000..35ea38fb899 --- /dev/null +++ b/test/integration/007_graph_selection_tests/test_tag_selection.py @@ -0,0 +1,63 @@ +from test.integration.base import DBTIntegrationTest, use_profile + +class TestGraphSelection(DBTIntegrationTest): + + @property + def schema(self): + return "graph_selection_tests_007" + + @property + def models(self): + return "test/integration/007_graph_selection_tests/models" + + @property + def project_config(self): + return { + "models": { + "test": { + "users": { + "tags": "specified_as_string" + }, + + "users_rollup": { + "tags": ["specified_in_project"] + } + } + } + } + + @use_profile('postgres') + def test__postgres__select_tag(self): + self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql") + + results = self.run_dbt(['run', '--models', 'tag:specified_as_string']) + self.assertEqual(len(results), 1) + + models_run = [r.node['name'] for r in results] + self.assertTrue('users' in models_run) + + + @use_profile('postgres') + def test__postgres__select_tag_and_children(self): + self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql") + + results = self.run_dbt(['run', '--models', '+tag:specified_in_project+']) + self.assertEqual(len(results), 2) + + models_run = [r.node['name'] for r in results] + self.assertTrue('users' in models_run) + self.assertTrue('users_rollup' in models_run) + + + # check that model configs aren't squashed by project configs + @use_profile('postgres') + def test__postgres__select_tag_in_model_with_project_Config(self): + self.run_sql_file("test/integration/007_graph_selection_tests/seed.sql") + + results = self.run_dbt(['run', '--models', 'tag:bi']) + self.assertEqual(len(results), 2) + + models_run = [r.node['name'] for r in results] + self.assertTrue('users' in models_run) + self.assertTrue('users_rollup' in models_run) + diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index 955e324eb27..2202f2c085e 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -748,7 +748,8 @@ def expected_seeded_manifest(self): 'post-hook': [], 'vars': {}, 'column_types': {}, - 'quoting': {} + 'quoting': {}, + 'tags': [], }, 'schema': my_schema_name, 'alias': 'model', @@ -800,7 +801,8 @@ def expected_seeded_manifest(self): 'post-hook': [], 'vars': {}, 'column_types': {}, - 'quoting': {} + 'quoting': {}, + 'tags': [], }, 'schema': my_schema_name, 'alias': 'seed', @@ -818,7 +820,8 @@ def expected_seeded_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', @@ -846,7 +849,8 @@ def expected_seeded_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', @@ -875,7 +879,8 @@ def expected_seeded_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', @@ -938,7 +943,8 @@ def expected_postgres_references_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': []}, 'description': '', @@ -978,7 +984,8 @@ def expected_postgres_references_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': { 'macros': [], @@ -1040,7 +1047,8 @@ def expected_postgres_references_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': { 'macros': [], @@ -1092,7 +1100,8 @@ def expected_postgres_references_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': []}, 'description': '', @@ -1200,7 +1209,8 @@ def expected_bigquery_complex_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['seed.test.seed']}, 'empty': False, @@ -1251,7 +1261,8 @@ def expected_bigquery_complex_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': { 'macros': [], @@ -1352,6 +1363,7 @@ def expected_bigquery_complex_manifest(self): 'vars': {}, 'column_types': {}, 'quoting': {}, + 'tags': [], }, 'schema': my_schema_name, 'alias': 'seed', @@ -1412,6 +1424,7 @@ def expected_redshift_incremental_view_manifest(self): "vars": {}, "column_types": {}, "quoting": {}, + "tags": [], }, "schema": my_schema_name, "alias": "model", @@ -1466,6 +1479,7 @@ def expected_redshift_incremental_view_manifest(self): "vars": {}, "column_types": {}, "quoting": {}, + "tags": [], }, "schema": my_schema_name, "alias": "seed", @@ -1562,7 +1576,8 @@ def expected_run_results(self, quote_schema=True, quote_model=False): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': { 'macros': [], @@ -1612,6 +1627,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False): 'pre-hook': [], 'quoting': {}, 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': []}, 'description': '', @@ -1654,7 +1670,8 @@ def expected_run_results(self, quote_schema=True, quote_model=False): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', @@ -1696,7 +1713,8 @@ def expected_run_results(self, quote_schema=True, quote_model=False): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', @@ -1740,6 +1758,7 @@ def expected_run_results(self, quote_schema=True, quote_model=False): 'pre-hook': [], 'quoting': {}, 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': ['model.test.model']}, 'description': '', @@ -1817,7 +1836,8 @@ def expected_postgres_references_run_results(self): 'post-hook': [], 'vars': {}, 'column_types': {}, - 'quoting': {} + 'quoting': {}, + 'tags': [], }, 'depends_on': { 'nodes': ['model.test.ephemeral_copy'], @@ -1900,7 +1920,8 @@ def expected_postgres_references_run_results(self): 'post-hook': [], 'vars': {}, 'column_types': {}, - 'quoting': {} + 'quoting': {}, + 'tags': [], }, 'depends_on': { 'nodes': ['model.test.ephemeral_summary'], @@ -1972,6 +1993,7 @@ def expected_postgres_references_run_results(self): 'pre-hook': [], 'quoting': {}, 'vars': {}, + 'tags': [], }, 'depends_on': {'macros': [], 'nodes': []}, 'description': '', diff --git a/test/unit/test_compiler.py b/test/unit/test_compiler.py index 03f4db22a4d..ede4b48665f 100644 --- a/test/unit/test_compiler.py +++ b/test/unit/test_compiler.py @@ -41,6 +41,7 @@ def setUp(self): 'vars': {}, 'quoting': {}, 'column_types': {}, + 'tags': [], } def test__prepend_ctes__already_has_cte(self): diff --git a/test/unit/test_context.py b/test/unit/test_context.py index 1829596a299..e89e811ccd3 100644 --- a/test/unit/test_context.py +++ b/test/unit/test_context.py @@ -31,6 +31,7 @@ def setUp(self): 'vars': {}, 'quoting': {}, 'column_types': {}, + 'tags': [], }, tags=[], path='model_one.sql', diff --git a/test/unit/test_graph_selection.py b/test/unit/test_graph_selection.py index 5d4865875bf..518102f9128 100644 --- a/test/unit/test_graph_selection.py +++ b/test/unit/test_graph_selection.py @@ -23,6 +23,15 @@ def setUp(self): for node in self.package_graph: self.package_graph.node[node]['fqn'] = node.split('.')[1:] + self.package_graph.node['m.X.a']['tags'] = ['abc'] + self.package_graph.node['m.Y.b']['tags'] = ['abc'] + self.package_graph.node['m.X.c']['tags'] = ['abc'] + self.package_graph.node['m.Y.d']['tags'] = [] + self.package_graph.node['m.X.e']['tags'] = ['efg'] + self.package_graph.node['m.Y.f']['tags'] = ['efg'] + self.package_graph.node['m.X.g']['tags'] = ['efg'] + + def run_specs_and_assert(self, graph, include, exclude, expected): selected = graph_selector.select_nodes( graph, @@ -41,6 +50,37 @@ def test__single_node_selection_in_package(self): set(['m.X.a']) ) + def test__select_by_tag(self): + self.run_specs_and_assert( + self.package_graph, + ['tag:abc'], + [], + set(['m.X.a', 'm.Y.b', 'm.X.c']) + ) + + def test__exclude_by_tag(self): + self.run_specs_and_assert( + self.package_graph, + ['*'], + ['tag:abc'], + set(['m.Y.d', 'm.X.e', 'm.Y.f', 'm.X.g']) + ) + + def test__select_by_tag_and_model_name(self): + self.run_specs_and_assert( + self.package_graph, + ['tag:abc', 'a'], + [], + set(['m.X.a', 'm.Y.b', 'm.X.c']) + ) + + self.run_specs_and_assert( + self.package_graph, + ['tag:abc', 'd'], + [], + set(['m.X.a', 'm.Y.b', 'm.X.c', 'm.Y.d']) + ) + def test__multiple_node_selection_in_package(self): self.run_specs_and_assert( self.package_graph, @@ -56,33 +96,48 @@ def test__select_children_except_in_package(self): ['b'], set(['m.X.a','m.X.c', 'm.Y.d','m.X.e','m.Y.f','m.X.g'])) - def parse_spec_and_assert(self, spec, parents, children, qualified_node_name): + def test__select_children_except_tag(self): + self.run_specs_and_assert( + self.package_graph, + ['X.a+'], + ['tag:efg'], + set(['m.X.a','m.Y.b','m.X.c', 'm.Y.d'])) + + def parse_spec_and_assert(self, spec, parents, children, filter_type, filter_value): parsed = graph_selector.parse_spec(spec) self.assertEquals( parsed, { "select_parents": parents, "select_children": children, - "qualified_node_name": qualified_node_name, + "filter": { + 'type': filter_type, + 'value': filter_value + }, "raw": spec } ) def test__spec_parsing(self): - self.parse_spec_and_assert('a', False, False, ['a']) - self.parse_spec_and_assert('+a', True, False, ['a']) - self.parse_spec_and_assert('a+', False, True, ['a']) - self.parse_spec_and_assert('+a+', True, True, ['a']) - - self.parse_spec_and_assert('a.b', False, False, ['a', 'b']) - self.parse_spec_and_assert('+a.b', True, False, ['a', 'b']) - self.parse_spec_and_assert('a.b+', False, True, ['a', 'b']) - self.parse_spec_and_assert('+a.b+', True, True, ['a', 'b']) - - self.parse_spec_and_assert('a.b.*', False, False, ['a', 'b', '*']) - self.parse_spec_and_assert('+a.b.*', True, False, ['a', 'b', '*']) - self.parse_spec_and_assert('a.b.*+', False, True, ['a', 'b', '*']) - self.parse_spec_and_assert('+a.b.*+', True, True, ['a', 'b', '*']) + self.parse_spec_and_assert('a', False, False, 'fqn', 'a') + self.parse_spec_and_assert('+a', True, False, 'fqn', 'a') + self.parse_spec_and_assert('a+', False, True, 'fqn', 'a') + self.parse_spec_and_assert('+a+', True, True, 'fqn', 'a') + + self.parse_spec_and_assert('a.b', False, False, 'fqn', 'a.b') + self.parse_spec_and_assert('+a.b', True, False, 'fqn', 'a.b') + self.parse_spec_and_assert('a.b+', False, True, 'fqn', 'a.b') + self.parse_spec_and_assert('+a.b+', True, True, 'fqn', 'a.b') + + self.parse_spec_and_assert('a.b.*', False, False, 'fqn', 'a.b.*') + self.parse_spec_and_assert('+a.b.*', True, False, 'fqn', 'a.b.*') + self.parse_spec_and_assert('a.b.*+', False, True, 'fqn', 'a.b.*') + self.parse_spec_and_assert('+a.b.*+', True, True, 'fqn', 'a.b.*') + + self.parse_spec_and_assert('tag:a', False, False, 'tag', 'a') + self.parse_spec_and_assert('+tag:a', True, False, 'tag', 'a') + self.parse_spec_and_assert('tag:a+', False, True, 'tag', 'a') + self.parse_spec_and_assert('+tag:a+', True, True, 'tag', 'a') def test__package_name_getter(self): found = graph_selector.get_package_names(self.package_graph) diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index a4fd33f8e7a..d1706ab64af 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -26,6 +26,7 @@ def setUp(self): 'vars': {}, 'quoting': {}, 'column_types': {}, + 'tags': [], } self.nested_nodes = { @@ -321,6 +322,7 @@ def setUp(self): 'vars': {}, 'quoting': {}, 'column_types': {}, + 'tags': [], } self.nested_nodes = { diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 58164da0ac0..3706e53899f 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -81,6 +81,7 @@ def setUp(self): 'vars': {}, 'quoting': {}, 'column_types': {}, + 'tags': [], } self.disabled_config = { @@ -91,6 +92,7 @@ def setUp(self): 'vars': {}, 'quoting': {}, 'column_types': {}, + 'tags': [], } def test__single_model(self): From 4cc5e6d6484bae2c418ae3bb8f92ed2bb101ab0a Mon Sep 17 00:00:00 2001 From: Drew Banin Date: Sun, 30 Sep 2018 12:01:59 -0400 Subject: [PATCH 035/133] fix bq tests --- test/integration/029_docs_generate_tests/test_docs_generate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/integration/029_docs_generate_tests/test_docs_generate.py b/test/integration/029_docs_generate_tests/test_docs_generate.py index 2202f2c085e..08b74bfeca9 100644 --- a/test/integration/029_docs_generate_tests/test_docs_generate.py +++ b/test/integration/029_docs_generate_tests/test_docs_generate.py @@ -1316,7 +1316,8 @@ def expected_bigquery_complex_manifest(self): 'post-hook': [], 'pre-hook': [], 'quoting': {}, - 'vars': {} + 'vars': {}, + 'tags': [], }, 'depends_on': { 'macros': [], From bc8d523a4ea82388ea4c2592a8fbe06a81c384c0 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 2 Oct 2018 08:20:50 -0600 Subject: [PATCH 036/133] PR feedback --- dbt/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dbt/main.py b/dbt/main.py index b1897fb631e..d56be5b87c4 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -156,7 +156,7 @@ def run_from_args(parsed): else: nearest_project_dir = get_nearest_project_dir() if nearest_project_dir is None: - raise RuntimeError( + raise RuntimeException( "fatal: Not a dbt project (or any of the parent directories). " "Missing dbt_project.yml file" ) @@ -165,7 +165,7 @@ def run_from_args(parsed): res = invoke_dbt(parsed) if res is None: - raise RuntimeError("Could not run dbt") + raise RuntimeException("Could not run dbt") else: task, cfg = res From 52c1d5ace2708cac5a8b00e063575a79d48105c2 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 28 Sep 2018 10:57:53 -0600 Subject: [PATCH 037/133] attach args to the config directly instead of just the cli parameters --- dbt/config.py | 18 ++++++++++-------- test/unit/utils.py | 9 +++++++-- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index 2d2d45b9174..f1d4ea232d5 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -694,9 +694,10 @@ def __init__(self, project_name, version, project_root, source_paths, modules_path, quoting, models, on_run_start, on_run_end, archive, seeds, profile_name, target_name, send_anonymous_usage_stats, use_colors, threads, credentials, - packages, cli_vars): + packages, args): # 'vars' - self.cli_vars = cli_vars + self.args = args + self.cli_vars = dbt.utils.parse_cli_vars(getattr(args, 'vars', '{}')) # 'project' Project.__init__( self, @@ -735,13 +736,12 @@ def __init__(self, project_name, version, project_root, source_paths, self.validate() @classmethod - def from_parts(cls, project, profile, cli_vars): + def from_parts(cls, project, profile, args): """Instantiate a RuntimeConfig from its components. :param profile Profile: A parsed dbt Profile. :param project Project: A parsed dbt Project. - :param cli_vars dict: A dict of vars, as provided from the command - line. + :param args argparse.Namespace: The parsed command-line arguments. :returns RuntimeConfig: The new configuration. """ quoting = deepcopy( @@ -776,7 +776,7 @@ def from_parts(cls, project, profile, cli_vars): use_colors=profile.use_colors, threads=profile.threads, credentials=profile.credentials, - cli_vars=cli_vars + args=args ) def new_project(self, project_root): @@ -797,7 +797,7 @@ def new_project(self, project_root): cfg = self.from_parts( project=project, profile=profile, - cli_vars=deepcopy(self.cli_vars) + args=deepcopy(self.args), ) # force our quoting back onto the new project. cfg.quoting = deepcopy(self.quoting) @@ -808,6 +808,8 @@ def serialize(self): instance that has passed validate() (which happens in __init__), it matches the Configuration contract. + Note that args are not serialized. + :returns dict: The serialized configuration. """ result = self.to_project_config(with_packages=True) @@ -854,7 +856,7 @@ def from_args(cls, args): return cls.from_parts( project=project, profile=profile, - cli_vars=cli_vars + args=args ) diff --git a/test/unit/utils.py b/test/unit/utils.py index ed94245814b..4c09cecc66b 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -4,6 +4,10 @@ issues. """ +class Obj(object): + pass + + def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): from dbt.config import Project, Profile, RuntimeConfig from dbt.utils import parse_cli_vars @@ -16,9 +20,10 @@ def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): profile = Profile.from_raw_profile_info(deepcopy(profile), project.profile_name, cli_vars) - + args = Obj() + args.cli_vars = cli_vars return RuntimeConfig.from_parts( project=project, profile=profile, - cli_vars=cli_vars + args=args ) From a1b44201d462e35026b8606a309a8d119c819b91 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 28 Sep 2018 11:06:41 -0600 Subject: [PATCH 038/133] make seed parsing conditional on being a dbt seed invocation --- dbt/parser/seeds.py | 18 ++++++++++++------ test/integration/base.py | 1 + test/unit/utils.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/dbt/parser/seeds.py b/dbt/parser/seeds.py index d2e4ab9e372..eb3b9982571 100644 --- a/dbt/parser/seeds.py +++ b/dbt/parser/seeds.py @@ -15,7 +15,7 @@ class SeedParser(BaseParser): @classmethod - def parse_seed_file(cls, file_match, root_dir, package_name): + def parse_seed_file(cls, file_match, root_dir, package_name, should_parse): """Parse the given seed file, returning an UnparsedNode and the agate table. """ @@ -34,10 +34,13 @@ def parse_seed_file(cls, file_match, root_dir, package_name): original_file_path=os.path.join(file_match.get('searched_path'), file_match.get('relative_path')), ) - try: - table = dbt.clients.agate_helper.from_csv(abspath) - except ValueError as e: - dbt.exceptions.raise_compiler_error(str(e), node) + if should_parse: + try: + table = dbt.clients.agate_helper.from_csv(abspath) + except ValueError as e: + dbt.exceptions.raise_compiler_error(str(e), node) + else: + table = dbt.clients.agate_helper.empty_table() table.original_abspath = abspath return node, table @@ -56,10 +59,13 @@ def load_and_parse(cls, package_name, root_project, all_projects, root_dir, relative_dirs, extension) + # we only want to parse seeds if we're inside 'dbt seed' + should_parse = root_project.args.which == 'seed' + result = {} for file_match in file_matches: node, agate_table = cls.parse_seed_file(file_match, root_dir, - package_name) + package_name, should_parse) node_path = cls.get_path(NodeType.Seed, package_name, node.name) parsed = cls.parse_node(node, node_path, root_project, all_projects.get(package_name), diff --git a/test/integration/base.py b/test/integration/base.py index 9dc0040b95b..ca86f1b1df7 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -38,6 +38,7 @@ def __init__(self): class TestArgs(object): def __init__(self, kwargs): + self.which = 'run' self.__dict__.update(kwargs) diff --git a/test/unit/utils.py b/test/unit/utils.py index 4c09cecc66b..21281d97824 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -5,7 +5,7 @@ """ class Obj(object): - pass + which = 'blah' def config_from_parts_or_dicts(project, profile, packages=None, cli_vars='{}'): From 4b8d19c75ce8ca891f0e5ec247d2fe1c89c31991 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 5 Oct 2018 14:59:16 -0600 Subject: [PATCH 039/133] add test to ensure "dbt run" skips seed parsing --- .../005_simple_seed_test/data-bad/seed.csv | 3 ++ .../models-exist/model.sql | 1 + .../005_simple_seed_test/test_simple_seed.py | 35 +++++++++++++++++++ 3 files changed, 39 insertions(+) create mode 100644 test/integration/005_simple_seed_test/data-bad/seed.csv create mode 100644 test/integration/005_simple_seed_test/models-exist/model.sql diff --git a/test/integration/005_simple_seed_test/data-bad/seed.csv b/test/integration/005_simple_seed_test/data-bad/seed.csv new file mode 100644 index 00000000000..09f9007dcec --- /dev/null +++ b/test/integration/005_simple_seed_test/data-bad/seed.csv @@ -0,0 +1,3 @@ +a,b,c +1,7,23,90,5 +2 diff --git a/test/integration/005_simple_seed_test/models-exist/model.sql b/test/integration/005_simple_seed_test/models-exist/model.sql new file mode 100644 index 00000000000..809a05ba8f9 --- /dev/null +++ b/test/integration/005_simple_seed_test/models-exist/model.sql @@ -0,0 +1 @@ +select * from {{ this.schema }}.seed_expected diff --git a/test/integration/005_simple_seed_test/test_simple_seed.py b/test/integration/005_simple_seed_test/test_simple_seed.py index a4496a14cf6..c6a0b482d26 100644 --- a/test/integration/005_simple_seed_test/test_simple_seed.py +++ b/test/integration/005_simple_seed_test/test_simple_seed.py @@ -1,6 +1,8 @@ from nose.plugins.attrib import attr from test.integration.base import DBTIntegrationTest +from dbt.exceptions import CompilationException + class TestSimpleSeed(DBTIntegrationTest): def setUp(self): @@ -129,3 +131,36 @@ def test_simple_seed_with_disabled(self): self.assertEqual(len(results), 1) self.assertTableDoesExist('seed_enabled') self.assertTableDoesNotExist('seed_disabled') + + +class TestSeedParsing(DBTIntegrationTest): + def setUp(self): + super(TestSeedParsing, self).setUp() + self.run_sql_file("test/integration/005_simple_seed_test/seed.sql") + + @property + def schema(self): + return "simple_seed_005" + + @property + def models(self): + return "test/integration/005_simple_seed_test/models-exist" + + @property + def project_config(self): + return { + "data-paths": ['test/integration/005_simple_seed_test/data-bad'] + } + + @attr(type='postgres') + def test_postgres_dbt_run_skips_seeds(self): + # run does not try to parse the seed files + self.assertEqual(len(self.run_dbt(['run'])), 1) + + # make sure 'dbt seed' fails, otherwise our test is invalid! + with self.assertRaises(CompilationException): + self.run_dbt(['seed']) + + + + From 30b6868d95b402a559c0217ef3e1081301a087f2 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Wed, 26 Sep 2018 14:28:59 -0600 Subject: [PATCH 040/133] make schemas available to on-run-end hooks --- dbt/compilation.py | 6 +++- dbt/node_runners.py | 34 ++++++++++++------- .../014_hook_tests/test_run_hooks.py | 15 ++++++-- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/dbt/compilation.py b/dbt/compilation.py index c35ce1d4909..1c688d27771 100644 --- a/dbt/compilation.py +++ b/dbt/compilation.py @@ -102,7 +102,10 @@ def initialize(self): dbt.clients.system.make_directory(self.config.target_path) dbt.clients.system.make_directory(self.config.modules_path) - def compile_node(self, node, manifest): + def compile_node(self, node, manifest, extra_context=None): + if extra_context is None: + extra_context = {} + logger.debug("Compiling {}".format(node.get('unique_id'))) data = node.to_dict() @@ -117,6 +120,7 @@ def compile_node(self, node, manifest): context = dbt.context.runtime.generate( compiled_node, self.config, manifest) + context.update(extra_context) compiled_node.compiled_sql = dbt.clients.jinja.get_rendered( node.get('raw_sql'), diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 20a25a835e1..0c54b074c32 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -223,13 +223,13 @@ def execute(self, compiled_node, manifest): def compile(self, manifest): return self._compile_node(self.adapter, self.config, self.node, - manifest) + manifest, {}) @classmethod - def _compile_node(cls, adapter, config, node, manifest): + def _compile_node(cls, adapter, config, node, manifest, extra_context): compiler = dbt.compilation.Compiler(config) - node = compiler.compile_node(node, manifest) - node = cls._inject_runtime_config(adapter, node) + node = compiler.compile_node(node, manifest, extra_context) + node = cls._inject_runtime_config(adapter, node, extra_context) if(node.injected_sql is not None and not (dbt.utils.is_type(node, NodeType.Archive))): @@ -247,9 +247,10 @@ def _compile_node(cls, adapter, config, node, manifest): return node @classmethod - def _inject_runtime_config(cls, adapter, node): + def _inject_runtime_config(cls, adapter, node, extra_context): wrapped_sql = node.wrapped_sql context = cls._node_context(adapter, node) + context.update(extra_context) sql = dbt.clients.jinja.get_rendered(wrapped_sql, context) node.wrapped_sql = sql return node @@ -286,7 +287,7 @@ def raise_on_first_error(self): return False @classmethod - def run_hooks(cls, config, adapter, manifest, hook_type): + def run_hooks(cls, config, adapter, manifest, hook_type, extra_context): nodes = manifest.nodes.values() hooks = get_nodes_by_tags(nodes, {hook_type}, NodeType.Operation) @@ -304,7 +305,8 @@ def run_hooks(cls, config, adapter, manifest, hook_type): # Also, consider configuring psycopg2 (and other adapters?) to # ensure that a transaction is only created if dbt initiates it. adapter.clear_transaction(model_name) - compiled = cls._compile_node(adapter, config, hook, manifest) + compiled = cls._compile_node(adapter, config, hook, manifest, + extra_context) statement = compiled.wrapped_sql hook_index = hook.get('index', len(hooks)) @@ -322,10 +324,10 @@ def run_hooks(cls, config, adapter, manifest, hook_type): adapter.release_connection(model_name) @classmethod - def safe_run_hooks(cls, config, adapter, manifest, hook_type): + def safe_run_hooks(cls, config, adapter, manifest, hook_type, + extra_context): try: - cls.run_hooks(config, adapter, manifest, hook_type) - + cls.run_hooks(config, adapter, manifest, hook_type, extra_context) except dbt.exceptions.RuntimeException: logger.info("Database error while running {}".format(hook_type)) raise @@ -347,7 +349,7 @@ def create_schemas(cls, config, adapter, manifest): @classmethod def before_run(cls, config, adapter, manifest): - cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start) + cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start, {}) cls.create_schemas(config, adapter, manifest) @classmethod @@ -368,7 +370,15 @@ def print_results_line(cls, results, execution_time): @classmethod def after_run(cls, config, adapter, results, manifest): - cls.safe_run_hooks(config, adapter, manifest, RunHookType.End) + # in on-run-end hooks, provide the value 'schemas', which is a list of + # unique schemas that successfully executed models were in + # errored failed skipped + schemas = list(set( + r.node.schema for r in results + if not any((r.errored, r.failed, r.skipped)) + )) + cls.safe_run_hooks(config, adapter, manifest, RunHookType.End, + {'schemas': schemas}) @classmethod def after_hooks(cls, config, adapter, results, manifest, elapsed): diff --git a/test/integration/014_hook_tests/test_run_hooks.py b/test/integration/014_hook_tests/test_run_hooks.py index 304f3b82b61..a92790c8473 100644 --- a/test/integration/014_hook_tests/test_run_hooks.py +++ b/test/integration/014_hook_tests/test_run_hooks.py @@ -45,6 +45,8 @@ def project_config(self): "{{ custom_run_hook('end', target, run_started_at, invocation_id) }}", "create table {{ target.schema }}.end_hook_order_test ( id int )", "drop table {{ target.schema }}.end_hook_order_test", + "create table {{ target.schema }}.schemas ( schema text )", + "insert into {{ target.schema }}.schemas values ({% for schema in schemas %}( '{{ schema }}' ){% if not loop.last %},{% endif %}{% endfor %})", ] } @@ -63,6 +65,12 @@ def get_ctx_vars(self, state): return ctx + def assert_used_schemas(self): + schemas_query = 'select * from {}.schemas'.format(self.unique_schema()) + results = self.run_sql(schemas_query, fetch='all') + self.assertEqual(len(results), 1) + self.assertEqual(results[0][0], self.unique_schema()) + def check_hooks(self, state): ctx = self.get_ctx_vars(state) @@ -81,7 +89,7 @@ def check_hooks(self, state): self.assertTrue(ctx['invocation_id'] is not None and len(ctx['invocation_id']) > 0, 'invocation_id was not set') @attr(type='postgres') - def test_pre_and_post_run_hooks(self): + def test__postgres__pre_and_post_run_hooks(self): self.run_dbt(['run']) self.check_hooks('start') @@ -89,9 +97,10 @@ def test_pre_and_post_run_hooks(self): self.assertTableDoesNotExist("start_hook_order_test") self.assertTableDoesNotExist("end_hook_order_test") + self.assert_used_schemas() @attr(type='postgres') - def test_pre_and_post_seed_hooks(self): + def test__postgres__pre_and_post_seed_hooks(self): self.run_dbt(['seed']) self.check_hooks('start') @@ -99,4 +108,4 @@ def test_pre_and_post_seed_hooks(self): self.assertTableDoesNotExist("start_hook_order_test") self.assertTableDoesNotExist("end_hook_order_test") - + self.assert_used_schemas() From deab38a4e138d62b2e2b31b5931322ab3d6b9dbc Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 8 Oct 2018 09:41:53 -0600 Subject: [PATCH 041/133] remove unused staticmethods --- dbt/config.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index 2d2d45b9174..c59c365bdac 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -114,23 +114,6 @@ def _is_hook_path(keypath): return False - @staticmethod - def _is_port_path(keypath): - return len(keypath) == 2 and keypath[-1] == 'port' - - @staticmethod - def _convert_port(value, keypath): - - if len(keypath) != 4: - return value - - if keypath[-1] == 'port' and keypath[1] == 'outputs': - try: - return int(value) - except ValueError: - pass # let the validator or connection handle this - return value - def _render_project_entry(self, value, keypath): """Render an entry, in case it's jinja. This is meant to be passed to dbt.utils.deep_map. From 45ddd3d7f1a477fe7c50b085d88f08ba9d498eaa Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 8 Oct 2018 10:03:19 -0600 Subject: [PATCH 042/133] pr feedback: more tests --- dbt/utils.py | 3 +-- test/unit/test_utils.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/dbt/utils.py b/dbt/utils.py index 281dedf33b2..3ada06189aa 100644 --- a/dbt/utils.py +++ b/dbt/utils.py @@ -302,8 +302,7 @@ def deep_map(func, value, keypath=(), memo=None, _notfound=object()): ret = func(value, keypath) else: ok_types = (list, dict) + atomic_types - # TODO(jeb): real error - raise TypeError( + raise dbt.exceptions.DbtConfigError( 'in deep_map, expected one of {!r}, got {!r}' .format(ok_types, type(value)) ) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 61d02d8097f..29ae515d641 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -1,5 +1,6 @@ import unittest +import dbt.exceptions import dbt.utils @@ -125,3 +126,16 @@ def test__keypath(self): actual = dbt.utils.deep_map(self.special_keypath, expected) self.assertEquals(actual, expected) + def test__noop(self): + actual = dbt.utils.deep_map(lambda x, _: x, self.input_value) + self.assertEquals(actual, self.input_value) + + def test_trivial(self): + cases = [[], {}, 1, 'abc', None, True] + for case in cases: + result = dbt.utils.deep_map(lambda x, _: x, case) + self.assertEquals(result, case) + + with self.assertRaises(dbt.exceptions.DbtConfigError): + dbt.utils.deep_map(lambda x, _: x, {'foo': object()}) + From 4ccab99765013a64cddcb8f0b84e68aa6b8691f0 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 9 Oct 2018 06:54:52 -0600 Subject: [PATCH 043/133] added test --- .../models/view_using_ref.sql | 9 +++++++++ .../test_simple_reference.py | 20 +++++++++++++++---- 2 files changed, 25 insertions(+), 4 deletions(-) create mode 100644 test/integration/003_simple_reference_test/models/view_using_ref.sql diff --git a/test/integration/003_simple_reference_test/models/view_using_ref.sql b/test/integration/003_simple_reference_test/models/view_using_ref.sql new file mode 100644 index 00000000000..664f6290681 --- /dev/null +++ b/test/integration/003_simple_reference_test/models/view_using_ref.sql @@ -0,0 +1,9 @@ +{{ + config( + materialized = "view" + ) +}} + +select gender, count(*) as ct from {{ var('var_ref') }} +group by gender +order by gender asc diff --git a/test/integration/003_simple_reference_test/test_simple_reference.py b/test/integration/003_simple_reference_test/test_simple_reference.py index 1fb17884ce8..66fe496d212 100644 --- a/test/integration/003_simple_reference_test/test_simple_reference.py +++ b/test/integration/003_simple_reference_test/test_simple_reference.py @@ -9,6 +9,16 @@ def schema(self): def models(self): return "test/integration/003_simple_reference_test/models" + @property + def project_config(self): + return { + 'models': { + 'vars': { + 'var_ref': '{{ ref("view_copy") }}', + } + } + } + @use_profile('postgres') def test__postgres__simple_reference(self): self.use_default_project() @@ -17,7 +27,7 @@ def test__postgres__simple_reference(self): results = self.run_dbt() # ephemeral_copy doesn't show up in results - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) # Copies should match self.assertTablesEqual("seed","incremental_copy") @@ -29,11 +39,12 @@ def test__postgres__simple_reference(self): self.assertTablesEqual("summary_expected","materialized_summary") self.assertTablesEqual("summary_expected","view_summary") self.assertTablesEqual("summary_expected","ephemeral_summary") + self.assertTablesEqual("summary_expected","view_using_ref") self.run_sql_file("test/integration/003_simple_reference_test/update.sql") results = self.run_dbt() - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) # Copies should match self.assertTablesEqual("seed","incremental_copy") @@ -45,6 +56,7 @@ def test__postgres__simple_reference(self): self.assertTablesEqual("summary_expected","materialized_summary") self.assertTablesEqual("summary_expected","view_summary") self.assertTablesEqual("summary_expected","ephemeral_summary") + self.assertTablesEqual("summary_expected","view_using_ref") @use_profile('snowflake') def test__snowflake__simple_reference(self): @@ -52,7 +64,7 @@ def test__snowflake__simple_reference(self): self.run_sql_file("test/integration/003_simple_reference_test/seed.sql") results = self.run_dbt() - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) # Copies should match self.assertManyTablesEqual( @@ -64,7 +76,7 @@ def test__snowflake__simple_reference(self): "test/integration/003_simple_reference_test/update.sql") results = self.run_dbt() - self.assertEqual(len(results), 7) + self.assertEqual(len(results), 8) self.assertManyTablesEqual( ["SEED", "INCREMENTAL_COPY", "MATERIALIZED_COPY", "VIEW_COPY"], From 29584e3c5161073e25898f18f65dab1db5f8f28b Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 9 Oct 2018 07:07:08 -0600 Subject: [PATCH 044/133] Fix a bug where vars were rendered under models/seeds in the config --- dbt/config.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dbt/config.py b/dbt/config.py index c59c365bdac..95b251c27c2 100644 --- a/dbt/config.py +++ b/dbt/config.py @@ -99,7 +99,7 @@ def __init__(self, cli_vars): self.context['var'] = Var(None, self.context, cli_vars) @staticmethod - def _is_hook_path(keypath): + def _is_hook_or_model_vars_path(keypath): if not keypath: return False @@ -107,10 +107,14 @@ def _is_hook_path(keypath): # run hooks if first in {'on-run-start', 'on-run-end'}: return True - # model hooks + # models have two things to avoid if first in {'seeds', 'models'}: + # model-level hooks if 'pre-hook' in keypath or 'post-hook' in keypath: return True + # model-level 'vars' declarations + if 'vars' in keypath: + return True return False @@ -126,8 +130,9 @@ def _render_project_entry(self, value, keypath): :param key str: The key to convert on. :return Any: The rendered entry. """ - # hooks should be treated as raw sql, they'll get rendered later - if self._is_hook_path(keypath): + # hooks should be treated as raw sql, they'll get rendered later. + # Same goes for 'vars' declarations inside 'models'/'seeds'. + if self._is_hook_or_model_vars_path(keypath): return value return self.render_value(value) From ccee039c7693af016f1a78a2a34dcbb51d069adf Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Thu, 20 Sep 2018 11:53:57 -0600 Subject: [PATCH 045/133] First pass on caching --- dbt/adapters/bigquery/impl.py | 6 +- dbt/adapters/cache.py | 258 ++++++++++++++++++ dbt/adapters/default/impl.py | 134 +++++++-- dbt/adapters/postgres/impl.py | 21 +- dbt/adapters/snowflake/impl.py | 16 +- dbt/contracts/graph/manifest.py | 6 + .../global_project/macros/adapters/common.sql | 24 ++ .../macros/materializations/seed/seed.sql | 1 + .../operations/relations/get_relations.sql | 4 + .../macros/relations/postgres_relations.sql | 64 +++++ .../macros/relations/redshift_relations.sql | 4 + dbt/node_runners.py | 5 + test/unit/test_cache.py | 217 +++++++++++++++ test/unit/test_postgres_adapter.py | 11 +- 14 files changed, 736 insertions(+), 35 deletions(-) create mode 100644 dbt/adapters/cache.py create mode 100644 dbt/include/global_project/macros/operations/relations/get_relations.sql create mode 100644 dbt/include/global_project/macros/relations/postgres_relations.sql create mode 100644 dbt/include/global_project/macros/relations/redshift_relations.sql create mode 100644 test/unit/test_cache.py diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index f0722e40d7b..27b157ef3aa 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -180,7 +180,10 @@ def close(cls, connection): return connection - def list_relations(self, schema, model_name=None): + def _link_cached_relations(self, manifest, schemas): + pass + + def _list_relations(self, schema, model_name=None): connection = self.get_connection(model_name) client = connection.handle @@ -222,6 +225,7 @@ def get_relation(self, schema=None, identifier=None, model_name) def drop_relation(self, relation, model_name=None): + self.cache.drop(schema=relation.schema, identifier=relation.identifier) conn = self.get_connection(model_name) client = conn.handle diff --git a/dbt/adapters/cache.py b/dbt/adapters/cache.py new file mode 100644 index 00000000000..32b1d8b9d7f --- /dev/null +++ b/dbt/adapters/cache.py @@ -0,0 +1,258 @@ +from collections import namedtuple +import threading +from dbt.logger import GLOBAL_LOGGER as logger +from copy import deepcopy + +ReferenceKey = namedtuple('ReferenceKey', 'schema identifier') + + +class CachedRelation(object): + # TODO: should this more directly related to the Relation class in the + # adapters themselves? + """Nothing about CachedRelation is guaranteed to be thread-safe!""" + def __init__(self, schema, identifier, kind=None, inner=None): + self.schema = schema + self.identifier = identifier + # This might be None, if the table is only referenced _by_ things, or + # temporariliy during cache building + # TODO: I'm still not sure we need this + self.kind = kind + self.referenced_by = {} + # a inner to store on this cached relation. + self.inner = inner + + def __str__(self): + return ( + 'CachedRelation(schema={}, identifier={}, kind={}, inner={})' + ).format(self.schema, self.identifier, self.kind, self.inner) + + def __copy__(self): + new = self.__class__(self.schema, self.identifier) + new.__dict__.update(self.__dict__) + return new + + def __deepcopy__(self, memo): + new = self.__class__(self.schema, self.identifier) + new.__dict__.update(self.__dict__) + new.referenced_by = deepcopy(self.referenced_by, memo) + new.inner = self.inner.incorporate() + + def key(self): + return ReferenceKey(self.schema, self.identifier) + + def add_reference(self, referrer): + self.referenced_by[referrer.key()] = referrer + + def collect_consequences(self): + """Recursively collect a set of ReferenceKeys that would + consequentially get dropped if this were dropped via + "drop ... cascade". + """ + consequences = {self.key()} + for relation in self.referenced_by.values(): + consequences.update(relation.collect_consequences()) + return consequences + + def release_references(self, keys): + """Non-recursively indicate that an iterable of ReferenceKey no longer + exist. Unknown keys are ignored. + """ + keys = set(self.referenced_by) & set(keys) + for key in keys: + self.referenced_by.pop(key) + + def rename(self, new_relation): + """Note that this will change the output of key(), all refs must be + updated! + """ + self.schema = new_relation.schema + self.identifier = new_relation.identifier + # rename our inner value as well + if self.inner: + # Relations store this stuff inside their `path` dict. + # but they also store a table_name, and conditionally use it in + # .render(), so we need to update that as well... + # TODO: is this an aliasing issue? Do I have to worry about this? + self.inner = self.inner.incorporate( + path={ + 'schema': new_relation.schema, + 'identifier': new_relation.identifier + }, + table_name = new_relation.identifier + ) + + def rename_key(self, old_key, new_key): + # we've lost track of the state of the world! + assert new_key not in self.referenced_by, \ + 'Internal consistency error: new name is in the cache already' + if old_key not in self.referenced_by: + return + value = self.referenced_by.pop(old_key) + self.referenced_by[new_key] = value + + +class RelationsCache(object): + def __init__(self): + # map (schema, name) -> CachedRelation object. + # I think we can ignore database/project? + self.relations = {} + # make this a reentrant lock so the adatper can hold it while buliding + # the cache. + self.lock = threading.RLock() + # the set of cached schemas + self.schemas = set() + + def _setdefault(self, relation): + self.schemas.add(relation.schema) + key = relation.key() + result = self.relations.setdefault(key, relation) + # if we previously only saw the dependent without any kind information, + # update the type info. + if relation.kind is not None: + if result.kind is None: + result.kind = relation.kind + # we've lost track of the state of the world! + assert result.kind == relation.kind, \ + 'Internal consistency error: Different non-None relation kinds' + # ditto for inner, except overwriting is fine + if relation.inner is not None: + if result.inner is None: + result.inner = relation.inner + return result + + def _add_link(self, new_referenced, new_dependent): + # get the canonical referenced entries (our new one could be canonical) + referenced = self._setdefault(new_referenced) + dependent = self._setdefault(new_dependent) + + # link them up + referenced.add_reference(dependent) + + def add_link(self, referenced_schema, referenced_name, dependent_schema, + dependent_name): + """The dependent schema refers _to_ the referenced schema + + # given arguments of: + # (jake_test, bar, jake_test, foo, view) + # foo is a view that refers to bar -> "drop bar cascade" will drop foo + # and all of foo's dependencies, recursively + """ + referenced = CachedRelation( + schema=referenced_schema, + identifier=referenced_name + ) + dependent = CachedRelation( + schema=dependent_schema, + identifier=dependent_name + ) + logger.debug('adding link, {!s} references {!s}' + .format(dependent, referenced) + ) + with self.lock: + self._add_link(referenced, dependent) + + def add(self, schema, identifier, kind=None, inner=None): + relation = CachedRelation( + schema=schema, + identifier=identifier, + kind=kind, + inner=inner + ) + logger.debug('Adding relation: {!s}'.format(relation)) + with self.lock: + self._setdefault(relation) + + def _remove_refs(self, keys): + # remove direct refs + for key in keys: + del self.relations[key] + # then remove all entries from each child + for cached in self.relations.values(): + cached.release_references(keys) + + def _drop_cascade_relation(self, dropped): + key = dropped.key() + if key not in self.relations: + # dbt drops potentially non-existent relations all the time, so + # this is fine. + logger.debug('dropped a nonexistent relationship: {!s}' + .format(dropped.key())) + return + consequences = self.relations[key].collect_consequences() + logger.debug('drop {} is cascading to {}'.format(key, consequences)) + self._remove_refs(consequences) + + def drop(self, schema, identifier): + dropped = CachedRelation(schema=schema, identifier=identifier) + logger.debug('Dropping relation: {!s}'.format(dropped)) + with self.lock: + self._drop_cascade_relation(dropped) + + def _rename_relation(self, old_relation, new_relation): + old_key = old_relation.key() + new_key = new_relation.key() + # the old relation might not exist. In that case, dbt created this + # relation earlier in its run and we can ignore it, as we don't care + # about the rename either + if old_key not in self.relations: + return + # not good + if new_key in self.relations: + raise RuntimeError( + 'Internal consistency error!: {} in {}' + .format(new_key, list(self.relations.keys())) + ) + + # On the database level, a rename updates all values that were + # previously referenced by old_name to be referenced by new_name. + # basically, the name changes but some underlying ID moves. Kind of + # like an object reference! + # Get the canonical version of old_relation and remove it from the db + relation = self.relations.pop(old_key) + + # change the old_relation's name and schema to the new relation's + relation.rename(new_relation) + # update all the relations that refer to it + for cached in self.relations.values(): + if old_key in cached.referenced_by: + logger.debug( + 'updated reference from {0} -> {2} to {1} -> {2}' + .format(old_key, new_key, cached.key()) + ) + cached.rename_key(old_key, new_key) + + self.relations[new_key] = relation + + + def rename_relation(self, old_schema, old_identifier, new_schema, + new_identifier): + old_relation = CachedRelation( + schema=old_schema, + identifier=old_identifier + ) + new_relation = CachedRelation( + schema=new_schema, + identifier=new_identifier + ) + logger.debug('Renaming relation {!s} to {!s}'.format( + old_relation, new_relation) + ) + with self.lock: + self._rename_relation(old_relation, new_relation) + + def _get_relation(self, schema, identifier): + """Get the relation by name. Raises a KeyError if it does not exist""" + relation = CachedRelation(schema=schema, identifier=identifier) + return self.relations[relation.key()] + + def get_relations(self, schema): + """Case-insensitively yield all relations matching the given schema. + """ + # TODO: What do we do if the inner value is None? Should that be + # possible? + schema = schema.lower() + with self.lock: + return [ + r.inner for r in self.relations.values() + if r.schema.lower() == schema + ] diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 8fe29e4a684..5a4ddde6edc 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -16,6 +16,7 @@ from dbt.utils import filter_null_values from dbt.adapters.default.relation import DefaultRelation +from dbt.adapters.cache import RelationsCache GET_CATALOG_OPERATION_NAME = 'get_catalog_data' @@ -24,25 +25,39 @@ connections_available = [] -def _filter_schemas(manifest): +def _expect_row_value(key, row): + if key not in row.keys(): + raise dbt.exceptions.InternalException( + 'Got a row without "{}" column, columns: {}' + .format(key, row.keys()) + ) + return row[key] + + +def _relations_filter_schemas(schemas): + def test(row): + referenced_schema = _expect_row_value('referenced_schema', row) + dependent_schema = _expect_row_value('dependent_schema', row) + # if you somehow depend on the null schema, we should return that stuff + # TODO: is that true? + if referenced_schema is not None: + referenced_schema = referenced_schema.lower() + if dependent_schema is not None: + dependent_schema = dependent_schema.lower() + return referenced_schema in schemas or dependent_schema in schemas + return test + + +def _catalog_filter_schemas(manifest): """Return a function that takes a row and decides if the row should be included in the catalog output. """ - schemas = frozenset({ - node.schema.lower() - for node in manifest.nodes.values() - }) + schemas = frozenset(s.lower() for s in manifest.get_used_schemas()) def test(row): - if 'table_schema' not in row.keys(): - # this means the get catalog operation is somehow not well formed! - raise dbt.exceptions.InternalException( - 'Got a row without "table_schema" column, columns: {}' - .format(row.keys()) - ) + table_schema = _expect_row_value('table_schema', row) # the schema may be present but None, which is not an error and should # be filtered out - table_schema = row['table_schema'] if table_schema is None: return False return table_schema.lower() in schemas @@ -85,13 +100,15 @@ class DefaultAdapter(object): "get_status", "get_result_from_cursor", "quote", - "convert_type" + "convert_type", + "cache_new_relation" ] Relation = DefaultRelation Column = Column def __init__(self, config): self.config = config + self.cache = RelationsCache() ### # ADAPTER-SPECIFIC FUNCTIONS -- each of these must be overridden in @@ -151,6 +168,17 @@ def cancel_connection(self, connection): ### # FUNCTIONS THAT SHOULD BE ABSTRACT ### + def cache_new_relation(self, relation): + """Cache a new relation in dbt. It will show up in `list relations`.""" + self.cache.add( + schema=relation.schema, + identifier=relation.identifier, + kind=relation.type, + inner=relation, + ) + # so jinja doesn't render things + return '' + @classmethod def get_result_from_cursor(cls, cursor): data = [] @@ -175,6 +203,7 @@ def drop(self, schema, relation, relation_type, model_name=None): return self.drop_relation(relation, model_name) def drop_relation(self, relation, model_name=None): + self.cache.drop(schema=relation.schema, identifier=relation.identifier) if relation.type is None: dbt.exceptions.raise_compiler_error( 'Tried to drop relation {}, but its type is null.' @@ -216,6 +245,12 @@ def rename(self, schema, from_name, to_name, model_name=None): def rename_relation(self, from_relation, to_relation, model_name=None): + self.cache.rename_relation( + old_schema=from_relation.schema, + old_identifier=from_relation.identifier, + new_schema=to_relation.schema, + new_identifier=to_relation.identifier + ) sql = 'alter table {} rename to {}'.format( from_relation, to_relation.include(schema=False)) @@ -324,10 +359,32 @@ def expand_target_column_types(self, ### # RELATIONS ### - def list_relations(self, schema, model_name=None): + def _list_relations(self, schema, model_name=None): raise dbt.exceptions.NotImplementedException( '`list_relations` is not implemented for this adapter!') + def list_relations(self, schema, model_name=None): + if schema in self.cache.schemas: + logger.debug('In list_relations, model_name={}, cache hit' + .format(model_name)) + relations = self.cache.get_relations(schema) + else: + # this indicates that we missed a schema when populating. Warn + # about it. + logger.warning( + 'Schema "{}" not in the cache while handling model "{}", this' + 'is inefficient' + .format(schema, model_name or '') + ) + # we can't build the relations cache because we don't have a + # manifest so we can't run any operations. + # TODO: Should the manifest be stored on the adapter itself? + # Then we could call _relations_cache_for_schemas here + relations = self._list_relations(schema, model_name=model_name) + logger.debug('with schema={}, model_name={}, relations={}' + .format(schema, model_name, relations)) + return relations + def _make_match_kwargs(self, schema, identifier): quoting = self.config.quoting if identifier is not None and quoting['identifier'] is False: @@ -778,15 +835,54 @@ def run_operation(self, manifest, operation_name): # Abstract methods involving the manifest ### @classmethod - def _filter_table(cls, table, manifest): - return table.where(_filter_schemas(manifest)) + def _catalog_filter_table(cls, table, manifest): + return table.where(_catalog_filter_schemas(manifest)) def get_catalog(self, manifest): try: - table = self.run_operation(manifest, - GET_CATALOG_OPERATION_NAME) + table = self.run_operation(manifest, GET_CATALOG_OPERATION_NAME) finally: self.release_connection(GET_CATALOG_OPERATION_NAME) - results = self._filter_table(table, manifest) + results = self._catalog_filter_table(table, manifest) return results + + @classmethod + def _relations_filter_table(cls, table, schemas): + return table.where(_relations_filter_schemas(schemas)) + + def _link_cached_relations(self, manifest, schemas): + """This method has to exist because BigQueryAdapter and SnowflakeAdapter + inherit from the PostgresAdapter, so they need something to override + in order to disable linking. + """ + pass + + def _relations_cache_for_schemas(self, manifest, schemas=None): + if schemas is None: + schemas = manifest.get_used_schemas() + + relations = [] + # add all relations + for schema in schemas: + # bypass the cache, of course! + for relation in self._list_relations(schema): + self.cache.add( + schema=relation.schema, + identifier=relation.name, + kind=relation.type, + inner=relation + ) + self._link_cached_relations(manifest, schemas) + # it's possible that there were no relations in some schemas. We want + # to insert the schemas we query into the cache's `.schemas` attribute + # so we can check it later + self.cache.schemas.update(schemas) + + def set_relations_cache(self, manifest): + """Run a query that gets a populated cache of the relations in the + database and set the cache on this adapter. + """ + # TODO: ensure cache is empty? + with self.cache.lock: + self._relations_cache_for_schemas(manifest) diff --git a/dbt/adapters/postgres/impl.py b/dbt/adapters/postgres/impl.py index 29355df5a76..18fdf1c3ca0 100644 --- a/dbt/adapters/postgres/impl.py +++ b/dbt/adapters/postgres/impl.py @@ -10,6 +10,9 @@ from dbt.logger import GLOBAL_LOGGER as logger +GET_RELATIONS_OPERATION_NAME = 'get_relations_data' + + class PostgresAdapter(dbt.adapters.default.DefaultAdapter): DEFAULT_TCP_KEEPALIVE = 0 # 0 means to use the default value @@ -143,7 +146,20 @@ def alter_column_type(self, schema, table, column_name, return connection, cursor - def list_relations(self, schema, model_name=None): + def _link_cached_relations(self, manifest, schemas): + # now set up any links + try: + table = self.run_operation(manifest, GET_RELATIONS_OPERATION_NAME) + # avoid a rollback when releasing the connection + self.commit_if_has_connection(GET_RELATIONS_OPERATION_NAME) + finally: + self.release_connection(GET_RELATIONS_OPERATION_NAME) + table = self._relations_filter_table(table, schemas) + + for (refed_schema, refed_name, dep_schema, dep_name) in table: + self.cache.add_link(dep_schema, dep_name, refed_schema, refed_name) + + def _list_relations(self, schema, model_name=None): sql = """ select tablename as name, schemaname as schema, 'table' as type from pg_tables where schemaname ilike '{schema}' @@ -152,8 +168,7 @@ def list_relations(self, schema, model_name=None): where schemaname ilike '{schema}' """.format(schema=schema).strip() # noqa - connection, cursor = self.add_query(sql, model_name, - auto_begin=False) + connection, cursor = self.add_query(sql, model_name, auto_begin=False) results = cursor.fetchall() diff --git a/dbt/adapters/snowflake/impl.py b/dbt/adapters/snowflake/impl.py index 20d088e59c4..cad1760db44 100644 --- a/dbt/adapters/snowflake/impl.py +++ b/dbt/adapters/snowflake/impl.py @@ -103,7 +103,10 @@ def open_connection(cls, connection): return connection - def list_relations(self, schema, model_name=None): + def _link_cached_relations(self, manifest, schemas): + pass + + def _list_relations(self, schema, model_name=None): sql = """ select table_name as name, table_schema as schema, table_type as type @@ -133,6 +136,12 @@ def list_relations(self, schema, model_name=None): def rename_relation(self, from_relation, to_relation, model_name=None): + self.cache.rename_relation( + old_schema=from_relation.schema, + old_identifier=from_relation.identifier, + new_schema=to_relation.schema, + new_identifier=to_relation.identifier + ) sql = 'alter table {} rename to {}'.format( from_relation, to_relation) @@ -209,13 +218,14 @@ def add_query(self, sql, model_name=None, auto_begin=True, return connection, cursor @classmethod - def _filter_table(cls, table, manifest): + def _catalog_filter_table(cls, table, manifest): # On snowflake, users can set QUOTED_IDENTIFIERS_IGNORE_CASE, so force # the column names to their lowercased forms. lowered = table.rename( column_names=[c.lower() for c in table.column_names] ) - return super(SnowflakeAdapter, cls)._filter_table(lowered, manifest) + return super(SnowflakeAdapter, cls)._catalog_filter_table(lowered, + manifest) def _make_match_kwargs(self, schema, identifier): quoting = self.config.quoting diff --git a/dbt/contracts/graph/manifest.py b/dbt/contracts/graph/manifest.py index 9f99e0b7a8c..c72a54b77e2 100644 --- a/dbt/contracts/graph/manifest.py +++ b/dbt/contracts/graph/manifest.py @@ -370,3 +370,9 @@ def __getattr__(self, name): raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, name) ) + + def get_used_schemas(self): + return frozenset({ + node.schema + for node in self.nodes.values() + }) diff --git a/dbt/include/global_project/macros/adapters/common.sql b/dbt/include/global_project/macros/adapters/common.sql index 3ef6dc23341..ab56e242206 100644 --- a/dbt/include/global_project/macros/adapters/common.sql +++ b/dbt/include/global_project/macros/adapters/common.sql @@ -33,6 +33,10 @@ {% endmacro %} {% macro create_table_as(temporary, relation, sql) -%} + {%- if not temporary -%} + {{ adapter.cache_new_relation(relation) }} + {%- endif -%} + {{ adapter_macro('create_table_as', temporary, relation, sql) }} {%- endmacro %} @@ -46,6 +50,8 @@ {% macro create_view_as(relation, sql) -%} + {{ adapter.cache_new_relation(relation) }} + {{ adapter_macro('create_view_as', relation, sql) }} {%- endmacro %} @@ -57,6 +63,8 @@ {% macro create_archive_table(relation, columns) -%} + {{ adapter.cache_new_relation(relation) }} + {{ adapter_macro('create_archive_table', relation, columns) }} {%- endmacro %} @@ -80,3 +88,19 @@ {{ exceptions.raise_compiler_error(msg) }} {% endmacro %} + +{% macro get_relations() -%} + {{ return(adapter_macro('get_relations')) }} +{% endmacro %} + + +{% macro default__get_relations() -%} + {# TODO: should this just return an empty agate table? #} + + {% set typename = adapter.type() %} + {% set msg -%} + get_relations not implemented for {{ typename }} + {%- endset %} + + {{ exceptions.raise_compiler_error(msg) }} +{% endmacro %} diff --git a/dbt/include/global_project/macros/materializations/seed/seed.sql b/dbt/include/global_project/macros/materializations/seed/seed.sql index 08c06d7cde1..0f1d15fcf91 100644 --- a/dbt/include/global_project/macros/materializations/seed/seed.sql +++ b/dbt/include/global_project/macros/materializations/seed/seed.sql @@ -14,6 +14,7 @@ {% macro default__create_csv_table(model) %} {%- set agate_table = model['agate_table'] -%} {%- set column_override = model['config'].get('column_types', {}) -%} + {{ adapter.cache_new_relation(this) }} {% set sql %} create table {{ this.render(False) }} ( diff --git a/dbt/include/global_project/macros/operations/relations/get_relations.sql b/dbt/include/global_project/macros/operations/relations/get_relations.sql new file mode 100644 index 00000000000..ad943c2a84b --- /dev/null +++ b/dbt/include/global_project/macros/operations/relations/get_relations.sql @@ -0,0 +1,4 @@ +{% operation get_relations_data %} + {% set relations = dbt.get_relations() %} + {{ return(relations) }} +{% endoperation %} diff --git a/dbt/include/global_project/macros/relations/postgres_relations.sql b/dbt/include/global_project/macros/relations/postgres_relations.sql new file mode 100644 index 00000000000..e6ef11f93b0 --- /dev/null +++ b/dbt/include/global_project/macros/relations/postgres_relations.sql @@ -0,0 +1,64 @@ +{% macro postgres__get_relations () -%} + {%- call statement('relations', fetch_result=True) -%} + -- {# + -- in pg_depend, objid is the dependent, refobjid is the referenced object + -- "a pg_depend entry indicates that the referenced object cannot be dropped without also dropping the dependent object." + -- #} + with relation as ( + select + pg_rewrite.ev_class as class, + pg_rewrite.oid as id + from pg_rewrite + ), + class as ( + select + oid as id, + relname as name, + relnamespace as schema, + relkind as kind + from pg_class + ), + dependency as ( + select + pg_depend.objid as id, + pg_depend.refobjid as ref + from pg_depend + ), + schema as ( + select + pg_namespace.oid as id, + pg_namespace.nspname as name + from pg_namespace + where nspname != 'information_schema' and nspname not like 'pg_%' + ), + relationships as ( + select + referenced_class.name as referenced_name, + referenced_class.schema as referenced_schema_id, + dependent_class.name as dependent_name, + dependent_class.schema as dependent_schema_id, + referenced_class.kind as kind + from relation + join class as referenced_class on relation.class=referenced_class.id + join dependency on relation.id=dependency.id + join class as dependent_class on dependency.ref=dependent_class.id + where + referenced_class.kind in ('r', 'v') and + (referenced_class.name != dependent_class.name or + referenced_class.schema != dependent_class.schema) + ) + + select + referenced_schema.name as referenced_schema, + relationships.referenced_name as referenced_name, + dependent_schema.name as dependent_schema, + relationships.dependent_name as dependent_name + from relationships + join schema as dependent_schema on relationships.dependent_schema_id=dependent_schema.id + join schema as referenced_schema on relationships.referenced_schema_id=referenced_schema.id + group by referenced_schema, referenced_name, dependent_schema, dependent_name + order by referenced_schema, referenced_name, dependent_schema, dependent_name; + {%- endcall -%} + + {{ return(load_result('relations').table) }} +{% endmacro %} diff --git a/dbt/include/global_project/macros/relations/redshift_relations.sql b/dbt/include/global_project/macros/relations/redshift_relations.sql new file mode 100644 index 00000000000..9d8f800755f --- /dev/null +++ b/dbt/include/global_project/macros/relations/redshift_relations.sql @@ -0,0 +1,4 @@ +{% macro redshift__get_relations () -%} + {# TODO: is this allowed? #} + {{ return(dbt.postgres__get_relations()) }} +{% endmacro %} diff --git a/dbt/node_runners.py b/dbt/node_runners.py index 20a25a835e1..78bfba7211f 100644 --- a/dbt/node_runners.py +++ b/dbt/node_runners.py @@ -345,8 +345,13 @@ def create_schemas(cls, config, adapter, manifest): for schema in (required_schemas - existing_schemas): adapter.create_schema(schema) + @classmethod + def populate_adapter_cache(cls, config, adapter, manifest): + adapter.set_relations_cache(manifest) + @classmethod def before_run(cls, config, adapter, manifest): + cls.populate_adapter_cache(config, adapter, manifest) cls.safe_run_hooks(config, adapter, manifest, RunHookType.Start) cls.create_schemas(config, adapter, manifest) diff --git a/test/unit/test_cache.py b/test/unit/test_cache.py new file mode 100644 index 00000000000..33bf7fc1e67 --- /dev/null +++ b/test/unit/test_cache.py @@ -0,0 +1,217 @@ +from unittest import TestCase +from dbt.adapters.cache import RelationsCache +from dbt.adapters.default.relation import DefaultRelation + + +def make_mock_relationship(schema, identifier): + return DefaultRelation.create( + database='test_db', schema=schema, identifier=identifier, + table_name=identifier, type='view' + ) + + +class TestCache(TestCase): + def setUp(self): + self.cache = RelationsCache() + + def test_empty(self): + self.assertEqual(len(self.cache.relations), 0) + relations = self.cache.get_relations('test') + self.assertEqual(len(relations), 0) + + # make sure drop() is ok + self.cache.drop('foo', 'bar') + + def test_retrieval(self): + obj = object() + self.cache.add('foo', 'bar', kind='view', inner=obj) + self.assertEqual(len(self.cache.relations), 1) + + relations = self.cache.get_relations('foo') + self.assertEqual(len(relations), 1) + self.assertIs(relations[0], obj) + + relations = self.cache.get_relations('FOO') + self.assertEqual(len(relations), 1) + self.assertIs(relations[0], obj) + + def test_additions(self): + obj = object() + self.cache.add('foo', 'bar', kind='view') + + relations = self.cache.get_relations('foo') + self.assertEqual(len(relations), 1) + self.assertIs(relations[0], None) + + self.cache.add('foo', 'bar', inner=obj) + self.assertEqual(len(self.cache.relations), 1) + self.assertEqual(self.cache.schemas, {'foo'}) + + relations = self.cache.get_relations('foo') + self.assertEqual(len(relations), 1) + self.assertIs(relations[0], obj) + + self.cache.add('FOO', 'baz', inner=object()) + self.assertEqual(len(self.cache.relations), 2) + + relations = self.cache.get_relations('foo') + self.assertEqual(len(relations), 2) + + self.assertEqual(self.cache.schemas, {'foo', 'FOO'}) + self.cache._get_relation('foo', 'bar') + self.cache._get_relation('FOO', 'baz') + + def test_rename(self): + obj = make_mock_relationship('foo', 'bar') + self.cache.add('foo', 'bar', kind='view', inner=obj) + self.cache._get_relation('foo', 'bar') + self.cache.rename_relation('foo', 'bar', 'foo', 'baz') + + relations = self.cache.get_relations('foo') + self.assertEqual(len(relations), 1) + self.assertEqual(relations[0].schema, 'foo') + self.assertEqual(relations[0].identifier, 'baz') + + relation = self.cache._get_relation('foo', 'baz') + self.assertEqual(relation.inner.schema, 'foo') + self.assertEqual(relation.inner.identifier, 'baz') + self.assertEqual(relation.kind, 'view') + self.assertEqual(relation.schema, 'foo') + self.assertEqual(relation.identifier, 'baz') + + with self.assertRaises(KeyError): + self.cache._get_relation('foo', 'bar') + + +class TestLikeDbt(TestCase): + def setUp(self): + self.cache = RelationsCache() + + self.stored_relations = {} + # add a bunch of cache entries + for ident in 'abcdef': + obj = self.stored_relations.setdefault( + ident, + make_mock_relationship('schema', ident) + ) + self.cache.add('schema', ident, kind='view', inner=obj) + # 'b' references 'a' + self.cache.add_link('schema', 'a', 'schema', 'b') + # and 'c' references 'b' + self.cache.add_link('schema', 'b', 'schema', 'c') + # and 'd' references 'b' + self.cache.add_link('schema', 'b', 'schema', 'd') + # and 'e' references 'a' + self.cache.add_link('schema', 'a', 'schema', 'e') + # and 'f' references 'd' + self.cache.add_link('schema', 'd', 'schema', 'f') + # so drop propagation goes (a -> (b -> (c (d -> f))) e) + + def assert_has_relations(self, expected): + current = set(r.identifier for r in self.cache.get_relations('schema')) + self.assertEqual(current, expected) + + def test_drop_inner(self): + self.assert_has_relations(set('abcdef')) + self.cache.drop('schema', 'b') + self.assert_has_relations({'a', 'e'}) + + def test_rename_and_drop(self): + self.assert_has_relations(set('abcdef')) + # drop the backup/tmp + self.cache.drop('schema', 'b__backup') + self.cache.drop('schema', 'b__tmp') + self.assert_has_relations(set('abcdef')) + # create a new b__tmp + self.cache.add('schema', 'b__tmp', kind='view', + inner=make_mock_relationship('schema', 'b__tmp') + ) + self.assert_has_relations(set('abcdef') | {'b__tmp'}) + # rename b -> b__backup + self.cache.rename_relation('schema', 'b', 'schema', 'b__backup') + self.assert_has_relations(set('acdef') | {'b__tmp', 'b__backup'}) + # rename temp to b + self.cache.rename_relation('schema', 'b__tmp', 'schema', 'b') + self.assert_has_relations(set('abcdef') | {'b__backup'}) + + +class TestComplexCache(TestCase): + def setUp(self): + self.cache = RelationsCache() + inputs = [ + ('foo', 'table1', 'table'), + ('bar', 'table2', 'view'), + ('foo', 'table3', 'view'), + ('foo', 'table4', 'view'), + ('bar', 'table3', 'view'), + ] + self.inputs = [ + (s, i, k, make_mock_relationship(s, i)) + for s, i, k in inputs + ] + for schema, ident, kind, inner in self.inputs: + self.cache.add(schema, ident, kind, inner) + + # foo.table3 references foo.table1 + # (create view table3 as (select * from table1...)) + self.cache.add_link( + 'foo', 'table1', + 'foo', 'table3' + ) + # bar.table3 references foo.table3 + # (create view bar.table5 as (select * from foo.table3...)) + self.cache.add_link( + 'foo', 'table3', + 'bar', 'table3' + ) + + # foo.table2 also references foo.table1 + self.cache.add_link( + 'foo', 'table1', + 'foo', 'table4', + ) + + def test_get_relations(self): + self.assertEqual(len(self.cache.get_relations('foo')), 3) + self.assertEqual(len(self.cache.get_relations('bar')), 2) + self.assertEqual(len(self.cache.relations), 5) + + def test_drop_one(self): + # dropping bar.table2 should only drop itself + self.cache.drop('bar', 'table2') + self.assertEqual(len(self.cache.get_relations('foo')), 3) + self.assertEqual(len(self.cache.get_relations('bar')), 1) + self.assertEqual(len(self.cache.relations), 4) + + def test_drop_many(self): + # dropping foo.table1 should drop everything but bar.table2. + self.cache.drop('foo', 'table1') + self.assertEqual(len(self.cache.get_relations('foo')), 0) + self.assertEqual(len(self.cache.get_relations('bar')), 1) + self.assertEqual(len(self.cache.relations), 1) + + def test_rename_root(self): + self.cache.rename_relation('foo', 'table1', 'bar', 'table1') + retrieved = self.cache._get_relation('bar','table1').inner + self.assertEqual(retrieved.schema, 'bar') + self.assertEqual(retrieved.identifier, 'table1') + self.assertEqual(len(self.cache.get_relations('foo')), 2) + self.assertEqual(len(self.cache.get_relations('bar')), 3) + + # make sure drops still cascade from the renamed table + self.cache.drop('bar', 'table1') + self.assertEqual(len(self.cache.get_relations('foo')), 0) + self.assertEqual(len(self.cache.get_relations('bar')), 1) + self.assertEqual(len(self.cache.relations), 1) + + def test_rename_branch(self): + self.cache.rename_relation('foo', 'table3', 'foo', 'table2') + self.assertEqual(len(self.cache.get_relations('foo')), 3) + self.assertEqual(len(self.cache.get_relations('bar')), 2) + + # make sure drops still cascade through the renamed table + self.cache.drop('foo', 'table1') + self.assertEqual(len(self.cache.get_relations('foo')), 0) + self.assertEqual(len(self.cache.get_relations('bar')), 1) + self.assertEqual(len(self.cache.relations), 1) + diff --git a/test/unit/test_postgres_adapter.py b/test/unit/test_postgres_adapter.py index 7dea6a40a3c..8c8e789bbc6 100644 --- a/test/unit/test_postgres_adapter.py +++ b/test/unit/test_postgres_adapter.py @@ -116,15 +116,8 @@ def test_get_catalog_various_schemas(self, mock_run): mock_run.return_value = agate.Table(rows=rows, column_names=column_names) - # we should accept the lowercase matching 'foo's only. - mock_nodes = [ - mock.MagicMock(spec_set=['schema'], schema='foo') - for k in range(2) - ] - mock_nodes.append(mock.MagicMock(spec_set=['schema'], schema='quux')) - nodes = {str(idx): n for idx, n in enumerate(mock_nodes)} - # give manifest the dict it wants - mock_manifest = mock.MagicMock(spec_set=['nodes'], nodes=nodes) + mock_manifest = mock.MagicMock() + mock_manifest.get_used_schemas.return_value = {'foo', 'quux'} catalog = self.adapter.get_catalog(mock_manifest) self.assertEqual( From 32765ed7064a34d9fb766047c3349b9339b77b14 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 24 Sep 2018 14:16:51 -0600 Subject: [PATCH 046/133] add cache verification flag --- dbt/adapters/cache.py | 5 ++--- dbt/adapters/default/impl.py | 27 +++++++++++++++++++++++++++ dbt/main.py | 10 ++++++++++ 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/dbt/adapters/cache.py b/dbt/adapters/cache.py index 32b1d8b9d7f..71763ff654f 100644 --- a/dbt/adapters/cache.py +++ b/dbt/adapters/cache.py @@ -78,7 +78,7 @@ def rename(self, new_relation): 'schema': new_relation.schema, 'identifier': new_relation.identifier }, - table_name = new_relation.identifier + table_name=new_relation.identifier ) def rename_key(self, old_key, new_key): @@ -146,7 +146,7 @@ def add_link(self, referenced_schema, referenced_name, dependent_schema, identifier=dependent_name ) logger.debug('adding link, {!s} references {!s}' - .format(dependent, referenced) + .format(dependent, referenced) ) with self.lock: self._add_link(referenced, dependent) @@ -223,7 +223,6 @@ def _rename_relation(self, old_relation, new_relation): self.relations[new_key] = relation - def rename_relation(self, old_schema, old_identifier, new_schema, new_identifier): old_relation = CachedRelation( diff --git a/dbt/adapters/default/impl.py b/dbt/adapters/default/impl.py index 5a4ddde6edc..21f6a7d6b3d 100644 --- a/dbt/adapters/default/impl.py +++ b/dbt/adapters/default/impl.py @@ -359,6 +359,31 @@ def expand_target_column_types(self, ### # RELATIONS ### + def _verify_relation_cache(self, schema, model_name=None): + cached = self.cache.get_relations(schema) + retrieved = self._list_relations(schema, model_name=model_name) + retrieved_map = { + (r.schema, r.identifier): r for r in retrieved + } + extra = set() + + for relation in cached: + key = (relation.schema, relation.identifier) + try: + retrieved_map.pop(key) + except KeyError: + extra.add(key) + + missing = set(retrieved_map) + if extra or missing: + msg = ( + 'cache failure! cache has:\nextra entries:\n\t{}\n' + 'missing entries:\n\t{}' + .format('\n\t'.join(extra), '\n\t'.join(missing)) + ) + logger.error(msg) + raise RuntimeError(msg) + def _list_relations(self, schema, model_name=None): raise dbt.exceptions.NotImplementedException( '`list_relations` is not implemented for this adapter!') @@ -367,6 +392,8 @@ def list_relations(self, schema, model_name=None): if schema in self.cache.schemas: logger.debug('In list_relations, model_name={}, cache hit' .format(model_name)) + if dbt.flags.VERIFY_RELATION_CACHE: + self._verify_relation_cache(schema, model_name) relations = self.cache.get_relations(schema) else: # this indicates that we missed a schema when populating. Warn diff --git a/dbt/main.py b/dbt/main.py index 1f2ba98da53..518191cea6f 100644 --- a/dbt/main.py +++ b/dbt/main.py @@ -251,6 +251,8 @@ def invoke_dbt(parsed): return None flags.NON_DESTRUCTIVE = getattr(parsed, 'non_destructive', False) + flags.VERIFY_RELATION_CACHE = getattr(parsed, 'verify_relation_cache', + False) arg_drop_existing = getattr(parsed, 'drop_existing', False) arg_full_refresh = getattr(parsed, 'full_refresh', False) @@ -331,6 +333,14 @@ def parse_args(args): should be a YAML string, eg. '{my_variable: my_value}'""" ) + base_subparser.add_argument( + # if enabled, everything that would ordinarily hit the cache will + # instead perform the query and verify the result against the cache. + '--verify-relation-cache', + action='store_true', + help=argparse.SUPPRESS, + ) + sub = subs.add_parser('init', parents=[base_subparser]) sub.add_argument('project_name', type=str, help='Name of the new project') sub.set_defaults(cls=init_task.InitTask, which='init') From cf77a9a74401ee48ba7d97f8591e3d628edbb4f5 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 24 Sep 2018 17:34:30 -0600 Subject: [PATCH 047/133] tons of logging --- dbt/adapters/cache.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/dbt/adapters/cache.py b/dbt/adapters/cache.py index 71763ff654f..4fa79398414 100644 --- a/dbt/adapters/cache.py +++ b/dbt/adapters/cache.py @@ -2,6 +2,7 @@ import threading from dbt.logger import GLOBAL_LOGGER as logger from copy import deepcopy +import pprint ReferenceKey = namedtuple('ReferenceKey', 'schema identifier') @@ -102,6 +103,16 @@ def __init__(self): # the set of cached schemas self.schemas = set() + def dump_graph(self): + return { + '{}.{}'.format(k.schema, k.identifier): + [ + '{}.{}'.format(x.schema, x.identifier) + for x in v.referenced_by + ] + for k, v in self.relations.items() + } + def _setdefault(self, relation): self.schemas.add(relation.schema) key = relation.key() @@ -148,8 +159,14 @@ def add_link(self, referenced_schema, referenced_name, dependent_schema, logger.debug('adding link, {!s} references {!s}' .format(dependent, referenced) ) + logger.debug('before adding link: {}'.format( + pprint.pformat(self.dump_graph())) + ) with self.lock: self._add_link(referenced, dependent) + logger.debug('after adding link: {}'.format( + pprint.pformat(self.dump_graph())) + ) def add(self, schema, identifier, kind=None, inner=None): relation = CachedRelation( @@ -159,8 +176,14 @@ def add(self, schema, identifier, kind=None, inner=None): inner=inner ) logger.debug('Adding relation: {!s}'.format(relation)) + logger.debug('before adding: {}'.format( + pprint.pformat(self.dump_graph())) + ) with self.lock: self._setdefault(relation) + logger.debug('after adding: {}'.format( + pprint.pformat(self.dump_graph())) + ) def _remove_refs(self, keys): # remove direct refs @@ -185,8 +208,14 @@ def _drop_cascade_relation(self, dropped): def drop(self, schema, identifier): dropped = CachedRelation(schema=schema, identifier=identifier) logger.debug('Dropping relation: {!s}'.format(dropped)) + logger.debug('before drop: {}'.format( + pprint.pformat(self.dump_graph())) + ) with self.lock: self._drop_cascade_relation(dropped) + logger.debug('after drop: {}'.format( + pprint.pformat(self.dump_graph())) + ) def _rename_relation(self, old_relation, new_relation): old_key = old_relation.key() @@ -195,6 +224,10 @@ def _rename_relation(self, old_relation, new_relation): # relation earlier in its run and we can ignore it, as we don't care # about the rename either if old_key not in self.relations: + logger.debug( + 'old key {} not found in self.relations, assuming temporary' + .format(old_key) + ) return # not good if new_key in self.relations: @@ -236,8 +269,14 @@ def rename_relation(self, old_schema, old_identifier, new_schema, logger.debug('Renaming relation {!s} to {!s}'.format( old_relation, new_relation) ) + logger.debug('before rename: {}'.format( + pprint.pformat(self.dump_graph())) + ) with self.lock: self._rename_relation(old_relation, new_relation) + logger.debug('after rename: {}'.format( + pprint.pformat(self.dump_graph())) + ) def _get_relation(self, schema, identifier): """Get the relation by name. Raises a KeyError if it does not exist""" From 9f5040d8cc2983329c6bf9ea93c2afb76bde85ef Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Mon, 24 Sep 2018 17:35:17 -0600 Subject: [PATCH 048/133] make jinja templates kinda debuggable by injecting ourselves into the linecache --- dbt/clients/jinja.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/dbt/clients/jinja.py b/dbt/clients/jinja.py index fa2fea4ecd7..8a13017c9a7 100644 --- a/dbt/clients/jinja.py +++ b/dbt/clients/jinja.py @@ -37,13 +37,26 @@ def _parse(self, source, name, filename): jinja2._compat.encode_filename(filename) ).parse() + def _compile(self, source, filename): + import linecache, codecs, os + if filename == '