Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Koala 1423 support session connection #1

Merged
merged 4 commits into from
Dec 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ logs/
.venv*
*.sublime*
.python-version
venv/
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version: str = "1.8.6"
version: str = "1.8.6a14"
60 changes: 59 additions & 1 deletion dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from multiprocessing.context import SpawnContext
from numbers import Number
from threading import get_ident
from typing import Any
from typing import Any, Generator
from typing import Callable
from typing import cast
from typing import Dict
Expand Down Expand Up @@ -93,6 +93,8 @@

# toggle for session managements that minimizes the number of sessions opened/closed
USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
# toggle for session managements that assumes the adapter is running in a Databricks session
USE_SESSION_CONNECTION = os.getenv("DBT_DATABRICKS_SESSION_CONNECTION", "False").upper() == "TRUE"

# Number of idle seconds before a connection is automatically closed. Only applicable if
# USE_LONG_SESSIONS is true.
Expand Down Expand Up @@ -1086,6 +1088,62 @@ def exponential_backoff(attempt: int) -> int:
)


class DatabricksSessionConnectionManager(DatabricksConnectionManager):

def cancel_open(self) -> List[str]:
SparkConnectionManager.cancel_open(self)

def compare_dbr_version(self, major: int, minor: int) -> int:
wubbaLubbaDubIL marked this conversation as resolved.
Show resolved Hide resolved
version = (major, minor)
connection = self.get_thread_connection().handle
dbr_version = connection.dbr_version
return (dbr_version > version) - (dbr_version < version)

def set_query_header(self, query_header_context: Dict[str, Any]) -> None:
SparkConnectionManager.set_query_header(self, query_header_context)

def set_connection_name(
self, name: Optional[str] = None, query_header_context: Any = None
) -> Connection:
return SparkConnectionManager.set_connection_name(self, name)

def add_query(
self,
sql: str,
auto_begin: bool = True,
bindings: Optional[Any] = None,
abridge_sql_log: bool = False,
*,
close_cursor: bool = False,
) -> Tuple[Connection, Any]:
return SparkConnectionManager.add_query(self, sql, auto_begin, bindings, abridge_sql_log)

def list_schemas(self, database: str, schema: Optional[str] = None) -> "Table":
raise NotImplementedError(
"list_schemas is not implemented for DatabricksSessionConnectionManager - should call the list_schemas macro instead"
)

def list_tables(self, database: str, schema: str, identifier: Optional[str] = None) -> "Table":
raise NotImplementedError(
"list_tables is not implemented for DatabricksSessionConnectionManager - should call the list_tables macro instead"
)

@classmethod
def open(cls, connection: Connection) -> Connection:
from dbt.adapters.spark.session import Connection
from dbt.adapters.databricks.session_connection import DatabricksSessionConnectionWrapper

handle = DatabricksSessionConnectionWrapper(Connection())
connection.handle = handle
connection.state = ConnectionState.OPEN
return connection

@classmethod
def get_response(cls, cursor) -> DatabricksAdapterResponse:
response = SparkConnectionManager.get_response(cursor)
return DatabricksAdapterResponse(_message=response._message, query_id=None)


def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict:
pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}"

Expand Down
15 changes: 10 additions & 5 deletions dbt/adapters/databricks/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@
from dbt.adapters.contracts.relation import RelationConfig
from dbt.adapters.contracts.relation import RelationType
from dbt.adapters.databricks.column import DatabricksColumn
from dbt.adapters.databricks.connections import DatabricksConnectionManager
from dbt.adapters.databricks.connections import (
DatabricksConnectionManager,
DatabricksSessionConnectionManager,
)
from dbt.adapters.databricks.connections import DatabricksDBTConnection
from dbt.adapters.databricks.connections import DatabricksSQLConnectionWrapper
from dbt.adapters.databricks.connections import ExtendedSessionConnectionManager
from dbt.adapters.databricks.connections import USE_LONG_SESSIONS
from dbt.adapters.databricks.connections import USE_LONG_SESSIONS, USE_SESSION_CONNECTION
from dbt.adapters.databricks.python_submissions import (
DbtDatabricksAllPurposeClusterPythonJobHelper,
)
Expand Down Expand Up @@ -148,7 +151,9 @@ class DatabricksAdapter(SparkAdapter):
Relation = DatabricksRelation
Column = DatabricksColumn

