From da86314757dd6a9e6db260630c629fc33a93a262 Mon Sep 17 00:00:00 2001 From: Jacob Beck Date: Tue, 24 Sep 2019 09:40:55 -0600 Subject: [PATCH] Convert Relation types to hologram.JsonSchemaMixin Fix a lot of mypy things, add a number of adapter-ish modules to it Split relations and columns into separate files split context.common into base + common - base is all that's required for the config renderer Move Credentials into connection contracts since that's what they really are Removed model_name/table_name -> consolidated to identifier - I hope I did not break seeds, which claimed to care about render(False) Unify shared 'external' relation type with bigquery's own hack workarounds for some import cycles with plugin registration and config p arsing Assorted backwards compatibility fixes around types, deep_merge vs shallow merge Remove APIObject --- core/dbt/adapters/base/__init__.py | 8 +- core/dbt/adapters/base/column.py | 93 +++ core/dbt/adapters/base/connections.py | 159 ++--- core/dbt/adapters/base/impl.py | 16 +- core/dbt/adapters/base/plugin.py | 35 +- core/dbt/adapters/base/relation.py | 573 +++++++++--------- core/dbt/adapters/cache.py | 1 - core/dbt/adapters/factory.py | 51 +- core/dbt/adapters/sql/connections.py | 50 +- core/dbt/adapters/sql/impl.py | 24 +- core/dbt/api/__init__.py | 5 - core/dbt/api/object.py | 125 ---- core/dbt/config/__init__.py | 2 +- core/dbt/config/renderer.py | 2 +- core/dbt/config/runtime.py | 12 +- core/dbt/context/base.py | 138 +++++ core/dbt/context/common.py | 216 ++----- core/dbt/context/parser.py | 2 +- core/dbt/context/runtime.py | 3 +- core/dbt/contracts/connection.py | 64 +- core/dbt/contracts/util.py | 4 +- core/dbt/deprecations.py | 10 + .../global_project/macros/adapters/common.sql | 3 +- .../macros/materializations/seed/seed.sql | 4 +- core/dbt/parser/schemas.py | 2 +- core/dbt/tracking.py | 4 +- core/dbt/utils.py | 17 +- .../dbt/adapters/bigquery/__init__.py | 2 +- .../bigquery/dbt/adapters/bigquery/column.py | 121 ++++ .../dbt/adapters/bigquery/connections.py | 2 +- .../bigquery/dbt/adapters/bigquery/impl.py | 14 +- .../dbt/adapters/bigquery/relation.py | 209 +------ .../dbt/include/postgres/macros/adapters.sql | 1 - .../dbt/adapters/snowflake/relation.py | 54 +- .../dbt/include/snowflake/macros/adapters.sql | 4 +- .../test_concurrent_transaction.py | 3 + test/unit/test_bigquery_adapter.py | 24 +- test/unit/test_cache.py | 3 +- test/unit/utils.py | 1 - tox.ini | 5 +- 40 files changed, 1012 insertions(+), 1054 deletions(-) create mode 100644 core/dbt/adapters/base/column.py delete mode 100644 core/dbt/api/__init__.py delete mode 100644 core/dbt/api/object.py create mode 100644 core/dbt/context/base.py create mode 100644 plugins/bigquery/dbt/adapters/bigquery/column.py diff --git a/core/dbt/adapters/base/__init__.py b/core/dbt/adapters/base/__init__.py index 5edf237447b..39461477c69 100644 --- a/core/dbt/adapters/base/__init__.py +++ b/core/dbt/adapters/base/__init__.py @@ -1,8 +1,10 @@ # these are all just exports, #noqa them so flake8 will be happy + +# TODO: Should we still include this in the `adapters` namespace? +from dbt.contracts.connection import Credentials # noqa from dbt.adapters.base.meta import available # noqa -from dbt.adapters.base.relation import BaseRelation # noqa -from dbt.adapters.base.relation import Column # noqa from dbt.adapters.base.connections import BaseConnectionManager # noqa -from dbt.adapters.base.connections import Credentials # noqa +from dbt.adapters.base.relation import BaseRelation, RelationType # noqa +from dbt.adapters.base.column import Column # noqa from dbt.adapters.base.impl import BaseAdapter # noqa from dbt.adapters.base.plugin import AdapterPlugin # noqa diff --git a/core/dbt/adapters/base/column.py b/core/dbt/adapters/base/column.py new file mode 100644 index 00000000000..c6e6fcb3288 --- /dev/null +++ b/core/dbt/adapters/base/column.py @@ -0,0 +1,93 @@ +from dataclasses import dataclass + +from hologram import JsonSchemaMixin + +from typing import TypeVar, Dict, ClassVar, Any, Optional, Type + +Self = TypeVar('Self', bound='Column') + + +@dataclass +class Column(JsonSchemaMixin): + TYPE_LABELS: ClassVar[Dict[str, str]] = { + 'STRING': 'TEXT', + 'TIMESTAMP': 'TIMESTAMP', + 'FLOAT': 'FLOAT', + 'INTEGER': 'INT' + } + column: str + dtype: str + char_size: Optional[int] = None + numeric_precision: Optional[Any] = None + numeric_scale: Optional[Any] = None + + @classmethod + def translate_type(cls, dtype: str) -> str: + return cls.TYPE_LABELS.get(dtype.upper(), dtype) + + @classmethod + def create(cls: Type[Self], name, label_or_dtype: str) -> Self: + column_type = cls.translate_type(label_or_dtype) + return cls(name, column_type) + + @property + def name(self) -> str: + return self.column + + @property + def quoted(self) -> str: + return '"{}"'.format(self.column) + + @property + def data_type(self) -> str: + if self.is_string(): + return Column.string_type(self.string_size()) + elif self.is_numeric(): + return Column.numeric_type(self.dtype, self.numeric_precision, + self.numeric_scale) + else: + return self.dtype + + def is_string(self) -> bool: + return self.dtype.lower() in ['text', 'character varying', 'character', + 'varchar'] + + def is_numeric(self) -> bool: + return self.dtype.lower() in ['numeric', 'number'] + + def string_size(self) -> int: + if not self.is_string(): + raise RuntimeError("Called string_size() on non-string field!") + + if self.dtype == 'text' or self.char_size is None: + # char_size should never be None. Handle it reasonably just in case + return 256 + else: + return int(self.char_size) + + def can_expand_to(self: Self, other_column: Self) -> bool: + """returns True if this column can be expanded to the size of the + other column""" + if not self.is_string() or not other_column.is_string(): + return False + + return other_column.string_size() > self.string_size() + + def literal(self, value: Any) -> str: + return "{}::{}".format(value, self.data_type) + + @classmethod + def string_type(cls, size: int) -> str: + return "character varying({})".format(size) + + @classmethod + def numeric_type(cls, dtype: str, precision: Any, scale: Any) -> str: + # This could be decimal(...), numeric(...), number(...) + # Just use whatever was fed in here -- don't try to get too clever + if precision is None or scale is None: + return dtype + else: + return "{}({},{})".format(dtype, precision, scale) + + def __repr__(self) -> str: + return "".format(self.name, self.data_type) diff --git a/core/dbt/adapters/base/connections.py b/core/dbt/adapters/base/connections.py index ddb7a15522e..9f2cedf4616 100644 --- a/core/dbt/adapters/base/connections.py +++ b/core/dbt/adapters/base/connections.py @@ -2,66 +2,17 @@ import multiprocessing import os from threading import get_ident +from typing import ( + Dict, Tuple, Hashable, Optional, ContextManager, List +) + +import agate import dbt.exceptions import dbt.flags -from dbt.contracts.connection import Connection -from dbt.contracts.util import Replaceable +from dbt.config import Profile +from dbt.contracts.connection import Connection, Identifier, ConnectionState from dbt.logger import GLOBAL_LOGGER as logger -from dbt.utils import translate_aliases - -from hologram.helpers import ExtensibleJsonSchemaMixin - -from dataclasses import dataclass, field -from typing import Any, ClassVar, Dict, Tuple - - -@dataclass -class Credentials( - ExtensibleJsonSchemaMixin, - Replaceable, - metaclass=abc.ABCMeta -): - database: str - schema: str - _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) - - @abc.abstractproperty - def type(self): - raise NotImplementedError( - 'type not implemented for base credentials class' - ) - - def connection_info(self): - """Return an ordered iterator of key/value pairs for pretty-printing. - """ - as_dict = self.to_dict() - for key in self._connection_keys(): - if key in as_dict: - yield key, as_dict[key] - - @abc.abstractmethod - def _connection_keys(self) -> Tuple[str, ...]: - raise NotImplementedError - - @classmethod - def from_dict(cls, data): - data = cls.translate_aliases(data) - return super().from_dict(data) - - @classmethod - def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return translate_aliases(kwargs, cls._ALIASES) - - def to_dict(self, omit_none=True, validate=False, with_aliases=False): - serialized = super().to_dict(omit_none=omit_none, validate=validate) - if with_aliases: - serialized.update({ - new_name: serialized[canonical_name] - for new_name, canonical_name in self._ALIASES.items() - if canonical_name in serialized - }) - return serialized class BaseConnectionManager(metaclass=abc.ABCMeta): @@ -79,18 +30,18 @@ class BaseConnectionManager(metaclass=abc.ABCMeta): """ TYPE: str = NotImplemented - def __init__(self, profile): + def __init__(self, profile: Profile): self.profile = profile - self.thread_connections = {} + self.thread_connections: Dict[Hashable, Connection] = {} self.lock = multiprocessing.RLock() @staticmethod - def get_thread_identifier(): + def get_thread_identifier() -> Hashable: # note that get_ident() may be re-used, but we should never experience # that within a single process return (os.getpid(), get_ident()) - def get_thread_connection(self): + def get_thread_connection(self) -> Connection: key = self.get_thread_identifier() with self.lock: if key not in self.thread_connections: @@ -100,18 +51,18 @@ def get_thread_connection(self): ) return self.thread_connections[key] - def get_if_exists(self): + def get_if_exists(self) -> Optional[Connection]: key = self.get_thread_identifier() with self.lock: return self.thread_connections.get(key) - def clear_thread_connection(self): + def clear_thread_connection(self) -> None: key = self.get_thread_identifier() with self.lock: if key in self.thread_connections: del self.thread_connections[key] - def clear_transaction(self): + def clear_transaction(self) -> None: """Clear any existing transactions.""" conn = self.get_thread_connection() if conn is not None: @@ -121,7 +72,7 @@ def clear_transaction(self): self.commit() @abc.abstractmethod - def exception_handler(self, sql): + def exception_handler(self, sql: str) -> ContextManager: """Create a context manager that handles exceptions caused by database interactions. @@ -133,70 +84,73 @@ def exception_handler(self, sql): raise dbt.exceptions.NotImplementedException( '`exception_handler` is not implemented for this adapter!') - def set_connection_name(self, name=None): + def set_connection_name(self, name: Optional[str] = None) -> Connection: + conn_name: str if name is None: # if a name isn't specified, we'll re-use a single handle # named 'master' - name = 'master' + conn_name = 'master' + else: + assert isinstance(name, str) + conn_name = name conn = self.get_if_exists() thread_id_key = self.get_thread_identifier() if conn is None: conn = Connection( - type=self.TYPE, + type=Identifier(self.TYPE), name=None, - state='init', + state=ConnectionState.INIT, transaction_open=False, handle=None, credentials=self.profile.credentials ) self.thread_connections[thread_id_key] = conn - if conn.name == name and conn.state == 'open': + if conn.name == conn_name and conn.state == 'open': return conn - logger.debug('Acquiring new {} connection "{}".' - .format(self.TYPE, name)) + logger.debug( + 'Acquiring new {} connection "{}".'.format(self.TYPE, conn_name)) if conn.state == 'open': logger.debug( 'Re-using an available connection from the pool (formerly {}).' - .format(conn.name)) + .format(conn.name) + ) else: - logger.debug('Opening a new connection, currently in state {}' - .format(conn.state)) + logger.debug( + 'Opening a new connection, currently in state {}' + .format(conn.state) + ) self.open(conn) - conn.name = name + conn.name = conn_name return conn @abc.abstractmethod - def cancel_open(self): + def cancel_open(self) -> Optional[List[str]]: """Cancel all open connections on the adapter. (passable)""" raise dbt.exceptions.NotImplementedException( '`cancel_open` is not implemented for this adapter!' ) @abc.abstractclassmethod - def open(cls, connection): - """Open a connection on the adapter. + def open(cls, connection: Connection) -> Connection: + """Open the given connection on the adapter and return it. This may mutate the given connection (in particular, its state and its handle). This should be thread-safe, or hold the lock if necessary. The given connection should not be in either in_use or available. - - :param Connection connection: A connection object to open. - :return: A connection with a handle attached and an 'open' state. - :rtype: Connection """ raise dbt.exceptions.NotImplementedException( '`open` is not implemented for this adapter!' ) - def release(self): + def release(self) -> None: with self.lock: conn = self.get_if_exists() if conn is None: @@ -213,7 +167,7 @@ def release(self): self.clear_thread_connection() raise - def cleanup_all(self): + def cleanup_all(self) -> None: with self.lock: for connection in self.thread_connections.values(): if connection.state not in {'closed', 'init'}: @@ -228,24 +182,21 @@ def cleanup_all(self): self.thread_connections.clear() @abc.abstractmethod - def begin(self): - """Begin a transaction. (passable) - - :param str name: The name of the connection to use. - """ + def begin(self) -> None: + """Begin a transaction. (passable)""" raise dbt.exceptions.NotImplementedException( '`begin` is not implemented for this adapter!' ) @abc.abstractmethod - def commit(self): + def commit(self) -> None: """Commit a transaction. (passable)""" raise dbt.exceptions.NotImplementedException( '`commit` is not implemented for this adapter!' ) @classmethod - def _rollback_handle(cls, connection): + def _rollback_handle(cls, connection: Connection) -> None: """Perform the actual rollback operation.""" try: connection.handle.rollback() @@ -256,7 +207,7 @@ def _rollback_handle(cls, connection): ) @classmethod - def _close_handle(cls, connection): + def _close_handle(cls, connection: Connection) -> None: """Perform the actual close operation.""" # On windows, sometimes connection handles don't have a close() attr. if hasattr(connection.handle, 'close'): @@ -267,9 +218,8 @@ def _close_handle(cls, connection): .format(connection.name)) @classmethod - def _rollback(cls, connection): - """Roll back the given connection. - """ + def _rollback(cls, connection: Connection) -> None: + """Roll back the given connection.""" if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) @@ -283,15 +233,13 @@ def _rollback(cls, connection): connection.transaction_open = False - return connection - @classmethod - def close(cls, connection): + def close(cls, connection: Connection) -> Connection: if dbt.flags.STRICT_MODE: assert isinstance(connection, Connection) # if the connection is in closed or init, there's nothing to do - if connection.state in {'closed', 'init'}: + if connection.state in {ConnectionState.CLOSED, ConnectionState.INIT}: return connection if connection.transaction_open and connection.handle: @@ -299,21 +247,20 @@ def close(cls, connection): connection.transaction_open = False cls._close_handle(connection) - connection.state = 'closed' + connection.state = ConnectionState.CLOSED return connection - def commit_if_has_connection(self): - """If the named connection exists, commit the current transaction. - - :param str name: The name of the connection to use. - """ + def commit_if_has_connection(self) -> None: + """If the named connection exists, commit the current transaction.""" connection = self.get_if_exists() if connection: self.commit() @abc.abstractmethod - def execute(self, sql, auto_begin=False, fetch=False): + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> Tuple[str, agate.Table]: """Execute the given SQL. :param str sql: The sql to execute. diff --git a/core/dbt/adapters/base/impl.py b/core/dbt/adapters/base/impl.py index 93208be858b..a4ab29d51cc 100644 --- a/core/dbt/adapters/base/impl.py +++ b/core/dbt/adapters/base/impl.py @@ -21,7 +21,7 @@ from dbt.adapters.base.connections import BaseConnectionManager from dbt.adapters.base.meta import AdapterMeta, available -from dbt.adapters.base import BaseRelation +from dbt.adapters.base.relation import ComponentName, BaseRelation from dbt.adapters.base import Column as BaseColumn from dbt.adapters.cache import RelationsCache @@ -645,7 +645,7 @@ def list_relations(self, database: str, schema: str) -> List[BaseRelation]: information_schema = self.Relation.create( database=database, schema=schema, - model_name='', + identifier='', quote_policy=self.config.quoting ).information_schema() @@ -762,11 +762,13 @@ def quote_as_configured(self, identifier: str, quote_key: str) -> str: The quote key should be one of 'database' (on bigquery, 'profile'), 'identifier', or 'schema', or it will be treated as if you set `True`. """ - # TODO: Convert BaseRelation to a hologram.JsonSchemaMixin so mypy - # likes this - quotes = self.Relation.DEFAULTS['quote_policy'] - default = quotes.get(quote_key) # type: ignore - if self.config.quoting.get(quote_key, default): + try: + key = ComponentName(quote_key) + except ValueError: + return identifier + + default = self.Relation.get_default_quote_policy().get_part(key) + if self.config.quoting.get(key, default): return self.quote(identifier) else: return identifier diff --git a/core/dbt/adapters/base/plugin.py b/core/dbt/adapters/base/plugin.py index d731c3493c9..c307c97d62c 100644 --- a/core/dbt/adapters/base/plugin.py +++ b/core/dbt/adapters/base/plugin.py @@ -1,23 +1,30 @@ +from typing import List, Optional, Type + from dbt.config.project import Project +from dbt.adapters.base import BaseAdapter, Credentials class AdapterPlugin: """Defines the basic requirements for a dbt adapter plugin. - :param type adapter: An adapter class, derived from BaseAdapter - :param type credentials: A credentials object, derived from Credentials - :param str project_name: The name of this adapter plugin's associated dbt - project. - :param str include_path: The path to this adapter plugin's root - :param Optional[List[str]] dependencies: A list of adapter names that this - adapter depends upon. + :param include_path: The path to this adapter plugin's root + :param dependencies: A list of adapter names that this adapter depends + upon. """ - def __init__(self, adapter, credentials, include_path, dependencies=None): - self.adapter = adapter - self.credentials = credentials - self.include_path = include_path + def __init__( + self, + adapter: Type[BaseAdapter], + credentials: Type[Credentials], + include_path: str, + dependencies: Optional[List[str]] = None + ): + self.adapter: Type[BaseAdapter] = adapter + self.credentials: Type[Credentials] = credentials + self.include_path: str = include_path project = Project.from_project_root(include_path, {}) - self.project_name = project.project_name + self.project_name: str = project.project_name + self.dependencies: List[str] if dependencies is None: - dependencies = [] - self.dependencies = dependencies + self.dependencies = [] + else: + self.dependencies = dependencies diff --git a/core/dbt/adapters/base/relation.py b/core/dbt/adapters/base/relation.py index 84cd6bc7ecc..59728a55dff 100644 --- a/core/dbt/adapters/base/relation.py +++ b/core/dbt/adapters/base/relation.py @@ -1,102 +1,171 @@ -from dbt.api import APIObject -from dbt.utils import filter_null_values +from dbt.utils import filter_null_values, deep_merge, classproperty from dbt.node_types import NodeType import dbt.exceptions +from collections.abc import Mapping, Hashable +from dataclasses import dataclass, fields +from typing import ( + Optional, TypeVar, Generic, Any, Type, Dict, Union, List +) +from typing_extensions import Protocol -class BaseRelation(APIObject): - - Table = "table" - View = "view" - CTE = "cte" - MaterializedView = "materializedview" - ExternalTable = "externaltable" - - RelationTypes = [ - Table, - View, - CTE, - MaterializedView, - ExternalTable - ] - - DEFAULTS = { - 'metadata': { - 'type': 'BaseRelation' - }, - 'quote_character': '"', - 'quote_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'include_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'dbt_created': False, - } - - PATH_SCHEMA = { - 'type': 'object', - 'properties': { - 'database': {'type': ['string', 'null']}, - 'schema': {'type': ['string', 'null']}, - 'identifier': {'type': ['string', 'null']}, - }, - 'required': ['database', 'schema', 'identifier'], - } - - POLICY_SCHEMA = { - 'type': 'object', - 'properties': { - 'database': {'type': 'boolean'}, - 'schema': {'type': 'boolean'}, - 'identifier': {'type': 'boolean'}, - }, - 'required': ['database', 'schema', 'identifier'], - } - - SCHEMA = { - 'type': 'object', - 'properties': { - 'metadata': { - 'type': 'object', - 'properties': { - 'type': { - 'type': 'string', - 'const': 'BaseRelation', - }, - }, - }, - 'type': { - 'enum': RelationTypes + [None], - }, - 'path': PATH_SCHEMA, - 'include_policy': POLICY_SCHEMA, - 'quote_policy': POLICY_SCHEMA, - 'quote_character': {'type': 'string'}, - 'dbt_created': {'type': 'boolean'}, - }, - 'required': ['metadata', 'type', 'path', 'include_policy', - 'quote_policy', 'quote_character', 'dbt_created'] - } - - PATH_ELEMENTS = ['database', 'schema', 'identifier'] - - def _is_exactish_match(self, field, value): - if self.dbt_created and self.quote_policy.get(field) is False: - return self.get_path_part(field).lower() == value.lower() +from hologram import JsonSchemaMixin +from hologram.helpers import StrEnum + +from dbt.contracts.util import Replaceable +from dbt.contracts.graph.compiled import CompiledNode +from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode +from dbt import deprecations + + +class RelationType(StrEnum): + Table = 'table' + View = 'view' + CTE = 'cte' + MaterializedView = 'materializedview' + External = 'external' + + +class ComponentName(StrEnum): + Database = 'database' + Schema = 'schema' + Identifier = 'identifier' + + +class HasQuoting(Protocol): + quoting: Dict[str, bool] + + +class FakeAPIObject(JsonSchemaMixin, Replaceable, Mapping): + # override the mapping truthiness, len is always >1 + def __bool__(self): + return True + + def __getitem__(self, key): + # deprecations.warn('not-a-dictionary', obj=self) + try: + return getattr(self, key) + except AttributeError: + raise KeyError(key) from None + + def __iter__(self): + deprecations.warn('not-a-dictionary', obj=self) + for _, name in self._get_fields(): + yield name + + def __len__(self): + deprecations.warn('not-a-dictionary', obj=self) + return len(fields(self.__class__)) + + def incorporate(self, **kwargs): + value = self.to_dict() + value = deep_merge(value, kwargs) + return self.from_dict(value) + + +T = TypeVar('T') + + +@dataclass +class _ComponentObject(FakeAPIObject, Generic[T]): + database: T + schema: T + identifier: T + + def get_part(self, key: ComponentName) -> T: + if key == ComponentName.Database: + return self.database + elif key == ComponentName.Schema: + return self.schema + elif key == ComponentName.Identifier: + return self.identifier + else: + raise ValueError( + 'Got a key of {}, expected one of {}' + .format(key, list(ComponentName)) + ) + + def replace_dict(self, dct: Dict[ComponentName, T]): + kwargs: Dict[str, T] = {} + for k, v in dct.items(): + kwargs[str(k)] = v + return self.replace(**kwargs) + + +@dataclass +class Policy(_ComponentObject[bool]): + database: bool = True + schema: bool = True + identifier: bool = True + + +@dataclass +class Path(_ComponentObject[Optional[str]]): + database: Optional[str] + schema: Optional[str] + identifier: Optional[str] + + def get_lowered_part(self, key: ComponentName) -> Optional[str]: + part = self.get_part(key) + if part is not None: + part = part.lower() + return part + + +Self = TypeVar('Self', bound='BaseRelation') + + +@dataclass(frozen=True, eq=False, repr=False) +class BaseRelation(FakeAPIObject, Hashable): + type: Optional[RelationType] + path: Path + quote_character: str = '"' + include_policy: Policy = Policy() + quote_policy: Policy = Policy() + dbt_created: bool = False + + def _is_exactish_match(self, field: ComponentName, value: str) -> bool: + if self.dbt_created and self.quote_policy.get_part(field) is False: + return self.path.get_lowered_part(field) == value.lower() else: - return self.get_path_part(field) == value + return self.path.get_part(field) == value + + @classmethod + def _get_field_named(cls, field_name): + for field, _ in cls._get_fields(): + if field.name == field_name: + return field + # this should be unreachable + raise ValueError(f'BaseRelation has no {field_name} field!') + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.to_dict() == other.to_dict() + + @classmethod + def get_default_quote_policy(cls: Type[Self]) -> Policy: + return cls._get_field_named('quote_policy').default + + @classmethod + def get_default_include_policy(cls: Type[Self]) -> Policy: + return cls._get_field_named('include_policy').default - def matches(self, database=None, schema=None, identifier=None): + @classmethod + def get_relation_type_class(cls: Type[Self]) -> Type[RelationType]: + return cls._get_field_named('type') + + def matches( + self, + database: Optional[str] = None, + schema: Optional[str] = None, + identifier: Optional[str] = None, + ) -> bool: search = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) if not search: @@ -111,7 +180,7 @@ def matches(self, database=None, schema=None, identifier=None): if not self._is_exactish_match(k, v): exact_match = False - if self.get_path_part(k).lower() != v.lower(): + if self.path.get_lowered_part(k) != v.lower(): approximate_match = False if approximate_match and not exact_match: @@ -122,107 +191,100 @@ def matches(self, database=None, schema=None, identifier=None): return exact_match - def get_path_part(self, part): - return self.path.get(part) - - def should_quote(self, part): - return self.quote_policy.get(part) - - def should_include(self, part): - return self.include_policy.get(part) - - def quote(self, database=None, schema=None, identifier=None): + def quote( + self: Self, + database: Optional[bool] = None, + schema: Optional[bool] = None, + identifier: Optional[bool] = None, + ) -> Self: policy = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) - return self.incorporate(quote_policy=policy) + new_quote_policy = self.quote_policy.replace_dict(policy) + return self.replace(quote_policy=new_quote_policy) - def include(self, database=None, schema=None, identifier=None): + def include( + self: Self, + database: Optional[bool] = None, + schema: Optional[bool] = None, + identifier: Optional[bool] = None, + ) -> Self: policy = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) - return self.incorporate(include_policy=policy) + new_include_policy = self.include_policy.replace_dict(policy) + return self.replace(include_policy=new_include_policy) - def information_schema(self, identifier=None): - include_db = self.database is not None - include_policy = filter_null_values({ - 'database': include_db, - 'schema': True, - 'identifier': identifier is not None - }) - quote_policy = filter_null_values({ - 'database': self.quote_policy['database'], - 'schema': False, - 'identifier': False, - }) + def information_schema(self: Self, identifier=None) -> Self: + include_policy = self.include_policy.replace( + database=self.database is not None, + schema=True, + identifier=identifier is not None + ) + quote_policy = self.quote_policy.replace( + schema=False, + identifier=False, + ) - path_update = { - 'schema': 'information_schema', - 'identifier': identifier - } + path = self.path.replace( + schema='information_schema', + identifier=identifier, + ) - return self.incorporate( + return self.replace( quote_policy=quote_policy, include_policy=include_policy, - path=path_update, - table_name=identifier) + path=path, + ) - def information_schema_only(self): + def information_schema_only(self: Self) -> Self: return self.information_schema() - def information_schema_table(self, identifier): + def information_schema_table(self: Self, identifier: str) -> Self: return self.information_schema(identifier) - def render(self, use_table_name=True): - parts = [] - - for k in self.PATH_ELEMENTS: - if self.should_include(k): - path_part = self.get_path_part(k) + def render(self) -> str: + parts: List[str] = [] - if path_part is None: - continue - elif k == 'identifier': - if use_table_name: - path_part = self.table - else: - path_part = self.identifier + for k in ComponentName: + if self.include_policy.get_part(k): + path_part = self.path.get_part(k) - parts.append( - self.quote_if( - path_part, - self.should_quote(k))) + if path_part is not None: + part: str = path_part + if self.quote_policy.get_part(k): + part = self.quoted(path_part) + parts.append(part) if len(parts) == 0: raise dbt.exceptions.RuntimeException( - "No path parts are included! Nothing to render.") + "No path parts are included! Nothing to render." + ) return '.'.join(parts) - def quote_if(self, identifier, should_quote): - if should_quote: - return self.quoted(identifier) - - return identifier - def quoted(self, identifier): return '{quote_char}{identifier}{quote_char}'.format( quote_char=self.quote_character, - identifier=identifier) + identifier=identifier, + ) @classmethod - def create_from_source(cls, source, **kwargs): - quote_policy = dbt.utils.deep_merge( - cls.DEFAULTS['quote_policy'], + def create_from_source( + cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any + ) -> Self: + quote_policy = deep_merge( + cls.get_default_quote_policy().to_dict(), source.quoting.to_dict(), - kwargs.get('quote_policy', {}) + kwargs.get('quote_policy', {}), ) + return cls.create( database=source.database, schema=source.schema, @@ -232,8 +294,13 @@ def create_from_source(cls, source, **kwargs): ) @classmethod - def create_from_node(cls, config, node, table_name=None, quote_policy=None, - **kwargs): + def create_from_node( + cls: Type[Self], + config: HasQuoting, + node: Union[ParsedNode, CompiledNode], + quote_policy: Optional[Dict[str, bool]] = None, + **kwargs: Any, + ) -> Self: if quote_policy is None: quote_policy = {} @@ -243,164 +310,96 @@ def create_from_node(cls, config, node, table_name=None, quote_policy=None, database=node.database, schema=node.schema, identifier=node.alias, - table_name=table_name, quote_policy=quote_policy, **kwargs) @classmethod - def create_from(cls, config, node, **kwargs): + def create_from( + cls: Type[Self], + config: HasQuoting, + node: Union[CompiledNode, ParsedNode, ParsedSourceDefinition], + **kwargs: Any, + ) -> Self: if node.resource_type == NodeType.Source: + assert isinstance(node, ParsedSourceDefinition) return cls.create_from_source(node, **kwargs) else: + assert isinstance(node, (ParsedNode, CompiledNode)) return cls.create_from_node(config, node, **kwargs) @classmethod - def create(cls, database=None, schema=None, - identifier=None, table_name=None, - type=None, **kwargs): - if table_name is None: - table_name = identifier - - return cls(type=type, - path={ - 'database': database, - 'schema': schema, - 'identifier': identifier - }, - table_name=table_name, - **kwargs) - - def __repr__(self): + def create( + cls: Type[Self], + database: Optional[str] = None, + schema: Optional[str] = None, + identifier: Optional[str] = None, + type: Optional[RelationType] = None, + **kwargs, + ) -> Self: + kwargs.update({ + 'path': { + 'database': database, + 'schema': schema, + 'identifier': identifier, + }, + 'type': type, + }) + return cls.from_dict(kwargs) + + def __repr__(self) -> str: return "<{} {}>".format(self.__class__.__name__, self.render()) - def __hash__(self): + def __hash__(self) -> int: return hash(self.render()) - def __str__(self): + def __str__(self) -> str: return self.render() @property - def path(self): - return self.get('path', {}) + def database(self) -> Optional[str]: + return self.path.database @property - def database(self): - return self.path.get('database') + def schema(self) -> Optional[str]: + return self.path.schema @property - def schema(self): - return self.path.get('schema') + def identifier(self) -> Optional[str]: + return self.path.identifier @property - def identifier(self): - return self.path.get('identifier') + def table(self) -> Optional[str]: + return self.path.identifier # Here for compatibility with old Relation interface @property - def name(self): + def name(self) -> Optional[str]: return self.identifier - # Here for compatibility with old Relation interface - @property - def table(self): - return self.table_name - - @property - def is_table(self): - return self.type == self.Table - @property - def is_cte(self): - return self.type == self.CTE + def is_table(self) -> bool: + return self.type == RelationType.Table @property - def is_view(self): - return self.type == self.View - - -class Column: - TYPE_LABELS = { - 'STRING': 'TEXT', - 'TIMESTAMP': 'TIMESTAMP', - 'FLOAT': 'FLOAT', - 'INTEGER': 'INT' - } - - def __init__(self, column, dtype, char_size=None, numeric_precision=None, - numeric_scale=None): - self.column = column - self.dtype = dtype - self.char_size = char_size - self.numeric_precision = numeric_precision - self.numeric_scale = numeric_scale - - @classmethod - def translate_type(cls, dtype): - return cls.TYPE_LABELS.get(dtype.upper(), dtype) - - @classmethod - def create(cls, name, label_or_dtype): - column_type = cls.translate_type(label_or_dtype) - return cls(name, column_type) + def is_cte(self) -> bool: + return self.type == RelationType.CTE @property - def name(self): - return self.column - - @property - def quoted(self): - return '"{}"'.format(self.column) - - @property - def data_type(self): - if self.is_string(): - return Column.string_type(self.string_size()) - elif self.is_numeric(): - return Column.numeric_type(self.dtype, self.numeric_precision, - self.numeric_scale) - else: - return self.dtype - - def is_string(self): - return self.dtype.lower() in ['text', 'character varying', 'character', - 'varchar'] + def is_view(self) -> bool: + return self.type == RelationType.View - def is_numeric(self): - return self.dtype.lower() in ['numeric', 'number'] + @classproperty + def Table(self) -> str: + return str(RelationType.Table) - def string_size(self): - if not self.is_string(): - raise RuntimeError("Called string_size() on non-string field!") + @classproperty + def CTE(self) -> str: + return str(RelationType.CTE) - if self.dtype == 'text' or self.char_size is None: - # char_size should never be None. Handle it reasonably just in case - return 256 - else: - return int(self.char_size) - - def can_expand_to(self, other_column): - """returns True if this column can be expanded to the size of the - other column""" - if not self.is_string() or not other_column.is_string(): - return False - - return other_column.string_size() > self.string_size() - - def literal(self, value): - return "{}::{}".format(value, self.data_type) - - @classmethod - def string_type(cls, size): - return "character varying({})".format(size) - - @classmethod - def numeric_type(cls, dtype, precision, scale): - # This could be decimal(...), numeric(...), number(...) - # Just use whatever was fed in here -- don't try to get too clever - if precision is None or scale is None: - return dtype - else: - return "{}({},{})".format(dtype, precision, scale) + @classproperty + def View(self) -> str: + return str(RelationType.View) - def __repr__(self): - return "".format(self.name, self.data_type) + @classproperty + def External(self) -> str: + return str(RelationType.External) diff --git a/core/dbt/adapters/cache.py b/core/dbt/adapters/cache.py index 521ff52d975..78472ecc674 100644 --- a/core/dbt/adapters/cache.py +++ b/core/dbt/adapters/cache.py @@ -130,7 +130,6 @@ def rename(self, new_relation): 'schema': new_relation.inner.schema, 'identifier': new_relation.inner.identifier }, - table_name=new_relation.inner.identifier ) def rename_key(self, old_key, new_key): diff --git a/core/dbt/adapters/factory.py b/core/dbt/adapters/factory.py index a380e393a62..9c7a78b2e6a 100644 --- a/core/dbt/adapters/factory.py +++ b/core/dbt/adapters/factory.py @@ -1,50 +1,66 @@ -import dbt.exceptions +import threading from importlib import import_module +from typing import Type, Dict, TypeVar + +from dbt.exceptions import RuntimeException from dbt.include.global_project import PACKAGES from dbt.logger import GLOBAL_LOGGER as logger +from dbt.contracts.connection import Credentials -import threading -ADAPTER_TYPES = {} +# TODO: we can't import these because they cause an import cycle. +# currently RuntimeConfig needs to figure out default quoting for its adapter. +# We should push that elsewhere when we fixup project/profile stuff +# Instead here are some import loop avoiding-hacks right now. And Profile has +# to call into load_plugin to get credentials, so adapter/relation don't work +RuntimeConfig = TypeVar('RuntimeConfig') +BaseAdapter = TypeVar('BaseAdapter') +BaseRelation = TypeVar('BaseRelation') -_ADAPTERS = {} +ADAPTER_TYPES: Dict[str, Type[BaseAdapter]] = {} + +_ADAPTERS: Dict[str, BaseAdapter] = {} _ADAPTER_LOCK = threading.Lock() -def get_adapter_class_by_name(adapter_name): +def get_adapter_class_by_name(adapter_name: str) -> Type[BaseAdapter]: with _ADAPTER_LOCK: if adapter_name in ADAPTER_TYPES: return ADAPTER_TYPES[adapter_name] + adapter_names = ", ".join(ADAPTER_TYPES.keys()) + message = "Invalid adapter type {}! Must be one of {}" - adapter_names = ", ".join(ADAPTER_TYPES.keys()) formatted_message = message.format(adapter_name, adapter_names) - raise dbt.exceptions.RuntimeException(formatted_message) + raise RuntimeException(formatted_message) -def get_relation_class_by_name(adapter_name): +def get_relation_class_by_name(adapter_name: str) -> Type[BaseRelation]: adapter = get_adapter_class_by_name(adapter_name) return adapter.Relation -def load_plugin(adapter_name): +def load_plugin(adapter_name: str) -> Credentials: + # this doesn't need a lock: in the worst case we'll overwrite PACKAGES and + # _ADAPTER_TYPE entries with the same value, as they're all singletons try: mod = import_module('.' + adapter_name, 'dbt.adapters') except ImportError as e: logger.info("Error importing adapter: {}".format(e)) - raise dbt.exceptions.RuntimeException( + raise RuntimeException( "Could not find adapter type {}!".format(adapter_name) ) plugin = mod.Plugin if plugin.adapter.type() != adapter_name: - raise dbt.exceptions.RuntimeException( + raise RuntimeException( 'Expected to find adapter with type named {}, got adapter with ' 'type {}' .format(adapter_name, plugin.adapter.type()) ) with _ADAPTER_LOCK: + # things do hold the lock to iterate over it so we need ot to add stuff ADAPTER_TYPES[adapter_name] = plugin.adapter PACKAGES[plugin.project_name] = plugin.include_path @@ -55,19 +71,16 @@ def load_plugin(adapter_name): return plugin.credentials -def get_adapter(config): +def get_adapter(config: RuntimeConfig) -> BaseAdapter: adapter_name = config.credentials.type + + # Atomically check to see if we already have an adapter if adapter_name in _ADAPTERS: return _ADAPTERS[adapter_name] - with _ADAPTER_LOCK: - if adapter_name not in ADAPTER_TYPES: - raise dbt.exceptions.RuntimeException( - "Could not find adapter type {}!".format(adapter_name) - ) - - adapter_type = ADAPTER_TYPES[adapter_name] + adapter_type = get_adapter_class_by_name(adapter_name) + with _ADAPTER_LOCK: # check again, in case something was setting it before if adapter_name in _ADAPTERS: return _ADAPTERS[adapter_name] diff --git a/core/dbt/adapters/sql/connections.py b/core/dbt/adapters/sql/connections.py index 5b9c7f459cd..e96a5eae5cb 100644 --- a/core/dbt/adapters/sql/connections.py +++ b/core/dbt/adapters/sql/connections.py @@ -1,5 +1,8 @@ import abc import time +from typing import List, Optional, Tuple, Any, Iterable, Dict + +import agate import dbt.clients.agate_helper import dbt.exceptions @@ -18,16 +21,13 @@ class SQLConnectionManager(BaseConnectionManager): - open """ @abc.abstractmethod - def cancel(self, connection): - """Cancel the given connection. - - :param Connection connection: The connection to cancel. - """ + def cancel(self, connection: Connection): + """Cancel the given connection.""" raise dbt.exceptions.NotImplementedException( '`cancel` is not implemented for this adapter!' ) - def cancel_open(self): + def cancel_open(self) -> List[str]: names = [] this_connection = self.get_if_exists() with self.lock: @@ -39,11 +39,17 @@ def cancel_open(self): # nothing to cancel. if connection.handle is not None: self.cancel(connection) - names.append(connection.name) + if connection.name is not None: + names.append(connection.name) return names - def add_query(self, sql, auto_begin=True, bindings=None, - abridge_sql_log=False): + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False + ) -> Tuple[Connection, Any]: connection = self.get_thread_connection() if auto_begin and connection.transaction_open is False: self.begin() @@ -76,25 +82,25 @@ def add_query(self, sql, auto_begin=True, bindings=None, return connection, cursor @abc.abstractclassmethod - def get_status(cls, cursor): - """Get the status of the cursor. - - :param cursor: A database handle to get status from - :return: The current status - :rtype: str - """ + def get_status(cls, cursor: Any) -> str: + """Get the status of the cursor.""" raise dbt.exceptions.NotImplementedException( '`get_status` is not implemented for this adapter!' ) @classmethod - def process_results(cls, column_names, rows): + def process_results( + cls, + column_names: Iterable[str], + rows: Iterable[Any] + ) -> List[Dict[str, Any]]: + return [dict(zip(column_names, row)) for row in rows] @classmethod - def get_result_from_cursor(cls, cursor): - data = [] - column_names = [] + def get_result_from_cursor(cls, cursor: Any) -> agate.Table: + data: List[Any] = [] + column_names: List[str] = [] if cursor.description is not None: column_names = [col[0] for col in cursor.description] @@ -103,7 +109,9 @@ def get_result_from_cursor(cls, cursor): return dbt.clients.agate_helper.table_from_data(data, column_names) - def execute(self, sql, auto_begin=False, fetch=False): + def execute( + self, sql: str, auto_begin: bool = False, fetch: bool = False + ) -> Tuple[str, agate.Table]: _, cursor = self.add_query(sql, auto_begin) status = self.get_status(cursor) if fetch: diff --git a/core/dbt/adapters/sql/impl.py b/core/dbt/adapters/sql/impl.py index 87bbb39db20..260b17b49d8 100644 --- a/core/dbt/adapters/sql/impl.py +++ b/core/dbt/adapters/sql/impl.py @@ -1,9 +1,12 @@ import agate +from typing import Any, Optional, Tuple, Type import dbt.clients.agate_helper +from dbt.contracts.connection import Connection import dbt.exceptions import dbt.flags from dbt.adapters.base import BaseAdapter, available +from dbt.adapters.sql import SQLConnectionManager from dbt.logger import GLOBAL_LOGGER as logger @@ -35,18 +38,25 @@ class SQLAdapter(BaseAdapter): - list_relations_without_caching - get_columns_in_relation """ + ConnectionManager: Type[SQLConnectionManager] + connections: SQLConnectionManager + @available.parse(lambda *a, **k: (None, None)) - def add_query(self, sql, auto_begin=True, bindings=None, - abridge_sql_log=False): + def add_query( + self, + sql: str, + auto_begin: bool = True, + bindings: Optional[Any] = None, + abridge_sql_log: bool = False, + ) -> Tuple[Connection, Any]: """Add a query to the current transaction. A thin wrapper around ConnectionManager.add_query. - :param str sql: The SQL query to add - :param bool auto_begin: If set and there is no transaction in progress, + :param sql: The SQL query to add + :param auto_begin: If set and there is no transaction in progress, begin a new one. - :param Optional[List[object]]: An optional list of bindings for the - query. - :param bool abridge_sql_log: If set, limit the raw sql logged to 512 + :param bindings: An optional list of bindings for the query. + :param abridge_sql_log: If set, limit the raw sql logged to 512 characters """ return self.connections.add_query(sql, auto_begin, bindings, diff --git a/core/dbt/api/__init__.py b/core/dbt/api/__init__.py deleted file mode 100644 index a6fe655f9c8..00000000000 --- a/core/dbt/api/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from dbt.api.object import APIObject - -__all__ = [ - 'APIObject' -] diff --git a/core/dbt/api/object.py b/core/dbt/api/object.py deleted file mode 100644 index d6408f1160e..00000000000 --- a/core/dbt/api/object.py +++ /dev/null @@ -1,125 +0,0 @@ -import copy -from collections import Mapping -from jsonschema import Draft7Validator - -from dbt.exceptions import JSONValidationException -from dbt.utils import deep_merge -from dbt.clients.system import write_json - - -class APIObject(Mapping): - """ - A serializable / deserializable object intended for - use in a future dbt API. - - To create a new object, you'll want to extend this - class, and then implement the SCHEMA property (a - valid JSON schema), the DEFAULTS property (default - settings for this object), and a static method that - calls this constructor. - """ - - SCHEMA = { - 'type': 'object', - 'properties': {} - } - - DEFAULTS = {} - - def __init__(self, **kwargs): - """ - Create and validate an instance. Note that if you override this, you - will want to do so by modifying kwargs and only then calling - super().__init__(**kwargs). - """ - super().__init__() - # note: deep_merge does a deep copy on its arguments. - self._contents = deep_merge(self.DEFAULTS, kwargs) - self.validate() - - def __str__(self): - return '{}(**{})'.format(self.__class__.__name__, self._contents) - - def __repr__(self): - return '{}(**{})'.format(self.__class__.__name__, self._contents) - - def __eq__(self, other): - if not isinstance(other, self.__class__): - return False - return self.serialize() == other.serialize() - - def incorporate(self, **kwargs): - """ - Given a list of kwargs, incorporate these arguments - into a new copy of this instance, and return the new - instance after validating. - """ - return type(self)(**deep_merge(self._contents, kwargs)) - - def serialize(self): - """ - Return a dict representation of this object. - """ - return copy.deepcopy(self._contents) - - def write(self, path): - write_json(path, self.serialize()) - - @classmethod - def deserialize(cls, settings): - """ - Convert a dict representation of this object into - an actual object for internal use. - """ - return cls(**settings) - - def validate(self): - """ - Using the SCHEMA property, validate the attributes - of this instance. If any attributes are missing or - invalid, raise a ValidationException. - """ - validator = Draft7Validator(self.SCHEMA) - - errors = set() # make errors a set to avoid duplicates - - for error in validator.iter_errors(self.serialize()): - errors.add('.'.join( - list(map(str, error.path)) + [error.message] - )) - - if errors: - raise JSONValidationException(type(self).__name__, errors) - - # implement the Mapping protocol: - # https://docs.python.org/3/library/collections.abc.html - def __getitem__(self, key): - return self._contents[key] - - def __iter__(self): - return self._contents.__iter__() - - def __len__(self): - return self._contents.__len__() - - # implement this because everyone always expects it. - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def set(self, key, value): - self._contents[key] = value - - # most users of APIObject also expect the attributes to be available via - # dot-notation because the previous implementation assigned to __dict__. - # we should consider removing this if we fix all uses to have properties. - def __getattr__(self, name): - if name != '_contents' and name in self._contents: - return self._contents[name] - elif hasattr(self.__class__, name): - return getattr(self.__class__, name) - raise AttributeError(( - "'{}' object has no attribute '{}'" - ).format(type(self).__name__, name)) diff --git a/core/dbt/config/__init__.py b/core/dbt/config/__init__.py index 09b5523dae1..d18fd7f0790 100644 --- a/core/dbt/config/__init__.py +++ b/core/dbt/config/__init__.py @@ -1,5 +1,5 @@ # all these are just exports, they need "noqa" so flake8 will not complain. -from .renderer import ConfigRenderer # noqa from .profile import Profile, PROFILES_DIR, read_user_config # noqa from .project import Project # noqa from .runtime import RuntimeConfig # noqa +from .renderer import ConfigRenderer # noqa diff --git a/core/dbt/config/renderer.py b/core/dbt/config/renderer.py index 2ef892f6844..4f196419ab7 100644 --- a/core/dbt/config/renderer.py +++ b/core/dbt/config/renderer.py @@ -1,5 +1,5 @@ from dbt.clients.jinja import get_rendered -from dbt.context.common import generate_config_context +from dbt.context.base import generate_config_context from dbt.exceptions import DbtProfileError from dbt.exceptions import DbtProjectError from dbt.exceptions import RecursionException diff --git a/core/dbt/config/runtime.py b/core/dbt/config/runtime.py index e89b4c3fe95..1e3bdc7dbdb 100644 --- a/core/dbt/config/runtime.py +++ b/core/dbt/config/runtime.py @@ -1,13 +1,13 @@ from copy import deepcopy +from .profile import Profile +from .project import Project from dbt.utils import parse_cli_vars from dbt.contracts.project import Configuration from dbt.exceptions import DbtProjectError from dbt.exceptions import validator_error_message from dbt.adapters.factory import get_relation_class_by_name -from .profile import Profile -from .project import Project from hologram import ValidationError @@ -72,11 +72,11 @@ def from_parts(cls, project, profile, args): :param args argparse.Namespace: The parsed command-line arguments. :returns RuntimeConfig: The new configuration. """ - quoting = deepcopy( + quoting = ( get_relation_class_by_name(profile.credentials.type) - .DEFAULTS['quote_policy'] - ) - quoting.update(project.quoting) + .get_default_quote_policy() + .replace_dict(project.quoting) + ).to_dict() return cls( project_name=project.project_name, diff --git a/core/dbt/context/base.py b/core/dbt/context/base.py new file mode 100644 index 00000000000..3994694a13e --- /dev/null +++ b/core/dbt/context/base.py @@ -0,0 +1,138 @@ +import json +import os + +import dbt.tracking +from dbt.clients.jinja import undefined_error +from dbt.utils import merge + + +# These modules are added to the context. Consider alternative +# approaches which will extend well to potentially many modules +import pytz +import datetime + + +def add_tracking(context): + if dbt.tracking.active_user is not None: + context = merge(context, { + "run_started_at": dbt.tracking.active_user.run_started_at, + "invocation_id": dbt.tracking.active_user.invocation_id, + }) + else: + context = merge(context, { + "run_started_at": None, + "invocation_id": None + }) + + return context + + +def env_var(var, default=None): + if var in os.environ: + return os.environ[var] + elif default is not None: + return default + else: + msg = "Env var required but not provided: '{}'".format(var) + undefined_error(msg) + + +def debug_here(): + import sys + import ipdb + frame = sys._getframe(3) + ipdb.set_trace(frame) + + +class Var: + UndefinedVarError = "Required var '{}' not found in config:\nVars "\ + "supplied to {} = {}" + _VAR_NOTSET = object() + + def __init__(self, model, context, overrides): + self.model = model + self.context = context + + # These are hard-overrides (eg. CLI vars) that should take + # precedence over context-based var definitions + self.overrides = overrides + + if model is None: + # during config parsing we have no model and no local vars + self.model_name = '' + local_vars = {} + else: + self.model_name = model.name + local_vars = model.local_vars() + + self.local_vars = dbt.utils.merge(local_vars, overrides) + + def pretty_dict(self, data): + return json.dumps(data, sort_keys=True, indent=4) + + def get_missing_var(self, var_name): + pretty_vars = self.pretty_dict(self.local_vars) + msg = self.UndefinedVarError.format( + var_name, self.model_name, pretty_vars + ) + dbt.exceptions.raise_compiler_error(msg, self.model) + + def assert_var_defined(self, var_name, default): + if var_name not in self.local_vars and default is self._VAR_NOTSET: + return self.get_missing_var(var_name) + + def get_rendered_var(self, var_name): + raw = self.local_vars[var_name] + # if bool/int/float/etc are passed in, don't compile anything + if not isinstance(raw, str): + return raw + + return dbt.clients.jinja.get_rendered(raw, self.context) + + def __call__(self, var_name, default=_VAR_NOTSET): + if var_name in self.local_vars: + return self.get_rendered_var(var_name) + elif default is not self._VAR_NOTSET: + return default + else: + return self.get_missing_var(var_name) + + +def get_pytz_module_context(): + context_exports = pytz.__all__ + + return { + name: getattr(pytz, name) for name in context_exports + } + + +def get_datetime_module_context(): + context_exports = [ + 'date', + 'datetime', + 'time', + 'timedelta', + 'tzinfo' + ] + + return { + name: getattr(datetime, name) for name in context_exports + } + + +def get_context_modules(): + return { + 'pytz': get_pytz_module_context(), + 'datetime': get_datetime_module_context(), + } + + +def generate_config_context(cli_vars): + context = { + 'env_var': env_var, + 'modules': get_context_modules(), + } + context['var'] = Var(None, context, cli_vars) + if os.environ.get('DBT_MACRO_DEBUGGING'): + context['debug'] = debug_here + return add_tracking(context) diff --git a/core/dbt/context/common.py b/core/dbt/context/common.py index efc5b8d9f5e..b356392e56b 100644 --- a/core/dbt/context/common.py +++ b/core/dbt/context/common.py @@ -1,26 +1,21 @@ import json import os -from dbt.adapters.factory import get_adapter -from dbt.node_types import NodeType -from dbt.include.global_project import PACKAGES -from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME - -import dbt.clients.jinja import dbt.clients.agate_helper import dbt.exceptions import dbt.flags import dbt.tracking -import dbt.writer import dbt.utils - -from dbt.logger import GLOBAL_LOGGER as logger # noqa - - -# These modules are added to the context. Consider alternative -# approaches which will extend well to potentially many modules -import pytz -import datetime +import dbt.writer +from dbt.adapters.factory import get_adapter +from dbt.node_types import NodeType +from dbt.include.global_project import PACKAGES +from dbt.include.global_project import PROJECT_NAME as GLOBAL_PROJECT_NAME +from dbt.logger import GLOBAL_LOGGER as logger +from dbt.clients.jinja import get_rendered +from dbt.context.base import ( + debug_here, env_var, get_context_modules, add_tracking +) class RelationProxy: @@ -125,22 +120,29 @@ def _add_macros(context, model, manifest): return context -def _add_tracking(context): - if dbt.tracking.active_user is not None: - context = dbt.utils.merge(context, { - "run_started_at": dbt.tracking.active_user.run_started_at, - "invocation_id": dbt.tracking.active_user.invocation_id, - }) - else: - context = dbt.utils.merge(context, { - "run_started_at": None, - "invocation_id": None +def _store_result(sql_results): + def call(name, status, agate_table=None): + if agate_table is None: + agate_table = dbt.clients.agate_helper.empty_table() + + sql_results[name] = dbt.utils.AttrDict({ + 'status': status, + 'data': dbt.clients.agate_helper.as_matrix(agate_table), + 'table': agate_table }) + return '' - return context + return call -def _add_validation(context): +def _load_result(sql_results): + def call(name): + return sql_results.get(name) + + return call + + +def add_validation(context): def validate_any(*args): def inner(value): for arg in args: @@ -162,46 +164,7 @@ def inner(value): {'validation': validation_utils}) -def env_var(var, default=None): - if var in os.environ: - return os.environ[var] - elif default is not None: - return default - else: - msg = "Env var required but not provided: '{}'".format(var) - dbt.clients.jinja.undefined_error(msg) - - -def _store_result(sql_results): - def call(name, status, agate_table=None): - if agate_table is None: - agate_table = dbt.clients.agate_helper.empty_table() - - sql_results[name] = dbt.utils.AttrDict({ - 'status': status, - 'data': dbt.clients.agate_helper.as_matrix(agate_table), - 'table': agate_table - }) - return '' - - return call - - -def _load_result(sql_results): - def call(name): - return sql_results.get(name) - - return call - - -def _debug_here(): - import sys - import ipdb - frame = sys._getframe(3) - ipdb.set_trace(frame) - - -def _add_sql_handlers(context): +def add_sql_handlers(context): sql_results = {} return dbt.utils.merge(context, { '_sql_results': sql_results, @@ -210,68 +173,6 @@ def _add_sql_handlers(context): }) -def log(msg, info=False): - if info: - logger.info(msg) - else: - logger.debug(msg) - return '' - - -class Var: - UndefinedVarError = "Required var '{}' not found in config:\nVars "\ - "supplied to {} = {}" - _VAR_NOTSET = object() - - def __init__(self, model, context, overrides): - self.model = model - self.context = context - - # These are hard-overrides (eg. CLI vars) that should take - # precedence over context-based var definitions - self.overrides = overrides - - if model is None: - # during config parsing we have no model and no local vars - self.model_name = '' - local_vars = {} - else: - self.model_name = model.name - local_vars = model.local_vars() - - self.local_vars = dbt.utils.merge(local_vars, overrides) - - def pretty_dict(self, data): - return json.dumps(data, sort_keys=True, indent=4) - - def get_missing_var(self, var_name): - pretty_vars = self.pretty_dict(self.local_vars) - msg = self.UndefinedVarError.format( - var_name, self.model_name, pretty_vars - ) - dbt.exceptions.raise_compiler_error(msg, self.model) - - def assert_var_defined(self, var_name, default): - if var_name not in self.local_vars and default is self._VAR_NOTSET: - return self.get_missing_var(var_name) - - def get_rendered_var(self, var_name): - raw = self.local_vars[var_name] - # if bool/int/float/etc are passed in, don't compile anything - if not isinstance(raw, str): - return raw - - return dbt.clients.jinja.get_rendered(raw, self.context) - - def __call__(self, var_name, default=_VAR_NOTSET): - if var_name in self.local_vars: - return self.get_rendered_var(var_name) - elif default is not self._VAR_NOTSET: - return default - else: - return self.get_missing_var(var_name) - - def write(node, target_path, subdirectory): def fn(payload): node.build_path = dbt.writer.write_node( @@ -283,7 +184,7 @@ def fn(payload): def render(context, node): def fn(string): - return dbt.clients.jinja.get_rendered(string, context, node) + return get_rendered(string, context, node) return fn @@ -311,46 +212,17 @@ def impl(message_if_exception, func, *args, **kwargs): return impl -def _return(value): - raise dbt.exceptions.MacroReturn(value) - - -def get_pytz_module_context(): - context_exports = pytz.__all__ - - return { - name: getattr(pytz, name) for name in context_exports - } - - -def get_datetime_module_context(): - context_exports = [ - 'date', - 'datetime', - 'time', - 'timedelta', - 'tzinfo' - ] - - return { - name: getattr(datetime, name) for name in context_exports - } - - -def get_context_modules(): - return { - 'pytz': get_pytz_module_context(), - 'datetime': get_datetime_module_context(), - } +# Base context collection, used for parsing configs. +def log(msg, info=False): + if info: + logger.info(msg) + else: + logger.debug(msg) + return '' -def generate_config_context(cli_vars): - context = { - 'env_var': env_var, - 'modules': get_context_modules(), - } - context['var'] = Var(None, context, cli_vars) - return _add_tracking(context) +def _return(value): + raise dbt.exceptions.MacroReturn(value) def _build_load_agate_table(model): @@ -422,7 +294,7 @@ def generate_base(model, model_dict, config, manifest, source_config, "try_or_compiler_error": try_or_compiler_error(model) }) if os.environ.get('DBT_MACRO_DEBUGGING'): - context['debug'] = _debug_here + context['debug'] = debug_here return context @@ -430,9 +302,9 @@ def generate_base(model, model_dict, config, manifest, source_config, def modify_generated_context(context, model, config, manifest, provider): cli_var_overrides = config.cli_vars - context = _add_tracking(context) - context = _add_validation(context) - context = _add_sql_handlers(context) + context = add_tracking(context) + context = add_validation(context) + context = add_sql_handlers(context) # we make a copy of the context for each of these ^^ diff --git a/core/dbt/context/parser.py b/core/dbt/context/parser.py index 2d476c0595b..0e8d879ba4a 100644 --- a/core/dbt/context/parser.py +++ b/core/dbt/context/parser.py @@ -89,7 +89,7 @@ def __getattr__(self, name): ) -class Var(dbt.context.common.Var): +class Var(dbt.context.base.Var): def get_missing_var(self, var_name): # in the parser, just always return None. return None diff --git a/core/dbt/context/runtime.py b/core/dbt/context/runtime.py index de8cc730eb8..c97a75b50dc 100644 --- a/core/dbt/context/runtime.py +++ b/core/dbt/context/runtime.py @@ -1,6 +1,7 @@ from dbt.utils import get_materialization, add_ephemeral_model_prefix import dbt.clients.jinja +import dbt.context.base import dbt.context.common import dbt.flags from dbt.parser.util import ParserUtils @@ -144,7 +145,7 @@ def __getattr__(self, name): ) -class Var(dbt.context.common.Var): +class Var(dbt.context.base.Var): pass diff --git a/core/dbt/contracts/connection.py b/core/dbt/contracts/connection.py index 2ed88ae38c1..56b4be0adf2 100644 --- a/core/dbt/contracts/connection.py +++ b/core/dbt/contracts/connection.py @@ -1,11 +1,17 @@ +import abc +from dataclasses import dataclass, field +from typing import ( + Any, ClassVar, Dict, Tuple, Iterable, Optional, NewType +) + +from hologram import JsonSchemaMixin from hologram.helpers import ( StrEnum, register_pattern, ExtensibleJsonSchemaMixin ) -from hologram import JsonSchemaMixin + from dbt.contracts.util import Replaceable +from dbt.utils import translate_aliases -from dataclasses import dataclass -from typing import Any, Optional, NewType Identifier = NewType('Identifier', str) register_pattern(Identifier, r'^[A-Za-z_][A-Za-z0-9_]+$') @@ -58,3 +64,55 @@ def handle(self): @handle.setter def handle(self, value): self._handle = value + + +# see https://github.com/python/mypy/issues/4717#issuecomment-373932080 +# and https://github.com/python/mypy/issues/5374 +# for why we have type: ignore. Maybe someday dataclasses + abstract classes +# will work. +@dataclass +class Credentials( # type: ignore + ExtensibleJsonSchemaMixin, + Replaceable, + metaclass=abc.ABCMeta +): + database: str + schema: str + _ALIASES: ClassVar[Dict[str, str]] = field(default={}, init=False) + + @abc.abstractproperty + def type(self) -> str: + raise NotImplementedError( + 'type not implemented for base credentials class' + ) + + def connection_info(self) -> Iterable[Tuple[str, Any]]: + """Return an ordered iterator of key/value pairs for pretty-printing. + """ + as_dict = self.to_dict() + for key in self._connection_keys(): + if key in as_dict: + yield key, as_dict[key] + + @abc.abstractmethod + def _connection_keys(self) -> Tuple[str, ...]: + raise NotImplementedError + + @classmethod + def from_dict(cls, data): + data = cls.translate_aliases(data) + return super().from_dict(data) + + @classmethod + def translate_aliases(cls, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return translate_aliases(kwargs, cls._ALIASES) + + def to_dict(self, omit_none=True, validate=False, with_aliases=False): + serialized = super().to_dict(omit_none=omit_none, validate=validate) + if with_aliases: + serialized.update({ + new_name: serialized[canonical_name] + for new_name, canonical_name in self._ALIASES.items() + if canonical_name in serialized + }) + return serialized diff --git a/core/dbt/contracts/util.py b/core/dbt/contracts/util.py index b2fbe834c12..842d9bbc87f 100644 --- a/core/dbt/contracts/util.py +++ b/core/dbt/contracts/util.py @@ -1,7 +1,7 @@ -from dbt.clients.system import write_json - import dataclasses +from dbt.clients.system import write_json + class Replaceable: def replace(self, **kwargs): diff --git a/core/dbt/deprecations.py b/core/dbt/deprecations.py index 5f4ff68a2d2..3a54e95395c 100644 --- a/core/dbt/deprecations.py +++ b/core/dbt/deprecations.py @@ -75,6 +75,15 @@ class MaterializationReturnDeprecation(DBTDeprecation): '''.lstrip() +class NotADictionaryDeprecation(DBTDeprecation): + _name = 'not-a-dictionary' + + _description = ''' + The object ("{obj}") was used as a dictionary. In a future version of dbt + this capability will be removed from objects of this type. + '''.lstrip() + + _adapter_renamed_description = """\ The adapter function `adapter.{old_name}` is deprecated and will be removed in a future release of dbt. Please use `adapter.{new_name}` instead. @@ -113,6 +122,7 @@ def warn(name, *args, **kwargs): DBTRepositoriesDeprecation(), GenerateSchemaNameSingleArgDeprecated(), MaterializationReturnDeprecation(), + NotADictionaryDeprecation(), ] deprecations: Dict[str, DBTDeprecation] = { diff --git a/core/dbt/include/global_project/macros/adapters/common.sql b/core/dbt/include/global_project/macros/adapters/common.sql index a6be4d8cbc1..7725f2e0981 100644 --- a/core/dbt/include/global_project/macros/adapters/common.sql +++ b/core/dbt/include/global_project/macros/adapters/common.sql @@ -265,8 +265,7 @@ {% macro default__make_temp_relation(base_relation, suffix) %} {% set tmp_identifier = base_relation.identifier ~ suffix %} {% set tmp_relation = base_relation.incorporate( - path={"identifier": tmp_identifier}, - table_name=tmp_identifier) -%} + path={"identifier": tmp_identifier}) -%} {% do return(tmp_relation) %} {% endmacro %} diff --git a/core/dbt/include/global_project/macros/materializations/seed/seed.sql b/core/dbt/include/global_project/macros/materializations/seed/seed.sql index ca836dc88e2..f83f845e3ea 100644 --- a/core/dbt/include/global_project/macros/materializations/seed/seed.sql +++ b/core/dbt/include/global_project/macros/materializations/seed/seed.sql @@ -15,7 +15,7 @@ {%- set column_override = model['config'].get('column_types', {}) -%} {% set sql %} - create table {{ this.render(False) }} ( + create table {{ this.render() }} ( {%- for col_name in agate_table.column_names -%} {%- set inferred_type = adapter.convert_type(agate_table, loop.index0) -%} {%- set type = column_override.get(col_name, inferred_type) -%} @@ -60,7 +60,7 @@ {% endfor %} {% set sql %} - insert into {{ this.render(False) }} ({{ cols_sql }}) values + insert into {{ this.render() }} ({{ cols_sql }}) values {% for row in chunk -%} ({%- for column in agate_table.column_names -%} %s diff --git a/core/dbt/parser/schemas.py b/core/dbt/parser/schemas.py index a94030554bf..0c9737e6056 100644 --- a/core/dbt/parser/schemas.py +++ b/core/dbt/parser/schemas.py @@ -3,7 +3,7 @@ from hologram import ValidationError -from dbt.context.common import generate_config_context +from dbt.context.base import generate_config_context from dbt.clients.jinja import get_rendered from dbt.clients.yaml_helper import load_yaml_text diff --git a/core/dbt/tracking.py b/core/dbt/tracking.py index 26e8eff18af..6c78d12b9b3 100644 --- a/core/dbt/tracking.py +++ b/core/dbt/tracking.py @@ -4,8 +4,6 @@ from snowplow_tracker import SelfDescribingJson from datetime import datetime -from dbt.adapters.factory import get_adapter - import pytz import platform import uuid @@ -125,6 +123,8 @@ def get_run_type(args): def get_invocation_context(user, config, args): + # put this in here to avoid an import cycle + from dbt.adapters.factory import get_adapter try: adapter_type = get_adapter(config).type() except Exception: diff --git a/core/dbt/utils.py b/core/dbt/utils.py index 10016c05999..652556cbff9 100644 --- a/core/dbt/utils.py +++ b/core/dbt/utils.py @@ -437,9 +437,8 @@ def parse_cli_vars(var_string): V_T = TypeVar('V_T') -def filter_null_values(input: Dict[K_T, V_T]) -> Dict[K_T, V_T]: - return dict((k, v) for (k, v) in input.items() - if v is not None) +def filter_null_values(input: Dict[K_T, Optional[V_T]]) -> Dict[K_T, V_T]: + return {k: v for k, v in input.items() if v is not None} def add_ephemeral_model_prefix(s: str) -> str: @@ -522,3 +521,15 @@ def env_set_truthy(key: str) -> Optional[str]: def restrict_to(*restrictions): """Create the metadata for a restricted dataclass field""" return {'restrict': list(restrictions)} + + +# some types need to make constants available to the jinja context as +# attributes, and regular properties only work with objects. maybe this should +# be handled by the RelationProxy? + +class classproperty(object): + def __init__(self, func): + self.func = func + + def __get__(self, obj, objtype): + return self.func(objtype) diff --git a/plugins/bigquery/dbt/adapters/bigquery/__init__.py b/plugins/bigquery/dbt/adapters/bigquery/__init__.py index c456567722c..daff48a32ee 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/__init__.py +++ b/plugins/bigquery/dbt/adapters/bigquery/__init__.py @@ -1,7 +1,7 @@ from dbt.adapters.bigquery.connections import BigQueryConnectionManager # noqa from dbt.adapters.bigquery.connections import BigQueryCredentials from dbt.adapters.bigquery.relation import BigQueryRelation # noqa -from dbt.adapters.bigquery.relation import BigQueryColumn # noqa +from dbt.adapters.bigquery.column import BigQueryColumn # noqa from dbt.adapters.bigquery.impl import BigQueryAdapter from dbt.adapters.base import AdapterPlugin diff --git a/plugins/bigquery/dbt/adapters/bigquery/column.py b/plugins/bigquery/dbt/adapters/bigquery/column.py new file mode 100644 index 00000000000..8c8a442b412 --- /dev/null +++ b/plugins/bigquery/dbt/adapters/bigquery/column.py @@ -0,0 +1,121 @@ +from dataclasses import dataclass +from typing import Optional, List, TypeVar, Iterable, Type + +from dbt.adapters.base.column import Column + +from google.cloud.bigquery import SchemaField + +Self = TypeVar('Self', bound='BigQueryColumn') + + +@dataclass(init=False) +class BigQueryColumn(Column): + TYPE_LABELS = { + 'STRING': 'STRING', + 'TIMESTAMP': 'TIMESTAMP', + 'FLOAT': 'FLOAT64', + 'INTEGER': 'INT64', + 'RECORD': 'RECORD', + } + fields: List[Self] + mode: str + + def __init__( + self, + column: str, + dtype: str, + fields: Optional[Iterable[SchemaField]] = None, + mode: str = 'NULLABLE', + ) -> None: + super().__init__(column, dtype) + + if fields is None: + fields = [] + + self.fields = self.wrap_subfields(fields) + self.mode = mode + + @classmethod + def wrap_subfields( + cls: Type[Self], fields: Iterable[SchemaField] + ) -> List[Self]: + return [cls.create_from_field(field) for field in fields] + + @classmethod + def create_from_field(cls: Type[Self], field: SchemaField) -> Self: + return cls( + field.name, + cls.translate_type(field.field_type), + field.fields, + field.mode, + ) + + @classmethod + def _flatten_recursive( + cls: Type[Self], col: Self, prefix: Optional[str] = None + ) -> List[Self]: + if prefix is None: + prefix = [] + + if len(col.fields) == 0: + prefixed_name = ".".join(prefix + [col.column]) + new_col = cls(prefixed_name, col.dtype, col.fields, col.mode) + return [new_col] + + new_fields = [] + for field in col.fields: + new_prefix = prefix + [col.column] + new_fields.extend(cls._flatten_recursive(field, new_prefix)) + + return new_fields + + def flatten(self): + return self._flatten_recursive(self) + + @property + def quoted(self): + return '`{}`'.format(self.column) + + def literal(self, value): + return "cast({} as {})".format(value, self.dtype) + + @property + def data_type(self) -> str: + if self.dtype.upper() == 'RECORD': + subcols = [ + "{} {}".format(col.name, col.data_type) for col in self.fields + ] + field_type = 'STRUCT<{}>'.format(", ".join(subcols)) + + else: + field_type = self.dtype + + if self.mode.upper() == 'REPEATED': + return 'ARRAY<{}>'.format(field_type) + + else: + return field_type + + def is_string(self) -> bool: + return self.dtype.lower() == 'string' + + def is_numeric(self) -> bool: + return False + + def can_expand_to(self: Self, other_column: Self) -> bool: + """returns True if both columns are strings""" + return self.is_string() and other_column.is_string() + + def __repr__(self) -> str: + return "".format(self.name, self.data_type, + self.mode) + + def column_to_bq_schema(self) -> SchemaField: + """Convert a column to a bigquery schema object. + """ + kwargs = {} + if len(self.fields) > 0: + fields = [field.column_to_bq_schema() for field in self.fields] + kwargs = {"fields": fields} + + return SchemaField(self.name, self.dtype, self.mode, **kwargs) diff --git a/plugins/bigquery/dbt/adapters/bigquery/connections.py b/plugins/bigquery/dbt/adapters/bigquery/connections.py index 00df1a82ace..b1bf8b44e6c 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/connections.py +++ b/plugins/bigquery/dbt/adapters/bigquery/connections.py @@ -92,7 +92,7 @@ def exception_handler(self, sql): raise raise dbt.exceptions.RuntimeException(str(e)) - def cancel_open(self): + def cancel_open(self) -> None: pass @classmethod diff --git a/plugins/bigquery/dbt/adapters/bigquery/impl.py b/plugins/bigquery/dbt/adapters/bigquery/impl.py index fe793ad5fc8..596aff2a66c 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/impl.py +++ b/plugins/bigquery/dbt/adapters/bigquery/impl.py @@ -6,8 +6,10 @@ import dbt.clients.gcloud import dbt.clients.agate_helper -from dbt.adapters.base import BaseAdapter, available -from dbt.adapters.bigquery import BigQueryRelation +from dbt.adapters.base import BaseAdapter, available, RelationType +from dbt.adapters.bigquery.relation import ( + BigQueryRelation +) from dbt.adapters.bigquery import BigQueryColumn from dbt.adapters.bigquery import BigQueryConnectionManager from dbt.contracts.connection import Connection @@ -36,9 +38,9 @@ def _stub_relation(*args, **kwargs): class BigQueryAdapter(BaseAdapter): RELATION_TYPES = { - 'TABLE': BigQueryRelation.Table, - 'VIEW': BigQueryRelation.View, - 'EXTERNAL': BigQueryRelation.External + 'TABLE': RelationType.Table, + 'VIEW': RelationType.View, + 'EXTERNAL': RelationType.External } Relation = BigQueryRelation @@ -102,7 +104,7 @@ def get_columns_in_relation(self, relation): table = self.connections.get_bq_table( database=relation.database, schema=relation.schema, - identifier=relation.table_name + identifier=relation.identifier ) return self._get_dbt_columns_from_bq_table(table) diff --git a/plugins/bigquery/dbt/adapters/bigquery/relation.py b/plugins/bigquery/dbt/adapters/bigquery/relation.py index 8110adc4c7a..509ae8e4e98 100644 --- a/plugins/bigquery/dbt/adapters/bigquery/relation.py +++ b/plugins/bigquery/dbt/adapters/bigquery/relation.py @@ -1,60 +1,26 @@ -from dbt.adapters.base.relation import BaseRelation, Column -from dbt.utils import filter_null_values +from dataclasses import dataclass +from typing import Optional -import google.cloud.bigquery +from dbt.adapters.base.relation import ( + BaseRelation, ComponentName +) +from dbt.utils import filter_null_values +@dataclass(frozen=True, eq=False, repr=False) class BigQueryRelation(BaseRelation): - External = "external" - - DEFAULTS = { - 'metadata': { - 'type': 'BigQueryRelation' - }, - 'quote_character': '`', - 'quote_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'include_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'dbt_created': False, - } - - SCHEMA = { - 'type': 'object', - 'properties': { - 'metadata': { - 'type': 'object', - 'properties': { - 'type': { - 'type': 'string', - 'const': 'BigQueryRelation', - }, - }, - }, - 'type': { - 'enum': BaseRelation.RelationTypes + [External, None], - }, - 'path': BaseRelation.PATH_SCHEMA, - 'include_policy': BaseRelation.POLICY_SCHEMA, - 'quote_policy': BaseRelation.POLICY_SCHEMA, - 'quote_character': {'type': 'string'}, - 'dbt_created': {'type': 'boolean'}, - }, - 'required': ['metadata', 'type', 'path', 'include_policy', - 'quote_policy', 'quote_character', 'dbt_created'] - } - - def matches(self, database=None, schema=None, identifier=None): + quote_character: str = '`' + + def matches( + self, + database: Optional[str] = None, + schema: Optional[str] = None, + identifier: Optional[str] = None, + ) -> bool: search = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier + ComponentName.Database: database, + ComponentName.Schema: schema, + ComponentName.Identifier: identifier }) if not search: @@ -67,145 +33,10 @@ def matches(self, database=None, schema=None, identifier=None): return True - @classmethod - def create(cls, database=None, schema=None, - identifier=None, table_name=None, - type=None, **kwargs): - if table_name is None: - table_name = identifier - - return cls(type=type, - path={ - 'database': database, - 'schema': schema, - 'identifier': identifier - }, - table_name=table_name, - **kwargs) - - def quote(self, database=None, schema=None, identifier=None): - policy = filter_null_values({ - 'database': database, - 'schema': schema, - 'identifier': identifier - }) - - return self.incorporate(quote_policy=policy) - - @property - def database(self): - return self.path.get('database') - @property def project(self): - return self.path.get('database') - - @property - def schema(self): - return self.path.get('schema') + return self.database @property def dataset(self): - return self.path.get('schema') - - @property - def identifier(self): - return self.path.get('identifier') - - -class BigQueryColumn(Column): - TYPE_LABELS = { - 'STRING': 'STRING', - 'TIMESTAMP': 'TIMESTAMP', - 'FLOAT': 'FLOAT64', - 'INTEGER': 'INT64', - 'RECORD': 'RECORD', - } - - def __init__(self, column, dtype, fields=None, mode='NULLABLE'): - super().__init__(column, dtype) - - if fields is None: - fields = [] - - self.fields = self.wrap_subfields(fields) - self.mode = mode - - @classmethod - def wrap_subfields(cls, fields): - return [BigQueryColumn.create_from_field(field) for field in fields] - - @classmethod - def create_from_field(cls, field): - return BigQueryColumn(field.name, cls.translate_type(field.field_type), - field.fields, field.mode) - - @classmethod - def _flatten_recursive(cls, col, prefix=None): - if prefix is None: - prefix = [] - - if len(col.fields) == 0: - prefixed_name = ".".join(prefix + [col.column]) - new_col = BigQueryColumn(prefixed_name, col.dtype, col.fields, - col.mode) - return [new_col] - - new_fields = [] - for field in col.fields: - new_prefix = prefix + [col.column] - new_fields.extend(cls._flatten_recursive(field, new_prefix)) - - return new_fields - - def flatten(self): - return self._flatten_recursive(self) - - @property - def quoted(self): - return '`{}`'.format(self.column) - - def literal(self, value): - return "cast({} as {})".format(value, self.dtype) - - @property - def data_type(self): - if self.dtype.upper() == 'RECORD': - subcols = [ - "{} {}".format(col.name, col.data_type) for col in self.fields - ] - field_type = 'STRUCT<{}>'.format(", ".join(subcols)) - - else: - field_type = self.dtype - - if self.mode.upper() == 'REPEATED': - return 'ARRAY<{}>'.format(field_type) - - else: - return field_type - - def is_string(self): - return self.dtype.lower() == 'string' - - def is_numeric(self): - return False - - def can_expand_to(self, other_column): - """returns True if both columns are strings""" - return self.is_string() and other_column.is_string() - - def __repr__(self): - return "".format(self.name, self.data_type, - self.mode) - - def column_to_bq_schema(self): - """Convert a column to a bigquery schema object. - """ - kwargs = {} - if len(self.fields) > 0: - fields = [field.column_to_bq_schema() for field in self.fields] - kwargs = {"fields": fields} - - return google.cloud.bigquery.SchemaField(self.name, self.dtype, - self.mode, **kwargs) + return self.schema diff --git a/plugins/postgres/dbt/include/postgres/macros/adapters.sql b/plugins/postgres/dbt/include/postgres/macros/adapters.sql index aa4852c7e56..f8892741686 100644 --- a/plugins/postgres/dbt/include/postgres/macros/adapters.sql +++ b/plugins/postgres/dbt/include/postgres/macros/adapters.sql @@ -111,7 +111,6 @@ {% macro postgres__make_temp_relation(base_relation, suffix) %} {% set tmp_identifier = base_relation.identifier ~ suffix ~ py_current_timestring() %} {% do return(base_relation.incorporate( - table_name=tmp_identifier, path={ "identifier": tmp_identifier, "schema": none, diff --git a/plugins/snowflake/dbt/adapters/snowflake/relation.py b/plugins/snowflake/dbt/adapters/snowflake/relation.py index 0c6b8555484..217292d8d17 100644 --- a/plugins/snowflake/dbt/adapters/snowflake/relation.py +++ b/plugins/snowflake/dbt/adapters/snowflake/relation.py @@ -1,46 +1,14 @@ -from dbt.adapters.base.relation import BaseRelation +from dataclasses import dataclass +from dbt.adapters.base.relation import BaseRelation, Policy -class SnowflakeRelation(BaseRelation): - DEFAULTS = { - 'metadata': { - 'type': 'SnowflakeRelation' - }, - 'quote_character': '"', - 'quote_policy': { - 'database': False, - 'schema': False, - 'identifier': False, - }, - 'include_policy': { - 'database': True, - 'schema': True, - 'identifier': True, - }, - 'dbt_created': False, - } +@dataclass +class SnowflakeQuotePolicy(Policy): + database: bool = False + schema: bool = False + identifier: bool = False + - SCHEMA = { - 'type': 'object', - 'properties': { - 'metadata': { - 'type': 'object', - 'properties': { - 'type': { - 'type': 'string', - 'const': 'SnowflakeRelation', - }, - }, - }, - 'type': { - 'enum': BaseRelation.RelationTypes + [None], - }, - 'path': BaseRelation.PATH_SCHEMA, - 'include_policy': BaseRelation.POLICY_SCHEMA, - 'quote_policy': BaseRelation.POLICY_SCHEMA, - 'quote_character': {'type': 'string'}, - 'dbt_created': {'type': 'boolean'}, - }, - 'required': ['metadata', 'type', 'path', 'include_policy', - 'quote_policy', 'quote_character', 'dbt_created'] - } +@dataclass(frozen=True, eq=False, repr=False) +class SnowflakeRelation(BaseRelation): + quote_policy: SnowflakeQuotePolicy = SnowflakeQuotePolicy() diff --git a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql index 65bfe435680..993c539638c 100644 --- a/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql +++ b/plugins/snowflake/dbt/include/snowflake/macros/adapters.sql @@ -16,7 +16,7 @@ temporary {%- elif transient -%} transient - {%- endif %} table {{ relation }} {% if copy_grants and not temporary -%} copy grants {%- endif %} as + {%- endif %} table {{ relation }} {% if copy_grants and not temporary -%} copy grants {%- endif %} as ( {%- if cluster_by_string is not none -%} select * from( @@ -83,7 +83,7 @@ case when table_type = 'BASE TABLE' then 'table' when table_type = 'VIEW' then 'view' when table_type = 'MATERIALIZED VIEW' then 'materializedview' - when table_type = 'EXTERNAL TABLE' then 'externaltable' + when table_type = 'EXTERNAL TABLE' then 'external' else table_type end as table_type from {{ information_schema }}.tables diff --git a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py index 579772b854b..471cda24d04 100644 --- a/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py +++ b/test/integration/032_concurrent_transaction_test/test_concurrent_transaction.py @@ -99,6 +99,7 @@ def run_test(self): self.assertEqual(self.query_state['view_model'], 'good') self.assertEqual(self.query_state['model_1'], 'good') + class TableTestConcurrentTransaction(BaseTestConcurrentTransaction): @property def models(self): @@ -109,6 +110,7 @@ def test__redshift__concurrent_transaction_table(self): self.reset() self.run_test() + class ViewTestConcurrentTransaction(BaseTestConcurrentTransaction): @property def models(self): @@ -119,6 +121,7 @@ def test__redshift__concurrent_transaction_view(self): self.reset() self.run_test() + class IncrementalTestConcurrentTransaction(BaseTestConcurrentTransaction): @property def models(self): diff --git a/test/unit/test_bigquery_adapter.py b/test/unit/test_bigquery_adapter.py index f54050bef91..e7fd4a9228f 100644 --- a/test/unit/test_bigquery_adapter.py +++ b/test/unit/test_bigquery_adapter.py @@ -1,9 +1,10 @@ import unittest from unittest.mock import patch, MagicMock +import hologram + import dbt.flags as flags -from dbt.adapters.bigquery import BigQueryCredentials from dbt.adapters.bigquery import BigQueryAdapter from dbt.adapters.bigquery import BigQueryRelation import dbt.exceptions @@ -164,7 +165,7 @@ def setUp(self): self.mock_connection_manager = self.conn_manager_cls.return_value self.conn_manager_cls.TYPE = 'bigquery' - self.relation_cls.DEFAULTS = BigQueryRelation.DEFAULTS + self.relation_cls.get_default_quote_policy.side_effect = BigQueryRelation.get_default_quote_policy self.adapter = self.get_adapter('oauth') @@ -190,7 +191,7 @@ def test_drop_schema(self, mock_check_schema): def test_get_columns_in_relation(self): self.mock_connection_manager.get_bq_table.side_effect = ValueError self.adapter.get_columns_in_relation( - MagicMock(database='db', schema='schema', table_name='ident'), + MagicMock(database='db', schema='schema', identifier='ident'), ) self.mock_connection_manager.get_bq_table.assert_called_once_with( database='db', schema='schema', identifier='ident' @@ -209,12 +210,11 @@ def test_view_temp_relation(self): 'schema': 'test_schema', 'identifier': 'my_view' }, - 'table_name': 'my_view__dbt_tmp', 'quote_policy': { 'identifier': False } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_view_relation(self): kwargs = { @@ -224,13 +224,12 @@ def test_view_relation(self): 'schema': 'test_schema', 'identifier': 'my_view' }, - 'table_name': 'my_view', 'quote_policy': { 'identifier': True, 'schema': True } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_table_relation(self): kwargs = { @@ -240,13 +239,12 @@ def test_table_relation(self): 'schema': 'test_schema', 'identifier': 'generic_table' }, - 'table_name': 'generic_table', 'quote_policy': { 'identifier': True, 'schema': True } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_external_source_relation(self): kwargs = { @@ -256,13 +254,12 @@ def test_external_source_relation(self): 'schema': 'test_schema', 'identifier': 'sheet' }, - 'table_name': 'sheet', 'quote_policy': { 'identifier': True, 'schema': True } } - BigQueryRelation(**kwargs) + BigQueryRelation.from_dict(kwargs) def test_invalid_relation(self): kwargs = { @@ -272,11 +269,10 @@ def test_invalid_relation(self): 'schema': 'test_schema', 'identifier': 'my_invalid_id' }, - 'table_name': 'my_invalid_id', 'quote_policy': { 'identifier': False, 'schema': True } } - with self.assertRaises(dbt.exceptions.ValidationException): - BigQueryRelation(**kwargs) + with self.assertRaises(hologram.ValidationError): + BigQueryRelation.from_dict(kwargs) diff --git a/test/unit/test_cache.py b/test/unit/test_cache.py index 86e14915b5c..0c314350b9d 100644 --- a/test/unit/test_cache.py +++ b/test/unit/test_cache.py @@ -14,8 +14,7 @@ def make_relation(database, schema, identifier): def make_mock_relationship(database, schema, identifier): return BaseRelation.create( - database=database, schema=schema, identifier=identifier, - table_name=identifier, type='view' + database=database, schema=schema, identifier=identifier, type='view' ) diff --git a/test/unit/utils.py b/test/unit/utils.py index 51d7cf45d8e..b0675f7efa5 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -60,7 +60,6 @@ def inject_adapter(value): artisanal adapter will be available from get_adapter() as if dbt loaded it. """ from dbt.adapters import factory - from dbt.adapters.base.connections import BaseConnectionManager key = value.type() factory._ADAPTERS[key] = value factory.ADAPTER_TYPES[key] = type(value) diff --git a/tox.ini b/tox.ini index 67466513ed8..9ac72d6586c 100644 --- a/tox.ini +++ b/tox.ini @@ -12,8 +12,9 @@ deps = [testenv:mypy] basepython = python3.6 commands = /bin/bash -c '$(which mypy) \ - core/dbt/adapters/base/impl.py \ - core/dbt/adapters/base/meta.py \ + core/dbt/adapters/base \ + core/dbt/adapters/sql \ + core/dbt/adapters/cache.py \ core/dbt/clients \ core/dbt/config \ core/dbt/deprecations.py \