From 82bd794214fd48b88652862fe54b6f38f454b195 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Fri, 25 Oct 2019 08:37:43 -0600 Subject: [PATCH] Fix tests to ignore rollbacks for bq, fix redshift tests Trim down the "target" context value to use the opt-in connection_info Make sure it contains a superset of the documented stuff Make sure it does not contain any blacklisted items Change some asserts to raise InternalExceptions because assert error messages are bad --- core/dbt/adapters/base/connections.py | 14 +++++++-- core/dbt/adapters/base/relation.py | 13 ++++++-- core/dbt/adapters/sql/connections.py | 10 +++++-- core/dbt/clients/jinja.py | 4 +++ core/dbt/compilation.py | 6 ++-- core/dbt/context/base.py | 20 ++++++++----- core/dbt/contracts/connection.py | 19 ++++++++---- .../dbt/adapters/bigquery/connections.py | 3 +- .../bigquery/dbt/adapters/bigquery/impl.py | 12 ++++++-- .../dbt/adapters/postgres/connections.py | 3 +- .../dbt/adapters/redshift/connections.py | 5 ++-- .../dbt/adapters/snowflake/connections.py | 3 +- .../051_query_comments_test/models/x.sql | 30 +++++++++++++++++++ .../test_query_comments.py | 7 +++-- 14 files changed, 118 insertions(+), 31 deletions(-) diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index e4ca78ac52e..fd0cddc6ca7 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -93,6 +93,10 @@ def set_connection_name(self, name: Optional[str] = None) -> Connection: # named 'master' conn_name = 'master' else: + if not isinstance(name, str): + raise dbt.exceptions.CompilerException( + f'For connection name, got {name} - not a string!' + ) assert isinstance(name, str) conn_name = name @@ -223,7 +227,10 @@ def _close_handle(cls, connection: Connection) -> None: def _rollback(cls, connection: Connection) -> None: """Roll back the given connection.""" if dbt.flags.STRICT_MODE: - assert isinstance(connection, Connection) + if not isinstance(connection, Connection): + raise dbt.exceptions.CompilerException( + f'In _rollback, got {connection} - not a Connection!' + ) if connection.transaction_open is False: raise dbt.exceptions.InternalException( @@ -238,7 +245,10 @@ def _rollback(cls, connection: Connection) -> None: @classmethod def close(cls, connection: Connection) -> Connection: if dbt.flags.STRICT_MODE: - assert isinstance(connection, Connection) + if not isinstance(connection, Connection): + raise dbt.exceptions.CompilerException( + f'In close, got {connection} - not a Connection!' + ) # if the connection is in closed or init, there's nothing to do if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}: diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index d44f7b28cfe..19fc78f9951 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -16,6 +16,7 @@ from dbt.contracts.util import Replaceable from dbt.contracts.graph.compiled import CompiledNode from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode +from dbt.exceptions import InternalException from dbt import deprecations @@ -330,10 +331,18 @@ def create_from( **kwargs: Any, ) -> Self: if node.resource_type == NodeType.Source: - assert isinstance(node, ParsedSourceDefinition) + if not isinstance(node, ParsedSourceDefinition): + raise InternalException( + 'type mismatch, expected ParsedSourceDefinition but got {}' + .format(type(node)) + ) return cls.create_from_source(node, **kwargs) else: - assert isinstance(node, (ParsedNode, CompiledNode)) + if not isinstance(node, (ParsedNode, CompiledNode)): + raise InternalException( + 'type mismatch, expected ParsedNode or CompiledNode but ' + 'got {}'.format(type(node)) + ) return cls.create_from_node(config, node, **kwargs) @classmethod diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index c4230ebd5a0..d6f192302fe 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -131,7 +131,10 @@ def begin(self): connection = self.get_thread_connection() if dbt.flags.STRICT_MODE: - assert isinstance(connection, Connection) + if not isinstance(connection, Connection): + raise dbt.exceptions.CompilerException( + f'In begin, got {connection} - not a Connection!' + ) if connection.transaction_open is True: raise dbt.exceptions.InternalException( @@ -146,7 +149,10 @@ def begin(self): def commit(self): connection = self.get_thread_connection() if dbt.flags.STRICT_MODE: - assert isinstance(connection, Connection) + if not isinstance(connection, Connection): + raise dbt.exceptions.CompilerException( + f'In commit, got {connection} - not a Connection!' + ) if connection.transaction_open is False: raise dbt.exceptions.InternalException( diff --git a/core/dbt/clients/jinja.py b/core/dbt/clients/jinja.py index 39e35f10a22..b684ffb8319 100644 --- a/core/dbt/clients/jinja.py +++ b/core/dbt/clients/jinja.py @@ -142,6 +142,10 @@ def exception_handler(self) -> Iterator[None]: dbt.exceptions.raise_compiler_error(str(e)) def call_macro(self, *args, **kwargs): + if self.context is None: + raise dbt.exceptions.InternalException( + 'Context is still None in call_macro!' + ) assert self.context is not None macro = self.get_macro() diff --git a/core/dbt/compilation.py b/core/dbt/compilation.py index b6bbc5442b6..c3af181e93e 100644 --- a/core/dbt/compilation.py +++ b/core/dbt/compilation.py @@ -76,8 +76,10 @@ def recursively_prepend_ctes(model, manifest): return (model, model.extra_ctes, manifest) if dbt.flags.STRICT_MODE: - assert isinstance(model, tuple(COMPILED_TYPES.values())), \ - 'Bad model type: {}'.format(type(model)) + if not isinstance(model, tuple(COMPILED_TYPES.values())): + raise dbt.exceptions.InternalException( + 'Bad model type: {}'.format(type(model)) + ) prepended_ctes = [] diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py index 2c48183e577..f0b3c18a428 100644 --- a/core/dbt/context/base.py +++ b/core/dbt/context/base.py @@ -193,14 +193,18 @@ def __init__(self, config): self.config = config def get_target(self) -> Dict[str, Any]: - target_name = self.config.target_name - target = self.config.to_profile_info() - del target['credentials'] - target.update(self.config.credentials.to_dict(with_aliases=True)) - target['type'] = self.config.credentials.type - target.pop('pass', None) - target.pop('password', None) - target['name'] = target_name + target = dict( + self.config.credentials.connection_info(with_aliases=True) + ) + target.update({ + 'type': self.config.credentials.type, + 'threads': self.config.threads, + 'name': self.config.target_name, + # not specified, but present for compatibility + 'target_name': self.config.target_name, + 'profile_name': self.config.profile_name, + 'config': self.config.config.to_dict(), + }) return target @property diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index d8a83f4f9ae..7ca55af664a 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -1,7 +1,8 @@ import abc +import itertools from dataclasses import dataclass, field from typing import ( - Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType + Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType, List ) from typing_extensions import Protocol @@ -88,11 +89,19 @@ def type(self) -> str: 'type not implemented for base credentials class' ) - def connection_info(self) -> Iterable[Tuple[str, Any]]: + def connection_info( + self, *, with_aliases: bool = False + ) -> Iterable[Tuple[str, Any]]: """Return an ordered iterator of key/value pairs for pretty-printing. """ - as_dict = self.to_dict() - for key in self._connection_keys(): + as_dict = self.to_dict(omit_none=False, with_aliases=with_aliases) + connection_keys = set(self._connection_keys()) + aliases: List[str] = [] + if with_aliases: + aliases = [ + k for k, v in self._ALIASES.items() if v in connection_keys + ] + for key in itertools.chain(self._connection_keys(), aliases): if key in as_dict: yield key, as_dict[key] @@ -109,7 +118,7 @@ def from_dict(cls, data): def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: return translate_aliases(kwargs, cls._ALIASES) - def to_dict(self, omit_none=True, validate=False, with_aliases=False): + def to_dict(self, omit_none=True, validate=False, *, with_aliases=False): serialized = super().to_dict(omit_none=omit_none, validate=validate) if with_aliases: serialized.update({ diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index a888be70e33..51f37b5bfb2 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -45,7 +45,8 @@ def type(self): return 'bigquery' def _connection_keys(self): - return ('method', 'database', 'schema', 'location') + return ('method', 'database', 'schema', 'location', 'priority', + 'timeout_seconds') class BigQueryConnectionManager(BaseConnectionManager): diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index d25cc984e77..58464882b24 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -331,8 +331,16 @@ def execute_model(self, model, materialization, sql_override=None, if flags.STRICT_MODE: connection = self.connections.get_thread_connection() - assert isinstance(connection, Connection) - assert(connection.name == model.get('name')) + if not isinstance(connection, Connection): + raise dbt.exceptions.CompilerException( + f'Got {connection} - not a Connection!' + ) + model_uid = model.get('unique_id') + if connection.name != model_uid: + raise dbt.exceptions.InternalException( + f'Connection had name "{connection.name}", expected model ' + f'unique id of "{model_uid}"' + ) if materialization == 'view': res = self._materialize_as_view(model) diff --git a/plugins/postgres/dbt/adapters/postgres/connections.py b/plugins/postgres/dbt/adapters/postgres/connections.py index decfff0e59b..260ffb44fc4 100644 --- a/plugins/postgres/dbt/adapters/postgres/connections.py +++ b/plugins/postgres/dbt/adapters/postgres/connections.py @@ -31,7 +31,8 @@ def type(self): return 'postgres' def _connection_keys(self): - return ('host', 'port', 'user', 'database', 'schema', 'search_path') + return ('host', 'port', 'user', 'database', 'schema', 'search_path', + 'keepalives_idle') class PostgresConnectionManager(SQLConnectionManager): diff --git a/plugins/redshift/dbt/adapters/redshift/connections.py b/plugins/redshift/dbt/adapters/redshift/connections.py index d35ba7931ea..f69ca8aa29f 100644 --- a/plugins/redshift/dbt/adapters/redshift/connections.py +++ b/plugins/redshift/dbt/adapters/redshift/connections.py @@ -52,9 +52,8 @@ def type(self): return 'redshift' def _connection_keys(self): - return ( - 'host', 'port', 'user', 'database', 'schema', 'method', - 'search_path') + keys = super()._connection_keys() + return keys + ('method', 'cluster_id', 'iam_duration_seconds') class RedshiftConnectionManager(PostgresConnectionManager): diff --git a/plugins/snowflake/dbt/adapters/snowflake/connections.py b/plugins/snowflake/dbt/adapters/snowflake/connections.py index 9aa06cca0d0..9862f393afc 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/connections.py +++ b/plugins/snowflake/dbt/adapters/snowflake/connections.py @@ -35,7 +35,8 @@ def type(self): return 'snowflake' def _connection_keys(self): - return ('account', 'user', 'database', 'schema', 'warehouse', 'role') + return ('account', 'user', 'database', 'schema', 'warehouse', 'role', + 'client_session_keep_alive') def auth_args(self): # Pull all of the optional authentication args for the connector, diff --git a/test/integration/051_query_comments_test/models/x.sql b/test/integration/051_query_comments_test/models/x.sql index 589430109d9..d103fc6f92c 100644 --- a/test/integration/051_query_comments_test/models/x.sql +++ b/test/integration/051_query_comments_test/models/x.sql @@ -1,3 +1,33 @@ +{% set blacklist = ['pass', 'password', 'keyfile', 'keyfile.json', 'password', 'private_key_passphrase'] %} +{% for key in blacklist %} + {% if key in blacklist and blacklist[key] %} + {% do exceptions.raise_compiler_error('invalid target, found banned key "' ~ key ~ '"') %} + {% endif %} +{% endfor %} + +{% if 'type' not in target %} + {% do exceptions.raise_compiler_error('invalid target, missing "type"') %} +{% endif %} + +{% set required = ['name', 'schema', 'type', 'threads'] %} + +{# Require what we docuement at https://docs.getdbt.com/docs/target #} +{% if target.type == 'postgres' or target.type == 'redshift' %} + {% do required.extend(['dbname', 'host', 'user', 'port']) %} +{% elif target.type == 'snowflake' %} + {% do required.extend(['database', 'warehouse', 'user', 'role', 'account']) %} +{% elif target.type == 'bigquery' %} + {% do required.extend(['project']) %} +{% else %} + {% do exceptions.raise_compiler_error('invalid target, got unknown type "' ~ target.type ~ '"') %} + +{% endif %} + +{% for value in required %} + {% if value not in target %} + {% do exceptions.raise_compiler_error('invalid target, missing "' ~ value ~ '"') %} + {% endif %} +{% endfor %} {% do run_query('select 2 as inner_id') %} select 1 as outer_id diff --git a/test/integration/051_query_comments_test/test_query_comments.py b/test/integration/051_query_comments_test/test_query_comments.py index ea277f3522b..a66a8c84341 100644 --- a/test/integration/051_query_comments_test/test_query_comments.py +++ b/test/integration/051_query_comments_test/test_query_comments.py @@ -51,7 +51,7 @@ def query_comment(self, model_name, log): if log['message'].startswith(prefix): msg = log['message'][len(prefix):] - if msg in {'COMMIT', 'BEGIN'}: + if msg in {'COMMIT', 'BEGIN', 'ROLLBACK'}: return None return msg return None @@ -78,7 +78,10 @@ def profile_config(self): return {'config': {'query_comment': 'dbt\nrules!\n'}} def matches_comment(self, msg) -> bool: - self.assertTrue(msg.startswith('-- dbt\n-- rules!\n')) + self.assertTrue( + msg.startswith('-- dbt\n-- rules!\n'), + f'{msg} did not start with query comment' + ) @use_profile('postgres') def test_postgres_comments(self):