Skip to content

Commit 54ffd41

Browse files
varun-edachali-dbxjprakash-dbmadhav-dbsaishreeeee
committed
Separate Session related functionality from Connection class (#571)
* decouple session class from existing Connection ensure maintenance of current APIs of Connection while delegating responsibility Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * add open property to Connection to ensure maintenance of existing API Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * update unit tests to address ThriftBackend through session instead of through Connection Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * chore: move session specific tests from test_client to test_session Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * formatting (black) as in CONTRIBUTING.md Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * use connection open property instead of long chain through session Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * trigger integration workflow Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix: ensure open attribute of Connection never fails in case the openSession takes long, the initialisation of the session will not complete immediately. This could make the session attribute inaccessible. If the Connection is deleted in this time, the open() check will throw because the session attribute does not exist. Thus, we default to the Connection being closed in this case. This was not an issue before because open was a direct attribute of the Connection class. Caught in the integration tests. Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix: de-complicate earlier connection open logic earlier, one of the integration tests was failing because 'session was not an attribute of Connection'. This is likely tied to a local configuration issue related to unittest that was causing an error in the test suite itself. The tests are now passing without checking for the session attribute. c676f9b Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * Revert "fix: de-complicate earlier connection open logic" This reverts commit d6b1b19. Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * [empty commit] attempt to trigger ci e2e workflow Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * Update CODEOWNERS (#562) new codeowners Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * Enhance Cursor close handling and context manager exception management to prevent server side resource leaks (#554) * Enhance Cursor close handling and context manager exception management * tests * fmt * Fix Cursor.close() to properly handle CursorAlreadyClosedError * Remove specific test message from Cursor.close() error handling * Improve error handling in connection and cursor context managers to ensure proper closure during exceptions, including KeyboardInterrupt. Add tests for nested cursor management and verify operation closure on server-side errors. * add * add Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * PECOBLR-86 improve logging on python driver (#556) * PECOBLR-86 Improve logging for debug level Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * PECOBLR-86 Improve logging for debug level Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * fixed format Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * used lazy logging Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * changed debug to error logs Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> * used lazy logging Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> --------- Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * Revert "Merge remote-tracking branch 'upstream/sea-migration' into decouple-session" This reverts commit dbb2ec5, reversing changes made to 7192f11. Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * Reapply "Merge remote-tracking branch 'upstream/sea-migration' into decouple-session" This reverts commit bdb8381. Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix: separate session opening logic from instantiation ensures correctness of self.session.open call in Connection Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix: use is_open attribute to denote session availability Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix: access thrift backend through session Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * chore: use get_handle() instead of private session attribute in client Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * formatting (black) Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> * fix: remove accidentally removed assertions Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> --------- Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com> Signed-off-by: Sai Shree Pradhan <saishree.pradhan@databricks.com> Co-authored-by: Jothi Prakash <jothi.prakash@databricks.com> Co-authored-by: Madhav Sainanee <madhav.sainanee@databricks.com> Co-authored-by: Sai Shree Pradhan <saishree.pradhan@databricks.com>
1 parent c123af3 commit 54ffd41

File tree

5 files changed

+412
-197
lines changed

5 files changed

+412
-197
lines changed

src/databricks/sql/client.py

Lines changed: 45 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from databricks.sql.types import Row, SSLOptions
4646
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
4747
from databricks.sql.experimental.oauth_persistence import OAuthPersistence
48+
from databricks.sql.session import Session
4849

4950
from databricks.sql.thrift_api.TCLIService.ttypes import (
5051
TSparkParameter,
@@ -224,66 +225,28 @@ def read(self) -> Optional[OAuthToken]:
224225
access_token_kv = {"access_token": access_token}
225226
kwargs = {**kwargs, **access_token_kv}
226227

227-
self.open = False
228-
self.host = server_hostname
229-
self.port = kwargs.get("_port", 443)
230228
self.disable_pandas = kwargs.get("_disable_pandas", False)
231229
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
230+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
231+
self._cursors = [] # type: List[Cursor]
232232

233-
auth_provider = get_python_sql_connector_auth_provider(
234-
server_hostname, **kwargs
235-
)
236-
237-
user_agent_entry = kwargs.get("user_agent_entry")
238-
if user_agent_entry is None:
239-
user_agent_entry = kwargs.get("_user_agent_entry")
240-
if user_agent_entry is not None:
241-
logger.warning(
242-
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
243-
"This parameter will be removed in the upcoming releases."
244-
)
245-
246-
if user_agent_entry:
247-
useragent_header = "{}/{} ({})".format(
248-
USER_AGENT_NAME, __version__, user_agent_entry
249-
)
250-
else:
251-
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
252-
253-
base_headers = [("User-Agent", useragent_header)]
254-
255-
self._ssl_options = SSLOptions(
256-
# Double negation is generally a bad thing, but we have to keep backward compatibility
257-
tls_verify=not kwargs.get(
258-
"_tls_no_verify", False
259-
), # by default - verify cert and host
260-
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
261-
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
262-
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
263-
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
264-
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
265-
)
266-
267-
self.thrift_backend = ThriftBackend(
268-
self.host,
269-
self.port,
233+
# Create the session
234+
self.session = Session(
235+
server_hostname,
270236
http_path,
271-
(http_headers or []) + base_headers,
272-
auth_provider,
273-
ssl_options=self._ssl_options,
274-
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
237+
http_headers,
238+
session_configuration,
239+
catalog,
240+
schema,
241+
_use_arrow_native_complex_types,
275242
**kwargs,
276243
)
244+
self.session.open()
277245

278-
self._open_session_resp = self.thrift_backend.open_session(
279-
session_configuration, catalog, schema
246+
logger.info(
247+
"Successfully opened connection with session "
248+
+ str(self.get_session_id_hex())
280249
)
281-
self._session_handle = self._open_session_resp.sessionHandle
282-
self.protocol_version = self.get_protocol_version(self._open_session_resp)
283-
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
284-
self.open = True
285-
logger.info("Successfully opened session " + str(self.get_session_id_hex()))
286-
self._cursors = [] # type: List[Cursor]
287250

288251
self.use_inline_params = self._set_use_inline_params_with_warning(
289252
kwargs.get("use_inline_params", False)
@@ -342,34 +305,32 @@ def __del__(self):
342305
logger.debug("Couldn't close unclosed connection: {}".format(e.message))
343306

344307
def get_session_id(self):
345-
return self.thrift_backend.handle_to_id(self._session_handle)
308+
"""Get the session ID from the Session object"""
309+
return self.session.get_id()
346310

347-
@staticmethod
348-
def get_protocol_version(openSessionResp):
349-
"""
350-
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
351-
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
352-
"""
353-
if (
354-
openSessionResp.sessionHandle
355-
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
356-
and openSessionResp.sessionHandle.serverProtocolVersion
357-
):
358-
return openSessionResp.sessionHandle.serverProtocolVersion
359-
return openSessionResp.serverProtocolVersion
311+
def get_session_id_hex(self):
312+
"""Get the session ID in hex format from the Session object"""
313+
return self.session.get_id_hex()
360314

361315
@staticmethod
362316
def server_parameterized_queries_enabled(protocolVersion):
363-
if (
364-
protocolVersion
365-
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
366-
):
367-
return True
368-
else:
369-
return False
317+
"""Delegate to Session class static method"""
318+
return Session.server_parameterized_queries_enabled(protocolVersion)
370319

371-
def get_session_id_hex(self):
372-
return self.thrift_backend.handle_to_hex_id(self._session_handle)
320+
@property
321+
def protocol_version(self):
322+
"""Get the protocol version from the Session object"""
323+
return self.session.protocol_version
324+
325+
@staticmethod
326+
def get_protocol_version(openSessionResp):
327+
"""Delegate to Session class static method"""
328+
return Session.get_protocol_version(openSessionResp)
329+
330+
@property
331+
def open(self) -> bool:
332+
"""Return whether the connection is open by checking if the session is open."""
333+
return self.session.is_open
373334

374335
def cursor(
375336
self,
@@ -386,7 +347,7 @@ def cursor(
386347

387348
cursor = Cursor(
388349
self,
389-
self.thrift_backend,
350+
self.session.thrift_backend,
390351
arraysize=arraysize,
391352
result_buffer_size_bytes=buffer_size_bytes,
392353
)
@@ -402,28 +363,10 @@ def _close(self, close_cursors=True) -> None:
402363
for cursor in self._cursors:
403364
cursor.close()
404365

405-
logger.info(f"Closing session {self.get_session_id_hex()}")
406-
if not self.open:
407-
logger.debug("Session appears to have been closed already")
408-
409366
try:
410-
self.thrift_backend.close_session(self._session_handle)
411-
except RequestError as e:
412-
if isinstance(e.args[1], SessionAlreadyClosedError):
413-
logger.info("Session was closed by a prior request")
414-
except DatabaseError as e:
415-
if "Invalid SessionHandle" in str(e):
416-
logger.warning(
417-
f"Attempted to close session that was already closed: {e}"
418-
)
419-
else:
420-
logger.warning(
421-
f"Attempt to close session raised an exception at the server: {e}"
422-
)
367+
self.session.close()
423368
except Exception as e:
424-
logger.error(f"Attempt to close session raised a local exception: {e}")
425-
426-
self.open = False
369+
logger.error(f"Attempt to close session raised an exception: {e}")
427370

428371
def commit(self):
429372
"""No-op because Databricks does not support transactions"""
@@ -833,7 +776,7 @@ def execute(
833776
self._close_and_clear_active_result_set()
834777
execute_response = self.thrift_backend.execute_command(
835778
operation=prepared_operation,
836-
session_handle=self.connection._session_handle,
779+
session_handle=self.connection.session.get_handle(),
837780
max_rows=self.arraysize,
838781
max_bytes=self.buffer_size_bytes,
839782
lz4_compression=self.connection.lz4_compression,
@@ -896,7 +839,7 @@ def execute_async(
896839
self._close_and_clear_active_result_set()
897840
self.thrift_backend.execute_command(
898841
operation=prepared_operation,
899-
session_handle=self.connection._session_handle,
842+
session_handle=self.connection.session.get_handle(),
900843
max_rows=self.arraysize,
901844
max_bytes=self.buffer_size_bytes,
902845
lz4_compression=self.connection.lz4_compression,
@@ -992,7 +935,7 @@ def catalogs(self) -> "Cursor":
992935
self._check_not_closed()
993936
self._close_and_clear_active_result_set()
994937
execute_response = self.thrift_backend.get_catalogs(
995-
session_handle=self.connection._session_handle,
938+
session_handle=self.connection.session.get_handle(),
996939
max_rows=self.arraysize,
997940
max_bytes=self.buffer_size_bytes,
998941
cursor=self,
@@ -1018,7 +961,7 @@ def schemas(
1018961
self._check_not_closed()
1019962
self._close_and_clear_active_result_set()
1020963
execute_response = self.thrift_backend.get_schemas(
1021-
session_handle=self.connection._session_handle,
964+
session_handle=self.connection.session.get_handle(),
1022965
max_rows=self.arraysize,
1023966
max_bytes=self.buffer_size_bytes,
1024967
cursor=self,
@@ -1051,7 +994,7 @@ def tables(
1051994
self._close_and_clear_active_result_set()
1052995

1053996
execute_response = self.thrift_backend.get_tables(
1054-
session_handle=self.connection._session_handle,
997+
session_handle=self.connection.session.get_handle(),
1055998
max_rows=self.arraysize,
1056999
max_bytes=self.buffer_size_bytes,
10571000
cursor=self,
@@ -1086,7 +1029,7 @@ def columns(
10861029
self._close_and_clear_active_result_set()
10871030

10881031
execute_response = self.thrift_backend.get_columns(
1089-
session_handle=self.connection._session_handle,
1032+
session_handle=self.connection.session.get_handle(),
10901033
max_rows=self.arraysize,
10911034
max_bytes=self.buffer_size_bytes,
10921035
cursor=self,

src/databricks/sql/session.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import logging
2+
from typing import Dict, Tuple, List, Optional, Any
3+
4+
from databricks.sql.thrift_api.TCLIService import ttypes
5+
from databricks.sql.types import SSLOptions
6+
from databricks.sql.auth.auth import get_python_sql_connector_auth_provider
7+
from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError
8+
from databricks.sql import __version__
9+
from databricks.sql import USER_AGENT_NAME
10+
from databricks.sql.thrift_backend import ThriftBackend
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class Session:
16+
def __init__(
17+
self,
18+
server_hostname: str,
19+
http_path: str,
20+
http_headers: Optional[List[Tuple[str, str]]] = None,
21+
session_configuration: Optional[Dict[str, Any]] = None,
22+
catalog: Optional[str] = None,
23+
schema: Optional[str] = None,
24+
_use_arrow_native_complex_types: Optional[bool] = True,
25+
**kwargs,
26+
) -> None:
27+
"""
28+
Create a session to a Databricks SQL endpoint or a Databricks cluster.
29+
30+
This class handles all session-related behavior and communication with the backend.
31+
"""
32+
self.is_open = False
33+
self.host = server_hostname
34+
self.port = kwargs.get("_port", 443)
35+
36+
self.session_configuration = session_configuration
37+
self.catalog = catalog
38+
self.schema = schema
39+
40+
auth_provider = get_python_sql_connector_auth_provider(
41+
server_hostname, **kwargs
42+
)
43+
44+
user_agent_entry = kwargs.get("user_agent_entry")
45+
if user_agent_entry is None:
46+
user_agent_entry = kwargs.get("_user_agent_entry")
47+
if user_agent_entry is not None:
48+
logger.warning(
49+
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
50+
"This parameter will be removed in the upcoming releases."
51+
)
52+
53+
if user_agent_entry:
54+
useragent_header = "{}/{} ({})".format(
55+
USER_AGENT_NAME, __version__, user_agent_entry
56+
)
57+
else:
58+
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
59+
60+
base_headers = [("User-Agent", useragent_header)]
61+
62+
self._ssl_options = SSLOptions(
63+
# Double negation is generally a bad thing, but we have to keep backward compatibility
64+
tls_verify=not kwargs.get(
65+
"_tls_no_verify", False
66+
), # by default - verify cert and host
67+
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
68+
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
69+
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
70+
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
71+
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
72+
)
73+
74+
self.thrift_backend = ThriftBackend(
75+
self.host,
76+
self.port,
77+
http_path,
78+
(http_headers or []) + base_headers,
79+
auth_provider,
80+
ssl_options=self._ssl_options,
81+
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
82+
**kwargs,
83+
)
84+
85+
self._handle = None
86+
self.protocol_version = None
87+
88+
def open(self) -> None:
89+
self._open_session_resp = self.thrift_backend.open_session(
90+
self.session_configuration, self.catalog, self.schema
91+
)
92+
self._handle = self._open_session_resp.sessionHandle
93+
self.protocol_version = self.get_protocol_version(self._open_session_resp)
94+
self.is_open = True
95+
logger.info("Successfully opened session " + str(self.get_id_hex()))
96+
97+
@staticmethod
98+
def get_protocol_version(openSessionResp):
99+
"""
100+
Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
101+
precedence over the serverProtocolVersion defined in the OpenSessionResponse.
102+
"""
103+
if (
104+
openSessionResp.sessionHandle
105+
and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion")
106+
and openSessionResp.sessionHandle.serverProtocolVersion
107+
):
108+
return openSessionResp.sessionHandle.serverProtocolVersion
109+
return openSessionResp.serverProtocolVersion
110+
111+
@staticmethod
112+
def server_parameterized_queries_enabled(protocolVersion):
113+
if (
114+
protocolVersion
115+
and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8
116+
):
117+
return True
118+
else:
119+
return False
120+
121+
def get_handle(self):
122+
return self._handle
123+
124+
def get_id(self):
125+
handle = self.get_handle()
126+
if handle is None:
127+
return None
128+
return self.thrift_backend.handle_to_id(handle)
129+
130+
def get_id_hex(self):
131+
handle = self.get_handle()
132+
if handle is None:
133+
return None
134+
return self.thrift_backend.handle_to_hex_id(handle)
135+
136+
def close(self) -> None:
137+
"""Close the underlying session."""
138+
logger.info(f"Closing session {self.get_id_hex()}")
139+
if not self.is_open:
140+
logger.debug("Session appears to have been closed already")
141+
return
142+
143+
try:
144+
self.thrift_backend.close_session(self.get_handle())
145+
except RequestError as e:
146+
if isinstance(e.args[1], SessionAlreadyClosedError):
147+
logger.info("Session was closed by a prior request")
148+
except DatabaseError as e:
149+
if "Invalid SessionHandle" in str(e):
150+
logger.warning(
151+
f"Attempted to close session that was already closed: {e}"
152+
)
153+
else:
154+
logger.warning(
155+
f"Attempt to close session raised an exception at the server: {e}"
156+
)
157+
except Exception as e:
158+
logger.error(f"Attempt to close session raised a local exception: {e}")
159+
160+
self.is_open = False

0 commit comments

Comments
 (0)