diff --git a/.bumpversion-dbt.cfg b/.bumpversion-dbt.cfg index 62907f23b..a169d8179 100644 --- a/.bumpversion-dbt.cfg +++ b/.bumpversion-dbt.cfg @@ -1,10 +1,10 @@ [bumpversion] -current_version = 0.16.1 +current_version = 0.17.0rc2 parse = (?P\d+) \.(?P\d+) \.(?P\d+) ((?P[a-z]+)(?P\d+))? -serialize = +serialize = {major}.{minor}.{patch}{prerelease}{num} {major}.{minor}.{patch} commit = False @@ -12,7 +12,7 @@ tag = False [bumpversion:part:prerelease] first_value = a -values = +values = a b rc @@ -23,3 +23,4 @@ first_value = 1 [bumpversion:file:setup.py] [bumpversion:file:requirements.txt] + diff --git a/.bumpversion.cfg b/.bumpversion.cfg index e21d4d971..5c3ccfa24 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.16.1 +current_version = 0.17.0rc2 parse = (?P\d+) \.(?P\d+) \.(?P\d+) diff --git a/dbt/adapters/spark/__version__.py b/dbt/adapters/spark/__version__.py index bbbab11e4..ef664b00e 100644 --- a/dbt/adapters/spark/__version__.py +++ b/dbt/adapters/spark/__version__.py @@ -1 +1 @@ -version = "0.16.1" +version = "0.17.0rc2" diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 0a3110bd4..2c152fc82 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -32,10 +32,10 @@ class SparkCredentials(Credentials): host: str method: SparkConnectionMethod schema: str - cluster: Optional[str] - token: Optional[str] - user: Optional[str] database: Optional[str] + cluster: Optional[str] = None + token: Optional[str] = None + user: Optional[str] = None port: int = 443 organization: str = '0' connect_retries: int = 0 @@ -43,6 +43,16 @@ class SparkCredentials(Credentials): def __post_init__(self): # spark classifies database and schema as the same thing + if ( + self.database is not None and + self.database != self.schema + ): + raise dbt.exceptions.RuntimeException( + f' schema: {self.schema} \n' + f' database: {self.database} \n' + f'On Spark, database must be omitted or have the same value as' + f' schema.' + ) self.database = self.schema @property @@ -267,21 +277,26 @@ def open(cls, connection): break except Exception as e: exc = e - if getattr(e, 'message', None) is None: - raise dbt.exceptions.FailedToConnectException(str(e)) - - message = e.message.lower() - is_pending = 'pending' in message - is_starting = 'temporarily_unavailable' in message - - warning = "Warning: {}\n\tRetrying in {} seconds ({} of {})" - if is_pending or is_starting: - msg = warning.format(e.message, creds.connect_timeout, - i, creds.connect_retries) + if isinstance(e, EOFError): + # The user almost certainly has invalid credentials. + # Perhaps a token expired, or something + msg = 'Failed to connect' + if creds.token is not None: + msg += ', is your token valid?' + raise dbt.exceptions.FailedToConnectException(msg) from e + retryable_message = _is_retryable_error(e) + if retryable_message: + msg = ( + f"Warning: {retryable_message}\n\tRetrying in " + f"{creds.connect_timeout} seconds " + f"({i} of {creds.connect_retries})" + ) logger.warning(msg) time.sleep(creds.connect_timeout) else: - raise dbt.exceptions.FailedToConnectException(str(e)) + raise dbt.exceptions.FailedToConnectException( + 'failed to connect' + ) from e else: raise exc @@ -289,3 +304,15 @@ def open(cls, connection): connection.handle = handle connection.state = ConnectionState.OPEN return connection + + +def _is_retryable_error(exc: Exception) -> Optional[str]: + message = getattr(exc, 'message', None) + if message is None: + return None + message = message.lower() + if 'pending' in message: + return exc.message + if 'temporarily_unavailable' in message: + return exc.message + return None diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 6140be650..95ba44841 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,11 +1,11 @@ -from typing import Optional, List, Dict, Any +from dataclasses import dataclass +from typing import Optional, List, Dict, Any, Union import agate import dbt.exceptions import dbt -from dbt.adapters.base.relation import SchemaSearchMap +from dbt.adapters.base import AdapterConfig from dbt.adapters.sql import SQLAdapter -from dbt.node_types import NodeType from dbt.adapters.spark import SparkConnectionManager from dbt.adapters.spark import SparkRelation @@ -25,6 +25,15 @@ KEY_TABLE_STATISTICS = 'Statistics' +@dataclass +class SparkConfig(AdapterConfig): + file_format: str = 'parquet' + location_root: Optional[str] = None + partition_by: Optional[Union[List[str], str]] = None + clustered_by: Optional[Union[List[str], str]] = None + buckets: Optional[int] = None + + class SparkAdapter(SQLAdapter): COLUMN_NAMES = ( 'table_database', @@ -52,10 +61,7 @@ class SparkAdapter(SQLAdapter): Relation = SparkRelation Column = SparkColumn ConnectionManager = SparkConnectionManager - - AdapterSpecificConfigs = frozenset({"file_format", "location_root", - "partition_by", "clustered_by", - "buckets"}) + AdapterSpecificConfigs = SparkConfig @classmethod def date_function(cls) -> str: @@ -98,9 +104,9 @@ def add_schema_to_cache(self, schema) -> str: return '' def list_relations_without_caching( - self, information_schema, schema + self, schema_relation: SparkRelation ) -> List[SparkRelation]: - kwargs = {'information_schema': information_schema, 'schema': schema} + kwargs = {'schema_relation': schema_relation} try: results = self.execute_macro( LIST_RELATIONS_MACRO_NAME, @@ -108,11 +114,12 @@ def list_relations_without_caching( release=True ) except dbt.exceptions.RuntimeException as e: - if hasattr(e, 'msg') and f"Database '{schema}' not found" in e.msg: + errmsg = getattr(e, 'msg', '') + if f"Database '{schema_relation}' not found" in errmsg: return [] else: description = "Error while retrieving information about" - logger.debug(f"{description} {schema}: {e.msg}") + logger.debug(f"{description} {schema_relation}: {e.msg}") return [] relations = [] @@ -279,21 +286,6 @@ def _get_catalog_for_relations(self, database: str, schema: str): ) return agate.Table.from_object(columns) - def _get_cache_schemas(self, manifest, exec_only=False): - info_schema_name_map = SchemaSearchMap() - for node in manifest.nodes.values(): - if exec_only and node.resource_type not in NodeType.executable(): - continue - relation = self.Relation.create( - database=node.database, - schema=node.schema, - identifier='information_schema', - quote_policy=self.config.quoting, - ) - key = relation.information_schema_only() - info_schema_name_map[key] = {node.schema} - return info_schema_name_map - def _get_one_catalog( self, information_schema, schemas, manifest, ) -> agate.Table: diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index 92473d55e..2106e5cba 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from dbt.adapters.base.relation import BaseRelation, Policy +from dbt.exceptions import RuntimeException @dataclass @@ -22,3 +23,23 @@ class SparkRelation(BaseRelation): quote_policy: SparkQuotePolicy = SparkQuotePolicy() include_policy: SparkIncludePolicy = SparkIncludePolicy() quote_character: str = '`' + + def __post_init__(self): + # some core things set database='', which we should ignore. + if self.database and self.database != self.schema: + raise RuntimeException( + f'Error while parsing relation {self.name}: \n' + f' identifier: {self.identifier} \n' + f' schema: {self.schema} \n' + f' database: {self.database} \n' + f'On Spark, database should not be set. Use the schema ' + f'config to set a custom schema/database for this relation.' + ) + + def render(self): + if self.include_policy.database and self.include_policy.schema: + raise RuntimeException( + 'Got a spark relation with schema and database set to ' + 'include, but only one can be set' + ) + return super().render() diff --git a/dbt/include/spark/dbt_project.yml b/dbt/include/spark/dbt_project.yml index 2294c23d1..36d69b415 100644 --- a/dbt/include/spark/dbt_project.yml +++ b/dbt/include/spark/dbt_project.yml @@ -1,5 +1,5 @@ - name: dbt_spark version: 1.0 +config-version: 2 macro-paths: ["macros"] diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 3e6ce2369..d0e11fca6 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -20,7 +20,7 @@ {%- if raw_persist_docs is mapping -%} {%- set raw_relation = raw_persist_docs.get('relation', false) -%} {%- if raw_relation -%} - comment '{{ model.description }}' + comment '{{ model.description | replace("'", "\\'") }}' {% endif %} {%- else -%} {{ exceptions.raise_compiler_error("Invalid value provided for 'persist_docs'. Expected dict but got value: " ~ raw_persist_docs) }} @@ -96,15 +96,15 @@ {{ sql }} {% endmacro %} -{% macro spark__create_schema(database_name, schema_name) -%} +{% macro spark__create_schema(relation) -%} {%- call statement('create_schema') -%} - create schema if not exists {{schema_name}} + create schema if not exists {{relation}} {% endcall %} {% endmacro %} -{% macro spark__drop_schema(database_name, schema_name) -%} +{% macro spark__drop_schema(relation) -%} {%- call statement('drop_schema') -%} - drop schema if exists {{ schema_name }} cascade + drop schema if exists {{ relation }} cascade {%- endcall -%} {% endmacro %} @@ -115,9 +115,9 @@ {% do return(load_result('get_columns_in_relation').table) %} {% endmacro %} -{% macro spark__list_relations_without_caching(information_schema, schema) %} +{% macro spark__list_relations_without_caching(relation) %} {% call statement('list_relations_without_caching', fetch_result=True) -%} - show table extended in {{ schema }} like '*' + show table extended in {{ relation }} like '*' {% endcall %} {% do return(load_result('list_relations_without_caching').table) %} diff --git a/requirements.txt b/requirements.txt index 2f2c177d9..219e52ff3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -dbt-core==0.16.1 +dbt-core==0.17.0rc2 PyHive[hive]>=0.6.0,<0.7.0 thrift>=0.11.0,<0.12.0 diff --git a/setup.py b/setup.py index 9eb520f6f..00c26e511 100644 --- a/setup.py +++ b/setup.py @@ -28,9 +28,9 @@ def _dbt_spark_version(): package_version = _dbt_spark_version() description = """The SparkSQL plugin for dbt (data build tool)""" -dbt_version = '0.16.1' +dbt_version = '0.17.0rc2' # the package version should be the dbt version, with maybe some things on the -# ends of it. (0.16.1 vs 0.16.1a1, 0.16.1.1, ...) +# ends of it. (0.17.0rc2 vs 0.17.0rc2a1, 0.17.0rc2.1, ...) if not package_version.startswith(dbt_version): raise ValueError( f'Invalid setup.py: package_version={package_version} must start with ' diff --git a/test/unit/test_adapter.py b/test/unit/test_adapter.py index 72e78ea5c..e453c12b1 100644 --- a/test/unit/test_adapter.py +++ b/test/unit/test_adapter.py @@ -2,6 +2,7 @@ from unittest import mock import dbt.flags as flags +from dbt.exceptions import RuntimeException from agate import Row from pyhive import hive from dbt.adapters.spark import SparkAdapter, SparkRelation @@ -101,7 +102,6 @@ def test_parse_relation(self): rel_type = SparkRelation.get_relation_type.Table relation = SparkRelation.create( - database='default_database', schema='default_schema', identifier='mytable', type=rel_type @@ -182,7 +182,6 @@ def test_parse_relation_with_statistics(self): rel_type = SparkRelation.get_relation_type.Table relation = SparkRelation.create( - database='default_database', schema='default_schema', identifier='mytable', type=rel_type @@ -236,3 +235,33 @@ def test_parse_relation_with_statistics(self): 'stats:rows:label': 'rows', 'stats:rows:value': 14093476, }) + + def test_relation_with_database(self): + config = self._get_target_http(self.project_cfg) + adapter = SparkAdapter(config) + # fine + adapter.Relation.create(schema='different', identifier='table') + with self.assertRaises(RuntimeException): + # not fine - database set + adapter.Relation.create(database='something', schema='different', identifier='table') + + def test_profile_with_database(self): + profile = { + 'outputs': { + 'test': { + 'type': 'spark', + 'method': 'http', + # not allowed + 'database': 'analytics2', + 'schema': 'analytics', + 'host': 'myorg.sparkhost.com', + 'port': 443, + 'token': 'abc123', + 'organization': '0123456789', + 'cluster': '01234-23423-coffeetime', + } + }, + 'target': 'test' + } + with self.assertRaises(RuntimeException): + config_from_parts_or_dicts(self.project_cfg, profile) diff --git a/test/unit/utils.py b/test/unit/utils.py index affb6c375..53630bba0 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -35,13 +35,14 @@ def mock_connection(name): def profile_from_dict(profile, profile_name, cli_vars='{}'): - from dbt.config import Profile, ConfigRenderer + from dbt.config import Profile + from dbt.config.renderer import ProfileRenderer from dbt.context.base import generate_base_context from dbt.utils import parse_cli_vars if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) - renderer = ConfigRenderer(generate_base_context(cli_vars)) + renderer = ProfileRenderer(generate_base_context(cli_vars)) return Profile.from_raw_profile_info( profile, profile_name, @@ -51,12 +52,13 @@ def profile_from_dict(profile, profile_name, cli_vars='{}'): def project_from_dict(project, profile, packages=None, cli_vars='{}'): from dbt.context.target import generate_target_context - from dbt.config import Project, ConfigRenderer + from dbt.config import Project + from dbt.config.renderer import DbtProjectYamlRenderer from dbt.utils import parse_cli_vars if not isinstance(cli_vars, dict): cli_vars = parse_cli_vars(cli_vars) - renderer = ConfigRenderer(generate_target_context(profile, cli_vars)) + renderer = DbtProjectYamlRenderer(generate_target_context(profile, cli_vars)) project_root = project.pop('project-root', os.getcwd())