Skip to content

Separate Session related functionality from Connection class #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f97c81d
decouple session class from existing Connection
varun-edachali-dbx May 20, 2025
fe0af87
add open property to Connection to ensure maintenance of existing API
varun-edachali-dbx May 20, 2025
18f8f67
update unit tests to address ThriftBackend through session instead of…
varun-edachali-dbx May 20, 2025
fd8decb
chore: move session specific tests from test_client to test_session
varun-edachali-dbx May 20, 2025
1a92b77
formatting (black)
varun-edachali-dbx May 20, 2025
1b9a50a
use connection open property instead of long chain through session
varun-edachali-dbx May 20, 2025
0bf2794
trigger integration workflow
varun-edachali-dbx May 20, 2025
ff35165
fix: ensure open attribute of Connection never fails
varun-edachali-dbx May 21, 2025
0df486a
fix: de-complicate earlier connection open logic
varun-edachali-dbx May 23, 2025
63b10c3
Revert "fix: de-complicate earlier connection open logic"
varun-edachali-dbx May 23, 2025
f2b3fd5
[empty commit] attempt to trigger ci e2e workflow
varun-edachali-dbx May 23, 2025
53f16ab
Update CODEOWNERS (#562)
jprakash-db May 21, 2025
a026751
Enhance Cursor close handling and context manager exception managemen…
madhav-db May 21, 2025
0d6995c
PECOBLR-86 improve logging on python driver (#556)
saishreeeee May 22, 2025
923bbb6
Revert "Merge remote-tracking branch 'upstream/sea-migration' into de…
varun-edachali-dbx May 23, 2025
8df8c33
Reapply "Merge remote-tracking branch 'upstream/sea-migration' into d…
varun-edachali-dbx May 23, 2025
bcf5994
fix: separate session opening logic from instantiation
varun-edachali-dbx May 23, 2025
500dd0b
fix: use is_open attribute to denote session availability
varun-edachali-dbx May 23, 2025
510b454
fix: access thrift backend through session
varun-edachali-dbx May 23, 2025
634faa9
chore: use get_handle() instead of private session attribute in client
varun-edachali-dbx May 24, 2025
a32862b
formatting (black)
varun-edachali-dbx May 24, 2025
88b728d
Merge remote-tracking branch 'upstream/sea-migration' into decouple-s…
varun-edachali-dbx May 24, 2025
ed04584
Merge remote-tracking branch 'upstream/sea-migration' into decouple-s…
varun-edachali-dbx May 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 45 additions & 102 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from databricks.sql.types import Row, SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
from databricks.sql.session import Session

from databricks.sql.thrift_api.TCLIService.ttypes import (
TSparkParameter,
Expand Down Expand Up @@ -224,66 +225,28 @@ def read(self) -> Optional[OAuthToken]:
access_token_kv = {"access_token": access_token}
kwargs = {**kwargs, **access_token_kv}

self.open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)
self.disable_pandas = kwargs.get("_disable_pandas", False)
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
# Create the session
self.session = Session(
server_hostname,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
http_headers,
session_configuration,
catalog,
schema,
_use_arrow_native_complex_types,
**kwargs,
)
self.session.open()

self._open_session_resp = self.thrift_backend.open_session(
session_configuration, catalog, schema
logger.info(
"Successfully opened connection with session "
+ str(self.get_session_id_hex())
)
self._session_handle = self._open_session_resp.sessionHandle
self.protocol_version = self.get_protocol_version(self._open_session_resp)
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self.open = True
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
self._cursors = [] # type: List[Cursor]

self.use_inline_params = self._set_use_inline_params_with_warning(
kwargs.get("use_inline_params", False)
Expand Down Expand Up @@ -342,34 +305,32 @@ def __del__(self):
logger.debug("Couldn't close unclosed connection: {}".format(e.message))

def get_session_id(self):
return self.thrift_backend.handle_to_id(self._session_handle)
"""Get the session ID from the Session object"""
return self.session.get_id()

@staticmethod
def get_protocol_version(openSessionResp):
"""
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
if (
openSessionResp.sessionHandle
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
and openSessionResp.sessionHandle.serverProtocolVersion
):
return openSessionResp.sessionHandle.serverProtocolVersion
return openSessionResp.serverProtocolVersion
def get_session_id_hex(self):
"""Get the session ID in hex format from the Session object"""
return self.session.get_id_hex()

@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False
"""Delegate to Session class static method"""
return Session.server_parameterized_queries_enabled(protocolVersion)

def get_session_id_hex(self):
return self.thrift_backend.handle_to_hex_id(self._session_handle)
@property
def protocol_version(self):
"""Get the protocol version from the Session object"""
return self.session.protocol_version

@staticmethod
def get_protocol_version(openSessionResp):
"""Delegate to Session class static method"""
return Session.get_protocol_version(openSessionResp)

@property
def open(self) -> bool:
"""Return whether the connection is open by checking if the session is open."""
return self.session.is_open

def cursor(
self,
Expand All @@ -386,7 +347,7 @@ def cursor(

cursor = Cursor(
self,
self.thrift_backend,
self.session.thrift_backend,
arraysize=arraysize,
result_buffer_size_bytes=buffer_size_bytes,
)
Expand All @@ -402,28 +363,10 @@ def _close(self, close_cursors=True) -> None:
for cursor in self._cursors:
cursor.close()

logger.info(f"Closing session {self.get_session_id_hex()}")
if not self.open:
logger.debug("Session appears to have been closed already")

try:
self.thrift_backend.close_session(self._session_handle)
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
f"Attempted to close session that was already closed: {e}"
)
else:
logger.warning(
f"Attempt to close session raised an exception at the server: {e}"
)
self.session.close()
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")

self.open = False
logger.error(f"Attempt to close session raised an exception: {e}")

def commit(self):
"""No-op because Databricks does not support transactions"""
Expand Down Expand Up @@ -833,7 +776,7 @@ def execute(
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
Expand Down Expand Up @@ -896,7 +839,7 @@ def execute_async(
self._close_and_clear_active_result_set()
self.thrift_backend.execute_command(
operation=prepared_operation,
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
lz4_compression=self.connection.lz4_compression,
Expand Down Expand Up @@ -992,7 +935,7 @@ def catalogs(self) -> "Cursor":
self._check_not_closed()
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.get_catalogs(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand All @@ -1018,7 +961,7 @@ def schemas(
self._check_not_closed()
self._close_and_clear_active_result_set()
execute_response = self.thrift_backend.get_schemas(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down Expand Up @@ -1051,7 +994,7 @@ def tables(
self._close_and_clear_active_result_set()

execute_response = self.thrift_backend.get_tables(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down Expand Up @@ -1086,7 +1029,7 @@ def columns(
self._close_and_clear_active_result_set()

execute_response = self.thrift_backend.get_columns(
session_handle=self.connection._session_handle,
session_handle=self.connection.session.get_handle(),
max_rows=self.arraysize,
max_bytes=self.buffer_size_bytes,
cursor=self,
Expand Down
160 changes: 160 additions & 0 deletions src/databricks/sql/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import logging
from typing import Dict, Tuple, List, Optional, Any

from databricks.sql.thrift_api.TCLIService import ttypes
from databricks.sql.types import SSLOptions
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError
from databricks.sql import __version__
from databricks.sql import USER_AGENT_NAME
from databricks.sql.thrift_backend import ThriftBackend

logger = logging.getLogger(__name__)


class Session:
def __init__(
self,
server_hostname: str,
http_path: str,
http_headers: Optional[List[Tuple[str, str]]] = None,
session_configuration: Optional[Dict[str, Any]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
_use_arrow_native_complex_types: Optional[bool] = True,
**kwargs,
) -> None:
"""
Create a session to a Databricks SQL endpoint or a Databricks cluster.

This class handles all session-related behavior and communication with the backend.
"""
self.is_open = False
self.host = server_hostname
self.port = kwargs.get("_port", 443)

self.session_configuration = session_configuration
self.catalog = catalog
self.schema = schema

auth_provider = get_python_sql_connector_auth_provider(
server_hostname, **kwargs
)

user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the upcoming releases."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

self._ssl_options = SSLOptions(
# Double negation is generally a bad thing, but we have to keep backward compatibility
tls_verify=not kwargs.get(
"_tls_no_verify", False
), # by default - verify cert and host
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
)

self.thrift_backend = ThriftBackend(
self.host,
self.port,
http_path,
(http_headers or []) + base_headers,
auth_provider,
ssl_options=self._ssl_options,
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
**kwargs,
)

self._handle = None
self.protocol_version = None

def open(self) -> None:
self._open_session_resp = self.thrift_backend.open_session(
self.session_configuration, self.catalog, self.schema
)
self._handle = self._open_session_resp.sessionHandle
self.protocol_version = self.get_protocol_version(self._open_session_resp)
self.is_open = True
logger.info("Successfully opened session " + str(self.get_id_hex()))

@staticmethod
def get_protocol_version(openSessionResp):
"""
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
"""
if (
openSessionResp.sessionHandle
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
and openSessionResp.sessionHandle.serverProtocolVersion
):
return openSessionResp.sessionHandle.serverProtocolVersion
return openSessionResp.serverProtocolVersion

@staticmethod
def server_parameterized_queries_enabled(protocolVersion):
if (
protocolVersion
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
):
return True
else:
return False

def get_handle(self):
return self._handle

def get_id(self):
handle = self.get_handle()
if handle is None:
return None
return self.thrift_backend.handle_to_id(handle)

def get_id_hex(self):
handle = self.get_handle()
if handle is None:
return None
return self.thrift_backend.handle_to_hex_id(handle)

def close(self) -> None:
"""Close the underlying session."""
logger.info(f"Closing session {self.get_id_hex()}")
if not self.is_open:
logger.debug("Session appears to have been closed already")
return

try:
self.thrift_backend.close_session(self.get_handle())
except RequestError as e:
if isinstance(e.args[1], SessionAlreadyClosedError):
logger.info("Session was closed by a prior request")
except DatabaseError as e:
if "Invalid SessionHandle" in str(e):
logger.warning(
f"Attempted to close session that was already closed: {e}"
)
else:
logger.warning(
f"Attempt to close session raised an exception at the server: {e}"
)
except Exception as e:
logger.error(f"Attempt to close session raised a local exception: {e}")

self.is_open = False
Loading
Loading