Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use builtin timestampadd and timestampdiff functions for dateadd/datediff macros if available. #185

Merged
merged 9 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Features
- Support python model through run command API, currently supported materializations are table and incremental. ([dbt-labs/dbt-spark#377](https://github.com/dbt-labs/dbt-spark/pull/377), [#126](https://github.com/databricks/dbt-databricks/pull/126))
- Enable Pandas and Pandas-on-Spark DataFrames for dbt python models ([dbt-labs/dbt-spark#469](https://github.com/dbt-labs/dbt-spark/pull/469), [#181](https://github.com/databricks/dbt-databricks/pull/181))
- Use builtin timestampadd and timestampdiff functions for dateadd/datediff macros if available ([#185](https://github.com/databricks/dbt-databricks/pull/185))
- Implement testing for a test for various Python models ([#189](https://github.com/databricks/dbt-databricks/pull/189))
- Implement testing for `type_boolean` in Databricks ([dbt-labs/dbt-spark#471](https://github.com/dbt-labs/dbt-spark/pull/471), [#188](https://github.com/databricks/dbt-databricks/pull/188))

Expand Down
41 changes: 39 additions & 2 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import itertools
import os
import re
import sys
import time
from typing import (
Any,
Expand Down Expand Up @@ -50,6 +51,7 @@
logger = AdapterLogger("Databricks")

CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog"
DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)")
DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV"
DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$")
EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)")
Expand Down Expand Up @@ -176,10 +178,12 @@ class DatabricksSQLConnectionWrapper:
"""Wrap a Databricks SQL connector in a way that no-ops transactions"""

_conn: DatabricksSQLConnection
_is_cluster: bool
_cursors: List[DatabricksSQLCursor]

def __init__(self, conn: DatabricksSQLConnection):
def __init__(self, conn: DatabricksSQLConnection, *, is_cluster: bool):
self._conn = conn
self._is_cluster = is_cluster
self._cursors = []

def cursor(self) -> "DatabricksSQLCursorWrapper":
Expand Down Expand Up @@ -207,6 +211,30 @@ def close(self) -> None:
def rollback(self, *args: Any, **kwargs: Any) -> None:
logger.debug("NotImplemented: rollback")

_dbr_version: Tuple[int, int]

@property
def dbr_version(self) -> Tuple[int, int]:
if not hasattr(self, "_dbr_version"):
if self._is_cluster:
with self._conn.cursor() as cursor:
cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion")
dbr_version: str = cursor.fetchone()[1]

m = DBR_VERSION_REGEX.match(dbr_version)
assert m, f"Unknown DBR version: {dbr_version}"
major = int(m.group(1))
try:
minor = int(m.group(2))
except ValueError:
minor = sys.maxsize
self._dbr_version = (major, minor)
else:
# Assuming SQL Warehouse uses the latest version.
self._dbr_version = (sys.maxsize, sys.maxsize)

return self._dbr_version


class DatabricksSQLCursorWrapper:
"""Wrap a Databricks SQL cursor in a way that no-ops transactions"""
Expand Down Expand Up @@ -313,6 +341,13 @@ def _get_comment_macro(self) -> Optional[str]:
class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"

def compare_dbr_version(self, major: int, minor: int) -> int:
version = (major, minor)

connection: DatabricksSQLConnectionWrapper = self.get_thread_connection().handle
dbr_version = connection.dbr_version
return (dbr_version > version) - (dbr_version < version)

def set_query_header(self, manifest: Manifest) -> None:
self.query_header = DatabricksMacroQueryStringSetter(self.profile, manifest)

Expand Down Expand Up @@ -483,7 +518,9 @@ def open(cls, connection: Connection) -> Connection:
_user_agent_entry=user_agent_entry,
**connection_parameters,
)
handle = DatabricksSQLConnectionWrapper(conn)
handle = DatabricksSQLConnectionWrapper(
conn, is_cluster=creds.cluster_id is not None
)
break
except Exception as e:
exc = e
Expand Down
14 changes: 14 additions & 0 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed
from dbt.adapters.base.meta import available
from dbt.adapters.base.relation import BaseRelation
from dbt.adapters.spark.impl import (
SparkAdapter,
Expand Down Expand Up @@ -59,6 +60,19 @@ class DatabricksAdapter(SparkAdapter):

AdapterSpecificConfigs = DatabricksConfig

@available.parse(lambda *a, **k: 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why do we need this @available.parse?

Copy link
Collaborator Author

@ueshin ueshin Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Methods to be exposed to macros need to be marked with @available or its variants.

And without @available.parse, the function will be executed even at parse time that causes InvalidConnectionException because connections are not available at parse time.

A decorator factory to indicate that a method on the adapter will be
exposed to the database wrapper, and will be stubbed out at parse time
with the given function.

def compare_dbr_version(self, major: int, minor: int) -> int:
"""
Returns the comparison result between the version of the cluster and the specified version.

- positive number if the cluster version is greater than the specified version.
- 0 if the versions are the same
- negative number if the cluster version is less than the specified version.

Always returns positive number if trying to connect to SQL Warehouse.
"""
return self.connections.compare_dbr_version(major, minor)

def list_schemas(self, database: Optional[str]) -> List[str]:
"""
Get a list of existing schemas in database.
Expand Down
7 changes: 7 additions & 0 deletions dbt/include/databricks/macros/utils/dateadd.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% macro databricks__dateadd(datepart, interval, from_date_or_timestamp) %}
{%- if adapter.compare_dbr_version(10, 4) >= 0 -%}
timestampadd({{datepart}}, {{interval}}, {{from_date_or_timestamp}})
{%- else -%}
{{ spark__dateadd(datepart, interval, from_date_or_timestamp) }}
{%- endif -%}
{%- endmacro %}
7 changes: 7 additions & 0 deletions dbt/include/databricks/macros/utils/datediff.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% macro databricks__datediff(first_date, second_date, datepart) %}
{%- if adapter.compare_dbr_version(10, 4) >= 0 -%}
timestampdiff({{datepart}}, {{date_trunc(datepart, first_date)}}, {{date_trunc(datepart, second_date)}})
{%- else -%}
{{ spark__datediff(first_date, second_date, datepart) }}
{%- endif -%}
{%- endmacro %}
62 changes: 62 additions & 0 deletions tests/functional/adapter/utils/fixture_dateadd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# dateadd

seeds__data_dateadd_csv = """from_time,interval_length,datepart,result
2018-01-01 01:00:00,1,day,2018-01-02 01:00:00
2018-01-01 01:00:00,1,week,2018-01-08 01:00:00
2018-01-01 01:00:00,1,month,2018-02-01 01:00:00
2018-01-01 01:00:00,1,quarter,2018-04-01 01:00:00
2018-01-01 01:00:00,1,year,2019-01-01 01:00:00
2018-01-01 01:00:00,1,hour,2018-01-01 02:00:00
2018-01-01 01:00:00,1,minute,2018-01-01 01:01:00
2018-01-01 01:00:00,1,second,2018-01-01 01:00:01
,1,day,
"""


models__test_dateadd_sql = """
with data as (

select * from {{ ref('data_dateadd') }}

)

select
case
when datepart = 'day' then cast(
{{ dateadd('day', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'week' then cast(
{{ dateadd('week', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'month' then cast(
{{ dateadd('month', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'quarter' then cast(
{{ dateadd('quarter', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'year' then cast(
{{ dateadd('year', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'hour' then cast(
{{ dateadd('hour', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'minute' then cast(
{{ dateadd('minute', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
when datepart = 'second' then cast(
{{ dateadd('second', 'interval_length', 'from_time') }}
as {{ api.Column.translate_type('timestamp') }}
)
else null
end as actual,
result as expected

from data
"""
20 changes: 17 additions & 3 deletions tests/functional/adapter/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@

# requires modification
from dbt.tests.adapter.utils.test_listagg import BaseListagg
from dbt.tests.adapter.utils.fixture_dateadd import models__test_dateadd_yml
from dbt.tests.adapter.utils.fixture_listagg import models__test_listagg_yml
from tests.functional.adapter.utils.fixture_dateadd import (
models__test_dateadd_sql,
seeds__data_dateadd_csv,
)
from tests.functional.adapter.utils.fixture_listagg import models__test_listagg_no_order_by_sql


Expand All @@ -43,11 +48,20 @@ class TestConcat(BaseConcat):


class TestDateAdd(BaseDateAdd):
pass
@pytest.fixture(scope="class")
def seeds(self):
return {"data_dateadd.csv": seeds__data_dateadd_csv}

@pytest.fixture(scope="class")
def models(self):
return {
"test_dateadd.yml": models__test_dateadd_yml,
"test_dateadd.sql": self.interpolate_macro_namespace(
models__test_dateadd_sql, "dateadd"
),
}


# This test generates a super long create table sentence that exceeds HiveMetastore's limit
@pytest.mark.skip_profile("databricks_cluster", "databricks_sql_endpoint")
class TestDateDiff(BaseDateDiff):
pass

Expand Down