Skip to content

Commit

Permalink
Merge pull request #83 from fishtown-analytics/feature/0.17.0
Browse files Browse the repository at this point in the history
Upgrade to 0.17.0rc2
  • Loading branch information
jtcohen6 authored May 22, 2020
2 parents 8a2b24e + 07f9bf8 commit a4d6874
Show file tree
Hide file tree
Showing 12 changed files with 135 additions and 63 deletions.
7 changes: 4 additions & 3 deletions .bumpversion-dbt.cfg
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
[bumpversion]
current_version = 0.16.1
current_version = 0.17.0rc2
parse = (?P<major>\d+)
\.(?P<minor>\d+)
\.(?P<patch>\d+)
((?P<prerelease>[a-z]+)(?P<num>\d+))?
serialize =
serialize =
{major}.{minor}.{patch}{prerelease}{num}
{major}.{minor}.{patch}
commit = False
tag = False

[bumpversion:part:prerelease]
first_value = a
values =
values =
a
b
rc
Expand All @@ -23,3 +23,4 @@ first_value = 1
[bumpversion:file:setup.py]

[bumpversion:file:requirements.txt]

2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.16.1
current_version = 0.17.0rc2
parse = (?P<major>\d+)
\.(?P<minor>\d+)
\.(?P<patch>\d+)
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/spark/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.16.1"
version = "0.17.0rc2"
57 changes: 42 additions & 15 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,27 @@ class SparkCredentials(Credentials):
host: str
method: SparkConnectionMethod
schema: str
cluster: Optional[str]
token: Optional[str]
user: Optional[str]
database: Optional[str]
cluster: Optional[str] = None
token: Optional[str] = None
user: Optional[str] = None
port: int = 443
organization: str = '0'
connect_retries: int = 0
connect_timeout: int = 10

def __post_init__(self):
# spark classifies database and schema as the same thing
if (
self.database is not None and
self.database != self.schema
):
raise dbt.exceptions.RuntimeException(
f' schema: {self.schema} \n'
f' database: {self.database} \n'
f'On Spark, database must be omitted or have the same value as'
f' schema.'
)
self.database = self.schema

@property
Expand Down Expand Up @@ -267,25 +277,42 @@ def open(cls, connection):
break
except Exception as e:
exc = e
if getattr(e, 'message', None) is None:
raise dbt.exceptions.FailedToConnectException(str(e))

message = e.message.lower()
is_pending = 'pending' in message
is_starting = 'temporarily_unavailable' in message

warning = "Warning: {}\n\tRetrying in {} seconds ({} of {})"
if is_pending or is_starting:
msg = warning.format(e.message, creds.connect_timeout,
i, creds.connect_retries)
if isinstance(e, EOFError):
# The user almost certainly has invalid credentials.
# Perhaps a token expired, or something
msg = 'Failed to connect'
if creds.token is not None:
msg += ', is your token valid?'
raise dbt.exceptions.FailedToConnectException(msg) from e
retryable_message = _is_retryable_error(e)
if retryable_message:
msg = (
f"Warning: {retryable_message}\n\tRetrying in "
f"{creds.connect_timeout} seconds "
f"({i} of {creds.connect_retries})"
)
logger.warning(msg)
time.sleep(creds.connect_timeout)
else:
raise dbt.exceptions.FailedToConnectException(str(e))
raise dbt.exceptions.FailedToConnectException(
'failed to connect'
) from e
else:
raise exc

handle = ConnectionWrapper(conn)
connection.handle = handle
connection.state = ConnectionState.OPEN
return connection


def _is_retryable_error(exc: Exception) -> Optional[str]:
message = getattr(exc, 'message', None)
if message is None:
return None
message = message.lower()
if 'pending' in message:
return exc.message
if 'temporarily_unavailable' in message:
return exc.message
return None
44 changes: 18 additions & 26 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Optional, List, Dict, Any
from dataclasses import dataclass
from typing import Optional, List, Dict, Any, Union

import agate
import dbt.exceptions
import dbt
from dbt.adapters.base.relation import SchemaSearchMap
from dbt.adapters.base import AdapterConfig
from dbt.adapters.sql import SQLAdapter
from dbt.node_types import NodeType