if USE_LONG_SESSIONS:
if USE_SESSION_CONNECTION:
ConnectionManager: Type[DatabricksConnectionManager] = DatabricksSessionConnectionManager
elif USE_LONG_SESSIONS:
ConnectionManager: Type[DatabricksConnectionManager] = ExtendedSessionConnectionManager
else:
ConnectionManager = DatabricksConnectionManager
Expand Down Expand Up @@ -207,7 +212,7 @@ def list_schemas(self, database: Optional[str]) -> List[str]:
If `database` is `None`, fallback to executing `show databases` because
`list_schemas` tries to collect schemas from all catalogs when `database` is `None`.
"""
if database is not None:
if database is not None and not USE_SESSION_CONNECTION:
results = self.connections.list_schemas(database=database)
else:
results = self.execute_macro(LIST_SCHEMAS_MACRO_NAME, kwargs={"database": database})
Expand Down Expand Up @@ -281,7 +286,7 @@ def _get_hive_relations(
kwargs = {"relation": relation}

new_rows: List[Tuple[str, Optional[str]]]
if all([relation.database, relation.schema]):
if all([relation.database, relation.schema]) and not USE_SESSION_CONNECTION:
tables = self.connections.list_tables(
database=relation.database, schema=relation.schema # type: ignore[arg-type]
)
Expand Down
42 changes: 42 additions & 0 deletions dbt/adapters/databricks/session_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import re
import sys
from typing import Tuple
from dbt.adapters.spark.session import SessionConnectionWrapper, Connection


DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)")


class DatabricksSessionConnectionWrapper(SessionConnectionWrapper):

_is_cluster: bool
_dbr_version: Tuple[int, int]

def __init__(self, handle: Connection) -> None:
super().__init__(handle)
self._is_cluster = True
self.cursor()

@property
def dbr_version(self) -> Tuple[int, int]:
if not hasattr(self, "_dbr_version"):
wubbaLubbaDubIL marked this conversation as resolved.
Show resolved Hide resolved
if self._is_cluster:
with self._cursor() as cursor:
cursor.execute("SET spark.databricks.clusterUsageTags.sparkVersion")
results = cursor.fetchone()
if results:
dbr_version: str = results[1]

m = DBR_VERSION_REGEX.search(dbr_version)
assert m, f"Unknown DBR version: {dbr_version}"
major = int(m.group(1))
try:
minor = int(m.group(2))
except ValueError:
minor = sys.maxsize
self._dbr_version = (major, minor)
else:
# Assuming SQL Warehouse uses the latest version.
self._dbr_version = (sys.maxsize, sys.maxsize)

return self._dbr_version
34 changes: 31 additions & 3 deletions dbt/include/databricks/macros/adapters/metadata.sql
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,18 @@
{% endmacro %}

{% macro databricks__show_tables(relation) %}
{% call statement('show_tables', fetch_result=True) -%}
show tables in {{ relation|lower }}
{% endcall %}
{% set database = (relation.database | default(''))| lower | replace('`', '') %}
{% set schema = relation.schema | lower | replace('`', '') %}

{% if database and schema -%}
{% call statement('show_tables', fetch_result=True) -%}
SHOW TABLES IN {{ database }}.{{ schema }}
{% endcall %}
{% else -%}
{% call statement('show_tables', fetch_result=True) -%}
SHOW TABLES IN {{ relation | lower }}
{% endcall %}
{% endif %}

{% do return(load_result('show_tables').table) %}
{% endmacro %}
Expand Down Expand Up @@ -103,4 +112,23 @@
{% endcall %}

{% do return(load_result('get_uc_tables').table) %}
{% endmacro %}

{% macro list_schemas(database) %}
{{ return(adapter.dispatch('list_schemas', 'dbt')(database)) }}
{% endmacro %}

{% macro databricks__list_schemas(database) -%}
{% set database_clean = (database | default('')) | replace('`', '') %}
{% if database_clean -%}
{% call statement('list_schemas', fetch_result=True, auto_begin=False) %}
SHOW DATABASES IN {{ database_clean }}
{% endcall %}
{% else -%}
{% call statement('list_schemas', fetch_result=True, auto_begin=False) %}
SHOW DATABASES
{% endcall %}
{% endif -%}

{{ return(load_result('list_schemas').table) }}
{% endmacro %}
10 changes: 10 additions & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def test_http_headers(http_header):
test_http_headers(["a", "b"])
test_http_headers({"a": 1, "b": 2})

@pytest.mark.skip_profile("session_connection")
def test_invalid_custom_user_agent(self):
with pytest.raises(DbtValidationError) as excinfo:
config = self._get_config()
Expand All @@ -123,6 +124,7 @@ def test_invalid_custom_user_agent(self):

assert "Invalid invocation environment" in str(excinfo.value)

@pytest.mark.skip_profile("session_connection")
def test_custom_user_agent(self):
config = self._get_config()
adapter = DatabricksAdapter(config, get_context("spawn"))
Expand All @@ -137,12 +139,14 @@ def test_custom_user_agent(self):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load

@pytest.mark.skip_profile("session_connection")
def test_environment_single_http_header(self):
self._test_environment_http_headers(
http_headers_str='{"test":{"jobId":1,"runId":12123}}',
expected_http_headers=[("test", '{"jobId": 1, "runId": 12123}')],
)

@pytest.mark.skip_profile("session_connection")
def test_environment_multiple_http_headers(self):
self._test_environment_http_headers(
http_headers_str='{"test":{"jobId":1,"runId":12123},"dummy":{"jobId":1,"runId":12123}}',
Expand All @@ -152,6 +156,7 @@ def test_environment_multiple_http_headers(self):
],
)

@pytest.mark.skip_profile("session_connection")
def test_environment_users_http_headers_intersection_error(self):
with pytest.raises(DbtValidationError) as excinfo:
self._test_environment_http_headers(
Expand All @@ -162,6 +167,7 @@ def test_environment_users_http_headers_intersection_error(self):

assert "Intersection with reserved http_headers in keys: {'t'}" in str(excinfo.value)

@pytest.mark.skip_profile("session_connection")
def test_environment_users_http_headers_union_success(self):
self._test_environment_http_headers(
http_headers_str='{"t":{"jobId":1,"runId":12123},"d":{"jobId":1,"runId":12123}}',
Expand All @@ -173,6 +179,7 @@ def test_environment_users_http_headers_union_success(self):
],
)

@pytest.mark.skip_profile("session_connection")
def test_environment_http_headers_string(self):
self._test_environment_http_headers(
http_headers_str='{"string":"some-string"}',
Expand Down Expand Up @@ -272,6 +279,7 @@ def connect(

return connect

@pytest.mark.skip_profile("session_connection")
def test_databricks_sql_connector_connection(self):
self._test_databricks_sql_connector_connection(self._connect_func())

Expand All @@ -294,6 +302,7 @@ def _test_databricks_sql_connector_connection(self, connect):
assert len(connection.credentials.session_properties) == 1
assert connection.credentials.session_properties["spark.sql.ansi.enabled"] == "true"

@pytest.mark.skip_profile("session_connection")
def test_databricks_sql_connector_catalog_connection(self):
self._test_databricks_sql_connector_catalog_connection(
self._connect_func(expected_catalog="main")
Expand All @@ -317,6 +326,7 @@ def _test_databricks_sql_connector_catalog_connection(self, connect):
assert connection.credentials.schema == "analytics"
assert connection.credentials.database == "main"

@pytest.mark.skip_profile("session_connection")
def test_databricks_sql_connector_http_header_connection(self):
self._test_databricks_sql_connector_http_header_connection(
{"aaa": "xxx"}, self._connect_func(expected_http_headers=[("aaa", "xxx")])
Expand Down
14 changes: 14 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ deps =
[testenv:unit]
basepython = python3
commands = {envpython} -m pytest --color=yes -v {posargs} tests/unit
setenv =
DBT_DATABRICKS_SESSION_CONNECTION = False
passenv =
DBT_*
PYTEST_ADDOPTS
deps =
-r{toxinidir}/dev-requirements.txt
-r{toxinidir}/requirements.txt

[testenv:unit-session]
basepython = python3
commands = {envpython} -m pytest --color=yes -v {posargs} --profile session_connection tests/unit
setenv =
DBT_DATABRICKS_SESSION_CONNECTION = True
passenv =
DBT_*
PYTEST_ADDOPTS
Expand Down