from dbt.adapters.spark import SparkConnectionManager
from dbt.adapters.spark import SparkRelation
Expand All @@ -25,6 +25,15 @@
KEY_TABLE_STATISTICS = 'Statistics'


@dataclass
class SparkConfig(AdapterConfig):
file_format: str = 'parquet'
location_root: Optional[str] = None
partition_by: Optional[Union[List[str], str]] = None
clustered_by: Optional[Union[List[str], str]] = None
buckets: Optional[int] = None


class SparkAdapter(SQLAdapter):
COLUMN_NAMES = (
'table_database',
Expand Down Expand Up @@ -52,10 +61,7 @@ class SparkAdapter(SQLAdapter):
Relation = SparkRelation
Column = SparkColumn
ConnectionManager = SparkConnectionManager

AdapterSpecificConfigs = frozenset({"file_format", "location_root",
"partition_by", "clustered_by",
"buckets"})
AdapterSpecificConfigs = SparkConfig

@classmethod
def date_function(cls) -> str:
Expand Down Expand Up @@ -98,21 +104,22 @@ def add_schema_to_cache(self, schema) -> str:
return ''

def list_relations_without_caching(
self, information_schema, schema
self, schema_relation: SparkRelation
) -> List[SparkRelation]:
kwargs = {'information_schema': information_schema, 'schema': schema}
kwargs = {'schema_relation': schema_relation}
try:
results = self.execute_macro(
LIST_RELATIONS_MACRO_NAME,
kwargs=kwargs,
release=True
)
except dbt.exceptions.RuntimeException as e:
if hasattr(e, 'msg') and f"Database '{schema}' not found" in e.msg:
errmsg = getattr(e, 'msg', '')
if f"Database '{schema_relation}' not found" in errmsg:
return []
else:
description = "Error while retrieving information about"
logger.debug(f"{description} {schema}: {e.msg}")
logger.debug(f"{description} {schema_relation}: {e.msg}")
return []

relations = []
Expand Down Expand Up @@ -279,21 +286,6 @@ def _get_catalog_for_relations(self, database: str, schema: str):
)
return agate.Table.from_object(columns)

def _get_cache_schemas(self, manifest, exec_only=False):
info_schema_name_map = SchemaSearchMap()
for node in manifest.nodes.values():
if exec_only and node.resource_type not in NodeType.executable():
continue
relation = self.Relation.create(
database=node.database,
schema=node.schema,
identifier='information_schema',
quote_policy=self.config.quoting,
)
key = relation.information_schema_only()
info_schema_name_map[key] = {node.schema}
return info_schema_name_map

def _get_one_catalog(
self, information_schema, schemas, manifest,
) -> agate.Table:
Expand Down
21 changes: 21 additions & 0 deletions dbt/adapters/spark/relation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from dataclasses import dataclass

from dbt.adapters.base.relation import BaseRelation, Policy
from dbt.exceptions import RuntimeException


@dataclass
Expand All @@ -22,3 +23,23 @@ class SparkRelation(BaseRelation):
quote_policy: SparkQuotePolicy = SparkQuotePolicy()
include_policy: SparkIncludePolicy = SparkIncludePolicy()
quote_character: str = '`'

def __post_init__(self):
# some core things set database='', which we should ignore.
if self.database and self.database != self.schema:
raise RuntimeException(
f'Error while parsing relation {self.name}: \n'
f' identifier: {self.identifier} \n'
f' schema: {self.schema} \n'
f' database: {self.database} \n'
f'On Spark, database should not be set. Use the schema '
f'config to set a custom schema/database for this relation.'
)

def render(self):
if self.include_policy.database and self.include_policy.schema:
raise RuntimeException(
'Got a spark relation with schema and database set to '
'include, but only one can be set'
)
return super().render()
2 changes: 1 addition & 1 deletion dbt/include/spark/dbt_project.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

name: dbt_spark
version: 1.0
config-version: 2

macro-paths: ["macros"]
14 changes: 7 additions & 7 deletions dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
{%- if raw_persist_docs is mapping -%}
{%- set raw_relation = raw_persist_docs.get('relation', false) -%}
{%- if raw_relation -%}
comment '{{ model.description }}'
comment '{{ model.description | replace("'", "\\'") }}'
{% endif %}
{%- else -%}
{{ exceptions.raise_compiler_error("Invalid value provided for 'persist_docs'. Expected dict but got value: " ~ raw_persist_docs) }}
Expand Down Expand Up @@ -96,15 +96,15 @@
{{ sql }}
{% endmacro %}

{% macro spark__create_schema(database_name, schema_name) -%}
{% macro spark__create_schema(relation) -%}
{%- call statement('create_schema') -%}
create schema if not exists {{schema_name}}
create schema if not exists {{relation}}
{% endcall %}
{% endmacro %}

{% macro spark__drop_schema(database_name, schema_name) -%}
{% macro spark__drop_schema(relation) -%}
{%- call statement('drop_schema') -%}
drop schema if exists {{ schema_name }} cascade
drop schema if exists {{ relation }} cascade
{%- endcall -%}
{% endmacro %}

Expand All @@ -115,9 +115,9 @@
{% do return(load_result('get_columns_in_relation').table) %}
{% endmacro %}

{% macro spark__list_relations_without_caching(information_schema, schema) %}
{% macro spark__list_relations_without_caching(relation) %}
{% call statement('list_relations_without_caching', fetch_result=True) -%}
show table extended in {{ schema }} like '*'
show table extended in {{ relation }} like '*'
{% endcall %}

{% do return(load_result('list_relations_without_caching').table) %}
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
dbt-core==0.16.1
dbt-core==0.17.0rc2
PyHive[hive]>=0.6.0,<0.7.0
thrift>=0.11.0,<0.12.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def _dbt_spark_version():
package_version = _dbt_spark_version()
description = """The SparkSQL plugin for dbt (data build tool)"""

dbt_version = '0.16.1'
dbt_version = '0.17.0rc2'
# the package version should be the dbt version, with maybe some things on the
# ends of it. (0.16.1 vs 0.16.1a1, 0.16.1.1, ...)
# ends of it. (0.17.0rc2 vs 0.17.0rc2a1, 0.17.0rc2.1, ...)
if not package_version.startswith(dbt_version):
raise ValueError(
f'Invalid setup.py: package_version={package_version} must start with '
Expand Down
33 changes: 31 additions & 2 deletions test/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from unittest import mock

import dbt.flags as flags
from dbt.exceptions import RuntimeException
from agate import Row
from pyhive import hive
from dbt.adapters.spark import SparkAdapter, SparkRelation
Expand Down Expand Up @@ -101,7 +102,6 @@ def test_parse_relation(self):
rel_type = SparkRelation.get_relation_type.Table

relation = SparkRelation.create(
database='default_database',
schema='default_schema',
identifier='mytable',
type=rel_type
Expand Down Expand Up @@ -182,7 +182,6 @@ def test_parse_relation_with_statistics(self):
rel_type = SparkRelation.get_relation_type.Table

relation = SparkRelation.create(
database='default_database',
schema='default_schema',
identifier='mytable',
type=rel_type
Expand Down Expand Up @@ -236,3 +235,33 @@ def test_parse_relation_with_statistics(self):
'stats:rows:label': 'rows',
'stats:rows:value': 14093476,
})

def test_relation_with_database(self):
config = self._get_target_http(self.project_cfg)
adapter = SparkAdapter(config)
# fine
adapter.Relation.create(schema='different', identifier='table')
with self.assertRaises(RuntimeException):
# not fine - database set
adapter.Relation.create(database='something', schema='different', identifier='table')

def test_profile_with_database(self):
profile = {
'outputs': {
'test': {
'type': 'spark',
'method': 'http',
# not allowed
'database': 'analytics2',
'schema': 'analytics',
'host': 'myorg.sparkhost.com',
'port': 443,
'token': 'abc123',
'organization': '0123456789',
'cluster': '01234-23423-coffeetime',
}
},
'target': 'test'
}
with self.assertRaises(RuntimeException):
config_from_parts_or_dicts(self.project_cfg, profile)
Loading

0 comments on commit a4d6874

Please sign in to comment.