From f97c81d28e03c09c407034992b364dc7e774a492 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 09:55:05 +0530 Subject: [PATCH 01/21] decouple session class from existing Connection ensure maintenance of current APIs of Connection while delegating responsibility Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 152 ++++++++++------------------------ src/databricks/sql/session.py | 146 ++++++++++++++++++++++++++++++++ 2 files changed, 190 insertions(+), 108 deletions(-) create mode 100644 src/databricks/sql/session.py diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index f24a6584a..0fbd20df5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -19,6 +19,8 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, + Error, + NotSupportedError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -45,6 +47,7 @@ from databricks.sql.types import Row, SSLOptions from databricks.sql.auth.auth import get_python_sql_connector_auth_provider from databricks.sql.experimental.oauth_persistence import OAuthPersistence +from databricks.sql.session import Session from databricks.sql.thrift_api.TCLIService.ttypes import ( TSparkParameter, @@ -218,66 +221,24 @@ def read(self) -> Optional[OAuthToken]: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} - self.open = False - self.host = server_hostname - self.port = kwargs.get("_port", 443) self.disable_pandas = kwargs.get("_disable_pandas", False) self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True) + self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) + self._cursors = [] # type: List[Cursor] - auth_provider = get_python_sql_connector_auth_provider( - server_hostname, **kwargs - ) - - user_agent_entry = kwargs.get("user_agent_entry") - if user_agent_entry is None: - user_agent_entry = kwargs.get("_user_agent_entry") - if user_agent_entry is not None: - logger.warning( - "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " - "This parameter will be removed in the upcoming releases." - ) - - if user_agent_entry: - useragent_header = "{}/{} ({})".format( - USER_AGENT_NAME, __version__, user_agent_entry - ) - else: - useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) - - base_headers = [("User-Agent", useragent_header)] - - self._ssl_options = SSLOptions( - # Double negation is generally a bad thing, but we have to keep backward compatibility - tls_verify=not kwargs.get( - "_tls_no_verify", False - ), # by default - verify cert and host - tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), - tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), - tls_client_cert_file=kwargs.get("_tls_client_cert_file"), - tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), - tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), - ) - - self.thrift_backend = ThriftBackend( - self.host, - self.port, + # Create the session + self.session = Session( + server_hostname, http_path, - (http_headers or []) + base_headers, - auth_provider, - ssl_options=self._ssl_options, - _use_arrow_native_complex_types=_use_arrow_native_complex_types, - **kwargs, - ) - - self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema + http_headers, + session_configuration, + catalog, + schema, + _use_arrow_native_complex_types, + **kwargs ) - self._session_handle = self._open_session_resp.sessionHandle - self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True) - self.open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) - self._cursors = [] # type: List[Cursor] + + logger.info("Successfully opened connection with session " + str(self.get_session_id_hex())) self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) @@ -318,7 +279,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - if self.open: + if self.session.open: logger.debug( "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) @@ -330,34 +291,27 @@ def __del__(self): logger.debug("Couldn't close unclosed connection: {}".format(e.message)) def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + """Get the session ID from the Session object""" + return self.session.get_session_id() - @staticmethod - def get_protocol_version(openSessionResp): - """ - Since the sessionHandle will sometimes have a serverProtocolVersion, it takes - precedence over the serverProtocolVersion defined in the OpenSessionResponse. - """ - if ( - openSessionResp.sessionHandle - and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") - and openSessionResp.sessionHandle.serverProtocolVersion - ): - return openSessionResp.sessionHandle.serverProtocolVersion - return openSessionResp.serverProtocolVersion + def get_session_id_hex(self): + """Get the session ID in hex format from the Session object""" + return self.session.get_session_id_hex() @staticmethod def server_parameterized_queries_enabled(protocolVersion): - if ( - protocolVersion - and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 - ): - return True - else: - return False + """Delegate to Session class static method""" + return Session.server_parameterized_queries_enabled(protocolVersion) - def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + @property + def protocol_version(self): + """Get the protocol version from the Session object""" + return self.session.protocol_version + + @staticmethod + def get_protocol_version(openSessionResp): + """Delegate to Session class static method""" + return Session.get_protocol_version(openSessionResp) def cursor( self, @@ -369,12 +323,12 @@ def cursor( Will throw an Error if the connection has been closed. """ - if not self.open: + if not self.session.open: raise Error("Cannot create cursor from closed connection") cursor = Cursor( self, - self.thrift_backend, + self.session.thrift_backend, arraysize=arraysize, result_buffer_size_bytes=buffer_size_bytes, ) @@ -390,28 +344,10 @@ def _close(self, close_cursors=True) -> None: for cursor in self._cursors: cursor.close() - logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: - logger.debug("Session appears to have been closed already") - try: - self.thrift_backend.close_session(self._session_handle) - except RequestError as e: - if isinstance(e.args[1], SessionAlreadyClosedError): - logger.info("Session was closed by a prior request") - except DatabaseError as e: - if "Invalid SessionHandle" in str(e): - logger.warning( - f"Attempted to close session that was already closed: {e}" - ) - else: - logger.warning( - f"Attempt to close session raised an exception at the server: {e}" - ) + self.session.close() except Exception as e: - logger.error(f"Attempt to close session raised a local exception: {e}") - - self.open = False + logger.error(f"Attempt to close session raised an exception: {e}") def commit(self): """No-op because Databricks does not support transactions""" @@ -811,7 +747,7 @@ def execute( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -874,7 +810,7 @@ def execute_async( self._close_and_clear_active_result_set() self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -970,7 +906,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -996,7 +932,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1029,7 +965,7 @@ def tables( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_tables( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1064,7 +1000,7 @@ def columns( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_columns( - session_handle=self.connection._session_handle, + session_handle=self.connection.session._session_handle, max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1493,7 +1429,7 @@ def close(self) -> None: if ( self.op_state != self.thrift_backend.CLOSED_OP_STATE and not self.has_been_closed_server_side - and self.connection.open + and self.connection.session.open ): self.thrift_backend.close_command(self.command_id) except RequestError as e: diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py new file mode 100644 index 000000000..4920550e7 --- /dev/null +++ b/src/databricks/sql/session.py @@ -0,0 +1,146 @@ +import logging +from typing import Dict, Tuple, List, Optional, Any + +from databricks.sql.thrift_api.TCLIService import ttypes +from databricks.sql.types import SSLOptions +from databricks.sql.auth.auth import get_python_sql_connector_auth_provider +from databricks.sql.exc import SessionAlreadyClosedError, DatabaseError, RequestError +from databricks.sql import __version__ +from databricks.sql import USER_AGENT_NAME +from databricks.sql.thrift_backend import ThriftBackend + +logger = logging.getLogger(__name__) + + +class Session: + def __init__( + self, + server_hostname: str, + http_path: str, + http_headers: Optional[List[Tuple[str, str]]] = None, + session_configuration: Optional[Dict[str, Any]] = None, + catalog: Optional[str] = None, + schema: Optional[str] = None, + _use_arrow_native_complex_types: Optional[bool] = True, + **kwargs, + ) -> None: + """ + Create a session to a Databricks SQL endpoint or a Databricks cluster. + + This class handles all session-related behavior and communication with the backend. + """ + self.open = False + self.host = server_hostname + self.port = kwargs.get("_port", 443) + + auth_provider = get_python_sql_connector_auth_provider( + server_hostname, **kwargs + ) + + user_agent_entry = kwargs.get("user_agent_entry") + if user_agent_entry is None: + user_agent_entry = kwargs.get("_user_agent_entry") + if user_agent_entry is not None: + logger.warning( + "[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. " + "This parameter will be removed in the upcoming releases." + ) + + if user_agent_entry: + useragent_header = "{}/{} ({})".format( + USER_AGENT_NAME, __version__, user_agent_entry + ) + else: + useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__) + + base_headers = [("User-Agent", useragent_header)] + + self._ssl_options = SSLOptions( + # Double negation is generally a bad thing, but we have to keep backward compatibility + tls_verify=not kwargs.get( + "_tls_no_verify", False + ), # by default - verify cert and host + tls_verify_hostname=kwargs.get("_tls_verify_hostname", True), + tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"), + tls_client_cert_file=kwargs.get("_tls_client_cert_file"), + tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"), + tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"), + ) + + self.thrift_backend = ThriftBackend( + self.host, + self.port, + http_path, + (http_headers or []) + base_headers, + auth_provider, + ssl_options=self._ssl_options, + _use_arrow_native_complex_types=_use_arrow_native_complex_types, + **kwargs, + ) + + self._open_session_resp = self.thrift_backend.open_session( + session_configuration, catalog, schema + ) + self._session_handle = self._open_session_resp.sessionHandle + self.protocol_version = self.get_protocol_version(self._open_session_resp) + self.open = True + logger.info("Successfully opened session " + str(self.get_session_id_hex())) + + @staticmethod + def get_protocol_version(openSessionResp): + """ + Since the sessionHandle will sometimes have a serverProtocolVersion, it takes + precedence over the serverProtocolVersion defined in the OpenSessionResponse. + """ + if ( + openSessionResp.sessionHandle + and hasattr(openSessionResp.sessionHandle, "serverProtocolVersion") + and openSessionResp.sessionHandle.serverProtocolVersion + ): + return openSessionResp.sessionHandle.serverProtocolVersion + return openSessionResp.serverProtocolVersion + + @staticmethod + def server_parameterized_queries_enabled(protocolVersion): + if ( + protocolVersion + and protocolVersion >= ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8 + ): + return True + else: + return False + + def get_session_handle(self): + return self._session_handle + + def get_session_id(self): + return self.thrift_backend.handle_to_id(self._session_handle) + + def get_session_id_hex(self): + return self.thrift_backend.handle_to_hex_id(self._session_handle) + + def close(self) -> None: + """Close the underlying session.""" + logger.info(f"Closing session {self.get_session_id_hex()}") + if not self.open: + logger.debug("Session appears to have been closed already") + return + + try: + self.thrift_backend.close_session(self._session_handle) + except RequestError as e: + if isinstance(e.args[1], SessionAlreadyClosedError): + logger.info("Session was closed by a prior request") + except DatabaseError as e: + if "Invalid SessionHandle" in str(e): + logger.warning( + f"Attempted to close session that was already closed: {e}" + ) + else: + logger.warning( + f"Attempt to close session raised an exception at the server: {e}" + ) + except Exception as e: + logger.error(f"Attempt to close session raised a local exception: {e}") + + self.open = False \ No newline at end of file From fe0af8777de92946fa87a5598ae3b3680209dd1e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:35:55 +0530 Subject: [PATCH 02/21] add open property to Connection to ensure maintenance of existing API Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 0fbd20df5..6c89ef0a1 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -313,6 +313,11 @@ def get_protocol_version(openSessionResp): """Delegate to Session class static method""" return Session.get_protocol_version(openSessionResp) + @property + def open(self) -> bool: + """Return whether the connection is open by checking if the session is open.""" + return self.session.open + def cursor( self, arraysize: int = DEFAULT_ARRAY_SIZE, From 18f8f67a985c96384bb18e4b4922b3891ca95d59 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:36:36 +0530 Subject: [PATCH 03/21] update unit tests to address ThriftBackend through session instead of through Connection Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 49 ++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index c39aeb524..58607cf49 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -80,7 +80,7 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_close_uses_the_correct_session_id(self, mock_client_class): instance = mock_client_class.return_value @@ -95,7 +95,7 @@ def test_close_uses_the_correct_session_id(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_auth_args(self, mock_client_class): # Test that the following auth args work: # token = foo, @@ -122,7 +122,7 @@ def test_auth_args(self, mock_client_class): self.assertEqual(args["http_path"], http_path) connection.close() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_http_header_passthrough(self, mock_client_class): http_headers = [("foo", "bar")] databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) @@ -130,7 +130,7 @@ def test_http_header_passthrough(self, mock_client_class): call_args = mock_client_class.call_args[0][3] self.assertIn(("foo", "bar"), call_args) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_tls_arg_passthrough(self, mock_client_class): databricks.sql.connect( **self.DUMMY_CONNECTION_ARGS, @@ -146,7 +146,7 @@ def test_tls_arg_passthrough(self, mock_client_class): self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_useragent_header(self, mock_client_class): databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) @@ -167,7 +167,7 @@ def test_useragent_header(self, mock_client_class): http_headers = mock_client_class.call_args[0][3] self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): # Test once with has_been_closed_server side, once without @@ -184,7 +184,7 @@ def test_closing_connection_closes_commands(self, mock_result_set_class): ) mock_result_set_class.return_value.close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) self.assertTrue(connection.open) @@ -194,7 +194,7 @@ def test_cant_open_cursor_on_closed_connection(self, mock_client_class): connection.cursor() self.assertIn("closed", str(cm.exception)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) @patch("%s.client.Cursor" % PACKAGE_NAME) def test_arraysize_buffer_size_passthrough( self, mock_cursor_class, mock_client_class @@ -214,7 +214,10 @@ def test_closing_result_set_with_closed_connection_soft_closes_commands(self): thrift_backend=mock_backend, execute_response=Mock(), ) - mock_connection.open = False + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = False + type(mock_connection).session = PropertyMock(return_value=mock_session) result_set.close() @@ -226,7 +229,11 @@ def test_closing_result_set_hard_closes_commands(self): mock_results_response.has_been_closed_server_side = False mock_connection = Mock() mock_thrift_backend = Mock() - mock_connection.open = True + # Setup session mock on the mock_connection + mock_session = Mock() + mock_session.open = True + type(mock_connection).session = PropertyMock(return_value=mock_session) + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) @@ -283,7 +290,7 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_context_manager_closes_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -396,7 +403,7 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_max_number_of_retries_passthrough(self, mock_client_class): databricks.sql.connect( _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS @@ -406,7 +413,7 @@ def test_max_number_of_retries_passthrough(self, mock_client_class): mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_socket_timeout_passthrough(self, mock_client_class): databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) @@ -419,7 +426,7 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_configuration_passthrough(self, mock_client_class): mock_session_config = Mock() databricks.sql.connect( @@ -431,7 +438,7 @@ def test_configuration_passthrough(self, mock_client_class): mock_session_config, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_initial_namespace_passthrough(self, mock_client_class): mock_cat = Mock() mock_schem = Mock() @@ -505,7 +512,7 @@ def test_executemany_parameter_passhthrough_and_uses_last_result_set( "last operation", ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_commit_a_noop(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) c.commit() @@ -518,7 +525,7 @@ def test_setoutputsizes_a_noop(self): cursor = client.Cursor(Mock(), Mock()) cursor.setoutputsize(1) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_rollback_not_supported(self, mock_thrift_backend_class): c = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) with self.assertRaises(NotSupportedError): @@ -603,7 +610,7 @@ def test_column_name_api(self): }, ) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_finalizer_closes_abandoned_connection(self, mock_client_class): instance = mock_client_class.return_value @@ -620,7 +627,7 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): close_session_id = instance.close_session.call_args[0][0].sessionId self.assertEqual(close_session_id, b"\x22") - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value @@ -639,7 +646,7 @@ def test_cursor_keeps_connection_alive(self, mock_client_class): @patch("%s.utils.ExecuteResponse" % PACKAGE_NAME, autospec=True) @patch("%s.client.Cursor._handle_staging_operation" % PACKAGE_NAME) - @patch("%s.client.ThriftBackend" % PACKAGE_NAME) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_staging_operation_response_is_handled( self, mock_client_class, mock_handle_staging_operation, mock_execute_response ): @@ -658,7 +665,7 @@ def test_staging_operation_response_is_handled( mock_handle_staging_operation.call_count == 1 - @patch("%s.client.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) + @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) def test_access_current_query_id(self): operation_id = "EE6A8778-21FC-438B-92D8-96AC51EE3821" From fd8decb3100cff7ec0c9ec86e4a63195f0c008d3 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 10:43:24 +0530 Subject: [PATCH 04/21] chore: move session specific tests from test_client to test_session Signed-off-by: varun-edachali-dbx --- tests/unit/test_client.py | 161 ------------------------------- tests/unit/test_session.py | 187 +++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+), 161 deletions(-) create mode 100644 tests/unit/test_session.py diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 58607cf49..ecbf3493b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -80,93 +80,6 @@ class ClientTestSuite(unittest.TestCase): "access_token": "tok", } - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_close_uses_the_correct_session_id(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - connection.close() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_auth_args(self, mock_client_class): - # Test that the following auth args work: - # token = foo, - # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True - connection_args = [ - { - "server_hostname": "foo", - "http_path": None, - "access_token": "tok", - }, - { - "server_hostname": "foo", - "http_path": None, - "_tls_client_cert_file": "something", - "_use_cert_as_auth": True, - "access_token": None, - }, - ] - - for args in connection_args: - connection = databricks.sql.connect(**args) - host, port, http_path, *_ = mock_client_class.call_args[0] - self.assertEqual(args["server_hostname"], host) - self.assertEqual(args["http_path"], http_path) - connection.close() - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_http_header_passthrough(self, mock_client_class): - http_headers = [("foo", "bar")] - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) - - call_args = mock_client_class.call_args[0][3] - self.assertIn(("foo", "bar"), call_args) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_tls_arg_passthrough(self, mock_client_class): - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, - _tls_verify_hostname="hostname", - _tls_trusted_ca_file="trusted ca file", - _tls_client_cert_key_file="trusted client cert", - _tls_client_cert_key_password="key password", - ) - - kwargs = mock_client_class.call_args[1] - self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") - self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") - self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") - self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_useragent_header(self, mock_client_class): - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - http_headers = mock_client_class.call_args[0][3] - user_agent_header = ( - "User-Agent", - "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), - ) - self.assertIn(user_agent_header, http_headers) - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") - user_agent_header_with_entry = ( - "User-Agent", - "{}/{} ({})".format( - databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" - ), - ) - http_headers = mock_client_class.call_args[0][3] - self.assertIn(user_agent_header_with_entry, http_headers) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME, ThriftBackendMockFactory.new()) @patch("%s.client.ResultSet" % PACKAGE_NAME) def test_closing_connection_closes_commands(self, mock_result_set_class): @@ -290,21 +203,6 @@ def test_context_manager_closes_cursor(self): cursor.close = mock_close mock_close.assert_called_once_with() - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_context_manager_closes_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: - pass - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - def dict_product(self, dicts): """ Generate cartesion product of values in input dictionary, outputting a dictionary @@ -403,21 +301,6 @@ def test_cancel_command_will_issue_warning_for_cancel_with_no_executing_command( self.assertTrue(logger_instance.warning.called) self.assertFalse(mock_thrift_backend.cancel_command.called) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_max_number_of_retries_passthrough(self, mock_client_class): - databricks.sql.connect( - _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 - ) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_socket_timeout_passthrough(self, mock_client_class): - databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) - self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) - def test_version_is_canonical(self): version = databricks.sql.__version__ canonical_version_re = ( @@ -426,33 +309,6 @@ def test_version_is_canonical(self): ) self.assertIsNotNone(re.match(canonical_version_re, version)) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_configuration_passthrough(self, mock_client_class): - mock_session_config = Mock() - databricks.sql.connect( - session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS - ) - - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][0], - mock_session_config, - ) - - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_initial_namespace_passthrough(self, mock_client_class): - mock_cat = Mock() - mock_schem = Mock() - - databricks.sql.connect( - **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][1], mock_cat - ) - self.assertEqual( - mock_client_class.return_value.open_session.call_args[0][2], mock_schem - ) - def test_execute_parameter_passthrough(self): mock_thrift_backend = ThriftBackendMockFactory.new() cursor = client.Cursor(Mock(), mock_thrift_backend) @@ -610,23 +466,6 @@ def test_column_name_api(self): }, ) - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) - def test_finalizer_closes_abandoned_connection(self, mock_client_class): - instance = mock_client_class.return_value - - mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() - mock_open_session_resp.sessionHandle.sessionId = b"\x22" - instance.open_session.return_value = mock_open_session_resp - - databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) - - # not strictly necessary as the refcount is 0, but just to be sure - gc.collect() - - # Check the close session request has an id of x22 - close_session_id = instance.close_session.call_args[0][0].sessionId - self.assertEqual(close_session_id, b"\x22") - @patch("%s.session.ThriftBackend" % PACKAGE_NAME) def test_cursor_keeps_connection_alive(self, mock_client_class): instance = mock_client_class.return_value diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 000000000..6a49abef6 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,187 @@ +import unittest +from unittest.mock import patch, MagicMock, Mock, PropertyMock +import gc + +from databricks.sql.thrift_api.TCLIService.ttypes import ( + TOpenSessionResp, +) + +import databricks.sql + + +class SessionTestSuite(unittest.TestCase): + """ + Unit tests for Session functionality + """ + + PACKAGE_NAME = "databricks.sql" + DUMMY_CONNECTION_ARGS = { + "server_hostname": "foo", + "http_path": "dummy_path", + "access_token": "tok", + } + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_close_uses_the_correct_session_id(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + connection = databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + connection.close() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_auth_args(self, mock_client_class): + # Test that the following auth args work: + # token = foo, + # token = None, _tls_client_cert_file = something, _use_cert_as_auth = True + connection_args = [ + { + "server_hostname": "foo", + "http_path": None, + "access_token": "tok", + }, + { + "server_hostname": "foo", + "http_path": None, + "_tls_client_cert_file": "something", + "_use_cert_as_auth": True, + "access_token": None, + }, + ] + + for args in connection_args: + connection = databricks.sql.connect(**args) + host, port, http_path, *_ = mock_client_class.call_args[0] + self.assertEqual(args["server_hostname"], host) + self.assertEqual(args["http_path"], http_path) + connection.close() + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_http_header_passthrough(self, mock_client_class): + http_headers = [("foo", "bar")] + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, http_headers=http_headers) + + call_args = mock_client_class.call_args[0][3] + self.assertIn(("foo", "bar"), call_args) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_tls_arg_passthrough(self, mock_client_class): + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, + _tls_verify_hostname="hostname", + _tls_trusted_ca_file="trusted ca file", + _tls_client_cert_key_file="trusted client cert", + _tls_client_cert_key_password="key password", + ) + + kwargs = mock_client_class.call_args[1] + self.assertEqual(kwargs["_tls_verify_hostname"], "hostname") + self.assertEqual(kwargs["_tls_trusted_ca_file"], "trusted ca file") + self.assertEqual(kwargs["_tls_client_cert_key_file"], "trusted client cert") + self.assertEqual(kwargs["_tls_client_cert_key_password"], "key password") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_useragent_header(self, mock_client_class): + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + http_headers = mock_client_class.call_args[0][3] + user_agent_header = ( + "User-Agent", + "{}/{}".format(databricks.sql.USER_AGENT_NAME, databricks.sql.__version__), + ) + self.assertIn(user_agent_header, http_headers) + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar") + user_agent_header_with_entry = ( + "User-Agent", + "{}/{} ({})".format( + databricks.sql.USER_AGENT_NAME, databricks.sql.__version__, "foobar" + ), + ) + http_headers = mock_client_class.call_args[0][3] + self.assertIn(user_agent_header_with_entry, http_headers) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_context_manager_closes_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + with databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) as connection: + pass + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_max_number_of_retries_passthrough(self, mock_client_class): + databricks.sql.connect( + _retry_stop_after_attempts_count=54, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.call_args[1]["_retry_stop_after_attempts_count"], 54 + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_socket_timeout_passthrough(self, mock_client_class): + databricks.sql.connect(_socket_timeout=234, **self.DUMMY_CONNECTION_ARGS) + self.assertEqual(mock_client_class.call_args[1]["_socket_timeout"], 234) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_configuration_passthrough(self, mock_client_class): + mock_session_config = Mock() + databricks.sql.connect( + session_configuration=mock_session_config, **self.DUMMY_CONNECTION_ARGS + ) + + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][0], + mock_session_config, + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_initial_namespace_passthrough(self, mock_client_class): + mock_cat = Mock() + mock_schem = Mock() + + databricks.sql.connect( + **self.DUMMY_CONNECTION_ARGS, catalog=mock_cat, schema=mock_schem + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][1], mock_cat + ) + self.assertEqual( + mock_client_class.return_value.open_session.call_args[0][2], mock_schem + ) + + @patch("%s.session.ThriftBackend" % PACKAGE_NAME) + def test_finalizer_closes_abandoned_connection(self, mock_client_class): + instance = mock_client_class.return_value + + mock_open_session_resp = MagicMock(spec=TOpenSessionResp)() + mock_open_session_resp.sessionHandle.sessionId = b"\x22" + instance.open_session.return_value = mock_open_session_resp + + databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS) + + # not strictly necessary as the refcount is 0, but just to be sure + gc.collect() + + # Check the close session request has an id of x22 + close_session_id = instance.close_session.call_args[0][0].sessionId + self.assertEqual(close_session_id, b"\x22") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 1a92b7782910d8127baae36e671a821a453446ab Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 11:14:16 +0530 Subject: [PATCH 05/21] formatting (black) as in CONTRIBUTING.md Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 9 ++++++--- src/databricks/sql/session.py | 6 +++--- tests/unit/test_client.py | 2 +- tests/unit/test_session.py | 4 ++-- tests/unit/test_thrift_backend.py | 4 +++- 5 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 6c89ef0a1..098aa0548 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -235,10 +235,13 @@ def read(self) -> Optional[OAuthToken]: catalog, schema, _use_arrow_native_complex_types, - **kwargs + **kwargs, + ) + + logger.info( + "Successfully opened connection with session " + + str(self.get_session_id_hex()) ) - - logger.info("Successfully opened connection with session " + str(self.get_session_id_hex())) self.use_inline_params = self._set_use_inline_params_with_warning( kwargs.get("use_inline_params", False) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 4920550e7..a308b71d5 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -26,13 +26,13 @@ def __init__( ) -> None: """ Create a session to a Databricks SQL endpoint or a Databricks cluster. - + This class handles all session-related behavior and communication with the backend. """ self.open = False self.host = server_hostname self.port = kwargs.get("_port", 443) - + auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs ) @@ -143,4 +143,4 @@ def close(self) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - self.open = False \ No newline at end of file + self.open = False diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index ecbf3493b..b67101943 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -146,7 +146,7 @@ def test_closing_result_set_hard_closes_commands(self): mock_session = Mock() mock_session.open = True type(mock_connection).session = PropertyMock(return_value=mock_session) - + result_set = client.ResultSet( mock_connection, mock_results_response, mock_thrift_backend ) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 6a49abef6..eb392a229 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -11,7 +11,7 @@ class SessionTestSuite(unittest.TestCase): """ - Unit tests for Session functionality + Unit tests for Session functionality """ PACKAGE_NAME = "databricks.sql" @@ -184,4 +184,4 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/tests/unit/test_thrift_backend.py b/tests/unit/test_thrift_backend.py index 7fe318446..458ea9a82 100644 --- a/tests/unit/test_thrift_backend.py +++ b/tests/unit/test_thrift_backend.py @@ -86,7 +86,9 @@ def test_make_request_checks_thrift_status_code(self): def _make_type_desc(self, type): return ttypes.TTypeDesc( - types=[ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type))] + types=[ + ttypes.TTypeEntry(primitiveEntry=ttypes.TPrimitiveTypeEntry(type=type)) + ] ) def _make_fake_thrift_backend(self): From 1b9a50acc0631df22d8e8cd04f2d5021f2505dbb Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 16:18:44 +0530 Subject: [PATCH 06/21] use connection open property instead of long chain through session Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 098aa0548..54a097641 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -282,7 +282,7 @@ def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): - if self.session.open: + if self.open: logger.debug( "Closing unclosed connection for session " "{}".format(self.get_session_id_hex()) @@ -331,7 +331,7 @@ def cursor( Will throw an Error if the connection has been closed. """ - if not self.session.open: + if not self.open: raise Error("Cannot create cursor from closed connection") cursor = Cursor( @@ -1437,7 +1437,7 @@ def close(self) -> None: if ( self.op_state != self.thrift_backend.CLOSED_OP_STATE and not self.has_been_closed_server_side - and self.connection.session.open + and self.connection.open ): self.thrift_backend.close_command(self.command_id) except RequestError as e: From 0bf27940240b02b80838ba342f3bc3255229216a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Tue, 20 May 2025 16:47:02 +0530 Subject: [PATCH 07/21] trigger integration workflow Signed-off-by: varun-edachali-dbx From ff35165b28f53988c4077fbf4aaff5eb8bce68c5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Wed, 21 May 2025 10:36:31 +0530 Subject: [PATCH 08/21] fix: ensure open attribute of Connection never fails in case the openSession takes long, the initialisation of the session will not complete immediately. This could make the session attribute inaccessible. If the Connection is deleted in this time, the open() check will throw because the session attribute does not exist. Thus, we default to the Connection being closed in this case. This was not an issue before because open was a direct attribute of the Connection class. Caught in the integration tests. Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 54a097641..4ddc5069b 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -19,8 +19,6 @@ OperationalError, SessionAlreadyClosedError, CursorAlreadyClosedError, - Error, - NotSupportedError, ) from databricks.sql.thrift_api.TCLIService import ttypes from databricks.sql.thrift_backend import ThriftBackend @@ -319,7 +317,9 @@ def get_protocol_version(openSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.open + # NOTE: we have to check for the existence of session in case the __del__ is called + # before the session is instantiated + return hasattr(self, "session") and self.session.open def cursor( self, From 0df486ae08e9b059f6c3f12e08ae48c4f5a23a3e Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 04:05:12 +0000 Subject: [PATCH 09/21] fix: de-complicate earlier connection open logic earlier, one of the integration tests was failing because 'session was not an attribute of Connection'. This is likely tied to a local configuration issue related to unittest that was causing an error in the test suite itself. The tests are now passing without checking for the session attribute. https://github.com/databricks/databricks-sql-python/pull/567/commits/c676f9b0281cc3e4fe9c6d8216cc62fc75eade3b Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 4 +--- src/databricks/sql/session.py | 9 ++++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4ddc5069b..c227fdf19 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -317,9 +317,7 @@ def get_protocol_version(openSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - # NOTE: we have to check for the existence of session in case the __del__ is called - # before the session is instantiated - return hasattr(self, "session") and self.session.open + return self.session.is_open def cursor( self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index a308b71d5..6beb694eb 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -29,7 +29,7 @@ def __init__( This class handles all session-related behavior and communication with the backend. """ - self.open = False + self.is_open = False self.host = server_hostname self.port = kwargs.get("_port", 443) @@ -77,13 +77,12 @@ def __init__( _use_arrow_native_complex_types=_use_arrow_native_complex_types, **kwargs, ) - self._open_session_resp = self.thrift_backend.open_session( session_configuration, catalog, schema ) self._session_handle = self._open_session_resp.sessionHandle self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.open = True + self.is_open = True logger.info("Successfully opened session " + str(self.get_session_id_hex())) @staticmethod @@ -122,7 +121,7 @@ def get_session_id_hex(self): def close(self) -> None: """Close the underlying session.""" logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: + if not self.is_open: logger.debug("Session appears to have been closed already") return @@ -143,4 +142,4 @@ def close(self) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - self.open = False + self.is_open = False From 63b10c3bdefd876d077c8d6c8bf4e9b1809412a7 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 04:48:59 +0000 Subject: [PATCH 10/21] Revert "fix: de-complicate earlier connection open logic" This reverts commit d6b1b196c98a6e9d8e593a88c34bbde010519ef4. Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 4 +++- src/databricks/sql/session.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c227fdf19..4ddc5069b 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -317,7 +317,9 @@ def get_protocol_version(openSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - return self.session.is_open + # NOTE: we have to check for the existence of session in case the __del__ is called + # before the session is instantiated + return hasattr(self, "session") and self.session.open def cursor( self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 6beb694eb..a308b71d5 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -29,7 +29,7 @@ def __init__( This class handles all session-related behavior and communication with the backend. """ - self.is_open = False + self.open = False self.host = server_hostname self.port = kwargs.get("_port", 443) @@ -77,12 +77,13 @@ def __init__( _use_arrow_native_complex_types=_use_arrow_native_complex_types, **kwargs, ) + self._open_session_resp = self.thrift_backend.open_session( session_configuration, catalog, schema ) self._session_handle = self._open_session_resp.sessionHandle self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.is_open = True + self.open = True logger.info("Successfully opened session " + str(self.get_session_id_hex())) @staticmethod @@ -121,7 +122,7 @@ def get_session_id_hex(self): def close(self) -> None: """Close the underlying session.""" logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.is_open: + if not self.open: logger.debug("Session appears to have been closed already") return @@ -142,4 +143,4 @@ def close(self) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - self.is_open = False + self.open = False From f2b3fd5f327be4a3f8a85bf8ccb1f31d1e02b745 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 05:00:46 +0000 Subject: [PATCH 11/21] [empty commit] attempt to trigger ci e2e workflow Signed-off-by: varun-edachali-dbx From 53f16ab627187224760750da57709f4b7130133a Mon Sep 17 00:00:00 2001 From: Jothi Prakash Date: Wed, 21 May 2025 11:57:04 +0530 Subject: [PATCH 12/21] Update CODEOWNERS (#562) new codeowners Signed-off-by: varun-edachali-dbx --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0d074c07b..11d5aeb0a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,4 +2,4 @@ # the repo. Unless a later match takes precedence, these # users will be requested for review when someone opens a # pull request. -* @deeksha-db @samikshya-db @jprakash-db @yunbodeng-db @jackyhu-db @benc-db +* @deeksha-db @samikshya-db @jprakash-db @jackyhu-db @madhav-db @gopalldb @jayantsing-db @vikrantpuppala @shivam2680 From a026751b2d7189e18ef538ebc98f015fd27e4771 Mon Sep 17 00:00:00 2001 From: Madhav Sainanee Date: Wed, 21 May 2025 14:28:10 +0530 Subject: [PATCH 13/21] Enhance Cursor close handling and context manager exception management to prevent server side resource leaks (#554) * Enhance Cursor close handling and context manager exception management * tests * fmt * Fix Cursor.close() to properly handle CursorAlreadyClosedError * Remove specific test message from Cursor.close() error handling * Improve error handling in connection and cursor context managers to ensure proper closure during exceptions, including KeyboardInterrupt. Add tests for nested cursor management and verify operation closure on server-side errors. * add * add Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 33 ++++++++++- tests/e2e/test_driver.py | 100 ++++++++++++++++++++++++++++++- tests/unit/test_client.py | 111 +++++++++++++++++++++++++++++++++++ 3 files changed, 238 insertions(+), 6 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4ddc5069b..9f095d573 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -277,7 +277,13 @@ def __enter__(self) -> "Connection": return self def __exit__(self, exc_type, exc_value, traceback): - self.close() + try: + self.close() + except BaseException as e: + logger.warning(f"Exception during connection close in __exit__: {e}") + if exc_type is None: + raise + return False def __del__(self): if self.open: @@ -400,7 +406,14 @@ def __enter__(self) -> "Cursor": return self def __exit__(self, exc_type, exc_value, traceback): - self.close() + try: + logger.debug("Cursor context manager exiting, calling close()") + self.close() + except BaseException as e: + logger.warning(f"Exception during cursor close in __exit__: {e}") + if exc_type is None: + raise + return False def __iter__(self): if self.active_result_set: @@ -1107,7 +1120,21 @@ def cancel(self) -> None: def close(self) -> None: """Close cursor""" self.open = False - self.active_op_handle = None + + # Close active operation handle if it exists + if self.active_op_handle: + try: + self.thrift_backend.close_command(self.active_op_handle) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + else: + logging.warning(f"Error closing operation handle: {e}") + except Exception as e: + logging.warning(f"Error closing operation handle: {e}") + finally: + self.active_op_handle = None + if self.active_result_set: self._close_and_clear_active_result_set() diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index cfd561400..440d4efb3 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,7 +50,7 @@ from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin -from databricks.sql.exc import SessionAlreadyClosedError +from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError log = logging.getLogger(__name__) @@ -820,7 +820,6 @@ def test_close_connection_closes_cursors(self): ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True - # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( operationHandle=ars.command_id, getProgressUpdate=False @@ -847,9 +846,104 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): with self.connection() as conn: # First .close() call is explicit here conn.close() - assert "Session appears to have been closed already" in caplog.text + conn = None + try: + with pytest.raises(KeyboardInterrupt): + with self.connection() as c: + conn = c + raise KeyboardInterrupt("Simulated interrupt") + finally: + if conn is not None: + assert not conn.open, "Connection should be closed after KeyboardInterrupt" + + def test_cursor_close_properly_closes_operation(self): + """Test that Cursor.close() properly closes the active operation handle on the server.""" + with self.connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SELECT 1 AS test") + assert cursor.active_op_handle is not None + cursor.close() + assert cursor.active_op_handle is None + assert not cursor.open + finally: + if cursor.open: + cursor.close() + + conn = None + cursor = None + try: + with self.connection() as c: + conn = c + with pytest.raises(KeyboardInterrupt): + with conn.cursor() as cur: + cursor = cur + raise KeyboardInterrupt("Simulated interrupt") + finally: + if cursor is not None: + assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" + + def test_nested_cursor_context_managers(self): + """Test that nested cursor context managers properly close operations on the server.""" + with self.connection() as conn: + with conn.cursor() as cursor1: + cursor1.execute("SELECT 1 AS test1") + assert cursor1.active_op_handle is not None + + with conn.cursor() as cursor2: + cursor2.execute("SELECT 2 AS test2") + assert cursor2.active_op_handle is not None + + # After inner context manager exit, cursor2 should be not open + assert not cursor2.open + assert cursor2.active_op_handle is None + + # After outer context manager exit, cursor1 should be not open + assert not cursor1.open + assert cursor1.active_op_handle is None + + def test_cursor_error_handling(self): + """Test that cursor close handles errors properly to prevent orphaned operations.""" + with self.connection() as conn: + cursor = conn.cursor() + + cursor.execute("SELECT 1 AS test") + + op_handle = cursor.active_op_handle + + assert op_handle is not None + + # Manually close the operation to simulate server-side closure + conn.thrift_backend.close_command(op_handle) + + cursor.close() + + assert not cursor.open + + def test_result_set_close(self): + """Test that ResultSet.close() properly closes operations on the server and handles state correctly.""" + with self.connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SELECT * FROM RANGE(10)") + + result_set = cursor.active_result_set + assert result_set is not None + + initial_op_state = result_set.op_state + + result_set.close() + + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state != initial_op_state + + # Closing the result set again should be a no-op and not raise exceptions + result_set.close() + finally: + cursor.close() + # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index b67101943..30b1fd96a 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -20,6 +20,7 @@ import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row from tests.unit.test_fetches import FetchTests @@ -522,6 +523,116 @@ def test_access_current_query_id(self): cursor.close() self.assertIsNone(cursor.query_id) + def test_cursor_close_handles_exception(self): + """Test that Cursor.close() handles exceptions from close_command properly.""" + mock_backend = Mock() + mock_connection = Mock() + mock_op_handle = Mock() + + mock_backend.close_command.side_effect = Exception("Test error") + + cursor = client.Cursor(mock_connection, mock_backend) + cursor.active_op_handle = mock_op_handle + + cursor.close() + + mock_backend.close_command.assert_called_once_with(mock_op_handle) + + self.assertIsNone(cursor.active_op_handle) + + self.assertFalse(cursor.open) + + def test_cursor_context_manager_handles_exit_exception(self): + """Test that cursor's context manager handles exceptions during __exit__.""" + mock_backend = Mock() + mock_connection = Mock() + + cursor = client.Cursor(mock_connection, mock_backend) + original_close = cursor.close + cursor.close = Mock(side_effect=Exception("Test error during close")) + + try: + with cursor: + raise ValueError("Test error inside context") + except ValueError: + pass + + cursor.close.assert_called_once() + + def test_connection_close_handles_cursor_close_exception(self): + """Test that _close handles exceptions from cursor.close() properly.""" + cursors_closed = [] + + def mock_close_with_exception(): + cursors_closed.append(1) + raise Exception("Test error during close") + + cursor1 = Mock() + cursor1.close = mock_close_with_exception + + def mock_close_normal(): + cursors_closed.append(2) + + cursor2 = Mock() + cursor2.close = mock_close_normal + + mock_backend = Mock() + mock_session_handle = Mock() + + try: + for cursor in [cursor1, cursor2]: + try: + cursor.close() + except Exception: + pass + + mock_backend.close_session(mock_session_handle) + except Exception as e: + self.fail(f"Connection close should handle exceptions: {e}") + + self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") + + def test_resultset_close_handles_cursor_already_closed_error(self): + """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" + result_set = client.ResultSet.__new__(client.ResultSet) + result_set.thrift_backend = Mock() + result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' + result_set.connection = Mock() + result_set.connection.open = True + result_set.op_state = 'RUNNING' + result_set.has_been_closed_server_side = False + result_set.command_id = Mock() + + class MockRequestError(Exception): + def __init__(self): + self.args = ["Error message", CursorAlreadyClosedError()] + + result_set.thrift_backend.close_command.side_effect = MockRequestError() + + original_close = client.ResultSet.close + try: + try: + if ( + result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + and not result_set.has_been_closed_server_side + and result_set.connection.open + ): + result_set.thrift_backend.close_command(result_set.command_id) + except MockRequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + pass + finally: + result_set.has_been_closed_server_side = True + result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + + result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) + + assert result_set.has_been_closed_server_side is True + + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + finally: + pass + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) From 0d6995c988360e4f47319f85d65ca57588c5a134 Mon Sep 17 00:00:00 2001 From: Sai Shree Pradhan Date: Fri, 23 May 2025 00:31:40 +0530 Subject: [PATCH 14/21] PECOBLR-86 improve logging on python driver (#556) * PECOBLR-86 Improve logging for debug level Signed-off-by: Sai Shree Pradhan * PECOBLR-86 Improve logging for debug level Signed-off-by: Sai Shree Pradhan * fixed format Signed-off-by: Sai Shree Pradhan * used lazy logging Signed-off-by: Sai Shree Pradhan * changed debug to error logs Signed-off-by: Sai Shree Pradhan * used lazy logging Signed-off-by: Sai Shree Pradhan --------- Signed-off-by: Sai Shree Pradhan Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 9 +++++++++ src/databricks/sql/thrift_backend.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 9f095d573..82c74a1d5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -215,6 +215,12 @@ def read(self) -> Optional[OAuthToken]: # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage + logger.debug( + "Connection.__init__(server_hostname=%s, http_path=%s)", + server_hostname, + http_path, + ) + if access_token: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} @@ -744,6 +750,9 @@ def execute( :returns self """ + logger.debug( + "Cursor.execute(operation=%s, parameters=%s)", operation, parameters + ) param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 2e3478d77..e3dc38ad5 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -131,6 +131,13 @@ def __init__( # max_download_threads # Number of threads for handling cloud fetch downloads. Defaults to 10 + logger.debug( + "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", + server_hostname, + port, + http_path, + ) + port = port or 443 if kwargs.get("_connection_uri"): uri = kwargs.get("_connection_uri") @@ -390,6 +397,8 @@ def attempt_request(attempt): # TODO: don't use exception handling for GOS polling... + logger.error("ThriftBackend.attempt_request: HTTPError: %s", err) + gos_name = TCLIServiceClient.GetOperationStatus.__name__ if method.__name__ == gos_name: delay_default = ( @@ -434,6 +443,7 @@ def attempt_request(attempt): else: logger.warning(log_string) except Exception as err: + logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) error_message = ThriftBackend._extract_error_message_from_headers( @@ -888,6 +898,12 @@ def execute_command( ): assert session_handle is not None + logger.debug( + "ThriftBackend.execute_command(operation=%s, session_handle=%s)", + operation, + session_handle, + ) + spark_arrow_types = ttypes.TSparkArrowTypes( timestampAsArrow=self._use_arrow_native_timestamps, decimalAsArrow=self._use_arrow_native_decimals, @@ -1074,6 +1090,7 @@ def fetch_results( return queue, resp.hasMoreRows def close_command(self, op_handle): + logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) req = ttypes.TCloseOperationReq(operationHandle=op_handle) resp = self.make_request(self._client.CloseOperation, req) return resp.status From 923bbb62d2ff12e6592e925f4c75119707e4a4cc Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 06:01:45 +0000 Subject: [PATCH 15/21] Revert "Merge remote-tracking branch 'upstream/sea-migration' into decouple-session" This reverts commit dbb2ec52306b91072a2ee842270c7113aece9aff, reversing changes made to 7192f117279d4f0adcbafcdf2238c18663324515. Signed-off-by: varun-edachali-dbx --- .github/CODEOWNERS | 2 +- src/databricks/sql/client.py | 42 +--------- src/databricks/sql/thrift_backend.py | 17 ---- tests/e2e/test_driver.py | 100 +----------------------- tests/unit/test_client.py | 111 --------------------------- 5 files changed, 7 insertions(+), 265 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 11d5aeb0a..0d074c07b 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,4 +2,4 @@ # the repo. Unless a later match takes precedence, these # users will be requested for review when someone opens a # pull request. -* @deeksha-db @samikshya-db @jprakash-db @jackyhu-db @madhav-db @gopalldb @jayantsing-db @vikrantpuppala @shivam2680 +* @deeksha-db @samikshya-db @jprakash-db @yunbodeng-db @jackyhu-db @benc-db diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 82c74a1d5..4ddc5069b 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -215,12 +215,6 @@ def read(self) -> Optional[OAuthToken]: # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage - logger.debug( - "Connection.__init__(server_hostname=%s, http_path=%s)", - server_hostname, - http_path, - ) - if access_token: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} @@ -283,13 +277,7 @@ def __enter__(self) -> "Connection": return self def __exit__(self, exc_type, exc_value, traceback): - try: - self.close() - except BaseException as e: - logger.warning(f"Exception during connection close in __exit__: {e}") - if exc_type is None: - raise - return False + self.close() def __del__(self): if self.open: @@ -412,14 +400,7 @@ def __enter__(self) -> "Cursor": return self def __exit__(self, exc_type, exc_value, traceback): - try: - logger.debug("Cursor context manager exiting, calling close()") - self.close() - except BaseException as e: - logger.warning(f"Exception during cursor close in __exit__: {e}") - if exc_type is None: - raise - return False + self.close() def __iter__(self): if self.active_result_set: @@ -750,9 +731,6 @@ def execute( :returns self """ - logger.debug( - "Cursor.execute(operation=%s, parameters=%s)", operation, parameters - ) param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: @@ -1129,21 +1107,7 @@ def cancel(self) -> None: def close(self) -> None: """Close cursor""" self.open = False - - # Close active operation handle if it exists - if self.active_op_handle: - try: - self.thrift_backend.close_command(self.active_op_handle) - except RequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - logger.info("Operation was canceled by a prior request") - else: - logging.warning(f"Error closing operation handle: {e}") - except Exception as e: - logging.warning(f"Error closing operation handle: {e}") - finally: - self.active_op_handle = None - + self.active_op_handle = None if self.active_result_set: self._close_and_clear_active_result_set() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index e3dc38ad5..2e3478d77 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -131,13 +131,6 @@ def __init__( # max_download_threads # Number of threads for handling cloud fetch downloads. Defaults to 10 - logger.debug( - "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", - server_hostname, - port, - http_path, - ) - port = port or 443 if kwargs.get("_connection_uri"): uri = kwargs.get("_connection_uri") @@ -397,8 +390,6 @@ def attempt_request(attempt): # TODO: don't use exception handling for GOS polling... - logger.error("ThriftBackend.attempt_request: HTTPError: %s", err) - gos_name = TCLIServiceClient.GetOperationStatus.__name__ if method.__name__ == gos_name: delay_default = ( @@ -443,7 +434,6 @@ def attempt_request(attempt): else: logger.warning(log_string) except Exception as err: - logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) error_message = ThriftBackend._extract_error_message_from_headers( @@ -898,12 +888,6 @@ def execute_command( ): assert session_handle is not None - logger.debug( - "ThriftBackend.execute_command(operation=%s, session_handle=%s)", - operation, - session_handle, - ) - spark_arrow_types = ttypes.TSparkArrowTypes( timestampAsArrow=self._use_arrow_native_timestamps, decimalAsArrow=self._use_arrow_native_decimals, @@ -1090,7 +1074,6 @@ def fetch_results( return queue, resp.hasMoreRows def close_command(self, op_handle): - logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) req = ttypes.TCloseOperationReq(operationHandle=op_handle) resp = self.make_request(self._client.CloseOperation, req) return resp.status diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 440d4efb3..cfd561400 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,7 +50,7 @@ from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin -from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError +from databricks.sql.exc import SessionAlreadyClosedError log = logging.getLogger(__name__) @@ -820,6 +820,7 @@ def test_close_connection_closes_cursors(self): ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True + # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( operationHandle=ars.command_id, getProgressUpdate=False @@ -846,103 +847,8 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): with self.connection() as conn: # First .close() call is explicit here conn.close() - assert "Session appears to have been closed already" in caplog.text - - conn = None - try: - with pytest.raises(KeyboardInterrupt): - with self.connection() as c: - conn = c - raise KeyboardInterrupt("Simulated interrupt") - finally: - if conn is not None: - assert not conn.open, "Connection should be closed after KeyboardInterrupt" - - def test_cursor_close_properly_closes_operation(self): - """Test that Cursor.close() properly closes the active operation handle on the server.""" - with self.connection() as conn: - cursor = conn.cursor() - try: - cursor.execute("SELECT 1 AS test") - assert cursor.active_op_handle is not None - cursor.close() - assert cursor.active_op_handle is None - assert not cursor.open - finally: - if cursor.open: - cursor.close() - - conn = None - cursor = None - try: - with self.connection() as c: - conn = c - with pytest.raises(KeyboardInterrupt): - with conn.cursor() as cur: - cursor = cur - raise KeyboardInterrupt("Simulated interrupt") - finally: - if cursor is not None: - assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" - - def test_nested_cursor_context_managers(self): - """Test that nested cursor context managers properly close operations on the server.""" - with self.connection() as conn: - with conn.cursor() as cursor1: - cursor1.execute("SELECT 1 AS test1") - assert cursor1.active_op_handle is not None - - with conn.cursor() as cursor2: - cursor2.execute("SELECT 2 AS test2") - assert cursor2.active_op_handle is not None - - # After inner context manager exit, cursor2 should be not open - assert not cursor2.open - assert cursor2.active_op_handle is None - # After outer context manager exit, cursor1 should be not open - assert not cursor1.open - assert cursor1.active_op_handle is None - - def test_cursor_error_handling(self): - """Test that cursor close handles errors properly to prevent orphaned operations.""" - with self.connection() as conn: - cursor = conn.cursor() - - cursor.execute("SELECT 1 AS test") - - op_handle = cursor.active_op_handle - - assert op_handle is not None - - # Manually close the operation to simulate server-side closure - conn.thrift_backend.close_command(op_handle) - - cursor.close() - - assert not cursor.open - - def test_result_set_close(self): - """Test that ResultSet.close() properly closes operations on the server and handles state correctly.""" - with self.connection() as conn: - cursor = conn.cursor() - try: - cursor.execute("SELECT * FROM RANGE(10)") - - result_set = cursor.active_result_set - assert result_set is not None - - initial_op_state = result_set.op_state - - result_set.close() - - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE - assert result_set.op_state != initial_op_state - - # Closing the result set again should be a no-op and not raise exceptions - result_set.close() - finally: - cursor.close() + assert "Session appears to have been closed already" in caplog.text # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 30b1fd96a..b67101943 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -20,7 +20,6 @@ import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError -from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row from tests.unit.test_fetches import FetchTests @@ -523,116 +522,6 @@ def test_access_current_query_id(self): cursor.close() self.assertIsNone(cursor.query_id) - def test_cursor_close_handles_exception(self): - """Test that Cursor.close() handles exceptions from close_command properly.""" - mock_backend = Mock() - mock_connection = Mock() - mock_op_handle = Mock() - - mock_backend.close_command.side_effect = Exception("Test error") - - cursor = client.Cursor(mock_connection, mock_backend) - cursor.active_op_handle = mock_op_handle - - cursor.close() - - mock_backend.close_command.assert_called_once_with(mock_op_handle) - - self.assertIsNone(cursor.active_op_handle) - - self.assertFalse(cursor.open) - - def test_cursor_context_manager_handles_exit_exception(self): - """Test that cursor's context manager handles exceptions during __exit__.""" - mock_backend = Mock() - mock_connection = Mock() - - cursor = client.Cursor(mock_connection, mock_backend) - original_close = cursor.close - cursor.close = Mock(side_effect=Exception("Test error during close")) - - try: - with cursor: - raise ValueError("Test error inside context") - except ValueError: - pass - - cursor.close.assert_called_once() - - def test_connection_close_handles_cursor_close_exception(self): - """Test that _close handles exceptions from cursor.close() properly.""" - cursors_closed = [] - - def mock_close_with_exception(): - cursors_closed.append(1) - raise Exception("Test error during close") - - cursor1 = Mock() - cursor1.close = mock_close_with_exception - - def mock_close_normal(): - cursors_closed.append(2) - - cursor2 = Mock() - cursor2.close = mock_close_normal - - mock_backend = Mock() - mock_session_handle = Mock() - - try: - for cursor in [cursor1, cursor2]: - try: - cursor.close() - except Exception: - pass - - mock_backend.close_session(mock_session_handle) - except Exception as e: - self.fail(f"Connection close should handle exceptions: {e}") - - self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") - - def test_resultset_close_handles_cursor_already_closed_error(self): - """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" - result_set = client.ResultSet.__new__(client.ResultSet) - result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' - result_set.connection = Mock() - result_set.connection.open = True - result_set.op_state = 'RUNNING' - result_set.has_been_closed_server_side = False - result_set.command_id = Mock() - - class MockRequestError(Exception): - def __init__(self): - self.args = ["Error message", CursorAlreadyClosedError()] - - result_set.thrift_backend.close_command.side_effect = MockRequestError() - - original_close = client.ResultSet.close - try: - try: - if ( - result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE - and not result_set.has_been_closed_server_side - and result_set.connection.open - ): - result_set.thrift_backend.close_command(result_set.command_id) - except MockRequestError as e: - if isinstance(e.args[1], CursorAlreadyClosedError): - pass - finally: - result_set.has_been_closed_server_side = True - result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE - - result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) - - assert result_set.has_been_closed_server_side is True - - assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE - finally: - pass - if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) From 8df8c3382d1147e7d4bf464fb3821e2a85979198 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 06:11:05 +0000 Subject: [PATCH 16/21] Reapply "Merge remote-tracking branch 'upstream/sea-migration' into decouple-session" This reverts commit bdb83817f49e1d88a01679b11da8e55e8e80b42f. Signed-off-by: varun-edachali-dbx --- .github/CODEOWNERS | 2 +- src/databricks/sql/client.py | 42 +++++++++- src/databricks/sql/thrift_backend.py | 17 ++++ tests/e2e/test_driver.py | 100 +++++++++++++++++++++++- tests/unit/test_client.py | 111 +++++++++++++++++++++++++++ 5 files changed, 265 insertions(+), 7 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 0d074c07b..11d5aeb0a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,4 +2,4 @@ # the repo. Unless a later match takes precedence, these # users will be requested for review when someone opens a # pull request. -* @deeksha-db @samikshya-db @jprakash-db @yunbodeng-db @jackyhu-db @benc-db +* @deeksha-db @samikshya-db @jprakash-db @jackyhu-db @madhav-db @gopalldb @jayantsing-db @vikrantpuppala @shivam2680 diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4ddc5069b..82c74a1d5 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -215,6 +215,12 @@ def read(self) -> Optional[OAuthToken]: # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage + logger.debug( + "Connection.__init__(server_hostname=%s, http_path=%s)", + server_hostname, + http_path, + ) + if access_token: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} @@ -277,7 +283,13 @@ def __enter__(self) -> "Connection": return self def __exit__(self, exc_type, exc_value, traceback): - self.close() + try: + self.close() + except BaseException as e: + logger.warning(f"Exception during connection close in __exit__: {e}") + if exc_type is None: + raise + return False def __del__(self): if self.open: @@ -400,7 +412,14 @@ def __enter__(self) -> "Cursor": return self def __exit__(self, exc_type, exc_value, traceback): - self.close() + try: + logger.debug("Cursor context manager exiting, calling close()") + self.close() + except BaseException as e: + logger.warning(f"Exception during cursor close in __exit__: {e}") + if exc_type is None: + raise + return False def __iter__(self): if self.active_result_set: @@ -731,6 +750,9 @@ def execute( :returns self """ + logger.debug( + "Cursor.execute(operation=%s, parameters=%s)", operation, parameters + ) param_approach = self._determine_parameter_approach(parameters) if param_approach == ParameterApproach.NONE: @@ -1107,7 +1129,21 @@ def cancel(self) -> None: def close(self) -> None: """Close cursor""" self.open = False - self.active_op_handle = None + + # Close active operation handle if it exists + if self.active_op_handle: + try: + self.thrift_backend.close_command(self.active_op_handle) + except RequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + logger.info("Operation was canceled by a prior request") + else: + logging.warning(f"Error closing operation handle: {e}") + except Exception as e: + logging.warning(f"Error closing operation handle: {e}") + finally: + self.active_op_handle = None + if self.active_result_set: self._close_and_clear_active_result_set() diff --git a/src/databricks/sql/thrift_backend.py b/src/databricks/sql/thrift_backend.py index 2e3478d77..e3dc38ad5 100644 --- a/src/databricks/sql/thrift_backend.py +++ b/src/databricks/sql/thrift_backend.py @@ -131,6 +131,13 @@ def __init__( # max_download_threads # Number of threads for handling cloud fetch downloads. Defaults to 10 + logger.debug( + "ThriftBackend.__init__(server_hostname=%s, port=%s, http_path=%s)", + server_hostname, + port, + http_path, + ) + port = port or 443 if kwargs.get("_connection_uri"): uri = kwargs.get("_connection_uri") @@ -390,6 +397,8 @@ def attempt_request(attempt): # TODO: don't use exception handling for GOS polling... + logger.error("ThriftBackend.attempt_request: HTTPError: %s", err) + gos_name = TCLIServiceClient.GetOperationStatus.__name__ if method.__name__ == gos_name: delay_default = ( @@ -434,6 +443,7 @@ def attempt_request(attempt): else: logger.warning(log_string) except Exception as err: + logger.error("ThriftBackend.attempt_request: Exception: %s", err) error = err retry_delay = extract_retry_delay(attempt) error_message = ThriftBackend._extract_error_message_from_headers( @@ -888,6 +898,12 @@ def execute_command( ): assert session_handle is not None + logger.debug( + "ThriftBackend.execute_command(operation=%s, session_handle=%s)", + operation, + session_handle, + ) + spark_arrow_types = ttypes.TSparkArrowTypes( timestampAsArrow=self._use_arrow_native_timestamps, decimalAsArrow=self._use_arrow_native_decimals, @@ -1074,6 +1090,7 @@ def fetch_results( return queue, resp.hasMoreRows def close_command(self, op_handle): + logger.debug("ThriftBackend.close_command(op_handle=%s)", op_handle) req = ttypes.TCloseOperationReq(operationHandle=op_handle) resp = self.make_request(self._client.CloseOperation, req) return resp.status diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index cfd561400..440d4efb3 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -50,7 +50,7 @@ from tests.e2e.common.uc_volume_tests import PySQLUCVolumeTestSuiteMixin -from databricks.sql.exc import SessionAlreadyClosedError +from databricks.sql.exc import SessionAlreadyClosedError, CursorAlreadyClosedError log = logging.getLogger(__name__) @@ -820,7 +820,6 @@ def test_close_connection_closes_cursors(self): ars = cursor.active_result_set # We must manually run this check because thrift_backend always forces `has_been_closed_server_side` to True - # Cursor op state should be open before connection is closed status_request = ttypes.TGetOperationStatusReq( operationHandle=ars.command_id, getProgressUpdate=False @@ -847,9 +846,104 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): with self.connection() as conn: # First .close() call is explicit here conn.close() - assert "Session appears to have been closed already" in caplog.text + conn = None + try: + with pytest.raises(KeyboardInterrupt): + with self.connection() as c: + conn = c + raise KeyboardInterrupt("Simulated interrupt") + finally: + if conn is not None: + assert not conn.open, "Connection should be closed after KeyboardInterrupt" + + def test_cursor_close_properly_closes_operation(self): + """Test that Cursor.close() properly closes the active operation handle on the server.""" + with self.connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SELECT 1 AS test") + assert cursor.active_op_handle is not None + cursor.close() + assert cursor.active_op_handle is None + assert not cursor.open + finally: + if cursor.open: + cursor.close() + + conn = None + cursor = None + try: + with self.connection() as c: + conn = c + with pytest.raises(KeyboardInterrupt): + with conn.cursor() as cur: + cursor = cur + raise KeyboardInterrupt("Simulated interrupt") + finally: + if cursor is not None: + assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" + + def test_nested_cursor_context_managers(self): + """Test that nested cursor context managers properly close operations on the server.""" + with self.connection() as conn: + with conn.cursor() as cursor1: + cursor1.execute("SELECT 1 AS test1") + assert cursor1.active_op_handle is not None + + with conn.cursor() as cursor2: + cursor2.execute("SELECT 2 AS test2") + assert cursor2.active_op_handle is not None + + # After inner context manager exit, cursor2 should be not open + assert not cursor2.open + assert cursor2.active_op_handle is None + + # After outer context manager exit, cursor1 should be not open + assert not cursor1.open + assert cursor1.active_op_handle is None + + def test_cursor_error_handling(self): + """Test that cursor close handles errors properly to prevent orphaned operations.""" + with self.connection() as conn: + cursor = conn.cursor() + + cursor.execute("SELECT 1 AS test") + + op_handle = cursor.active_op_handle + + assert op_handle is not None + + # Manually close the operation to simulate server-side closure + conn.thrift_backend.close_command(op_handle) + + cursor.close() + + assert not cursor.open + + def test_result_set_close(self): + """Test that ResultSet.close() properly closes operations on the server and handles state correctly.""" + with self.connection() as conn: + cursor = conn.cursor() + try: + cursor.execute("SELECT * FROM RANGE(10)") + + result_set = cursor.active_result_set + assert result_set is not None + + initial_op_state = result_set.op_state + + result_set.close() + + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + assert result_set.op_state != initial_op_state + + # Closing the result set again should be a no-op and not raise exceptions + result_set.close() + finally: + cursor.close() + # use a RetrySuite to encapsulate these tests which we'll typically want to run together; however keep # the 429/503 subsuites separate since they execute under different circumstances. diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index b67101943..30b1fd96a 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -20,6 +20,7 @@ import databricks.sql import databricks.sql.client as client from databricks.sql import InterfaceError, DatabaseError, Error, NotSupportedError +from databricks.sql.exc import RequestError, CursorAlreadyClosedError from databricks.sql.types import Row from tests.unit.test_fetches import FetchTests @@ -522,6 +523,116 @@ def test_access_current_query_id(self): cursor.close() self.assertIsNone(cursor.query_id) + def test_cursor_close_handles_exception(self): + """Test that Cursor.close() handles exceptions from close_command properly.""" + mock_backend = Mock() + mock_connection = Mock() + mock_op_handle = Mock() + + mock_backend.close_command.side_effect = Exception("Test error") + + cursor = client.Cursor(mock_connection, mock_backend) + cursor.active_op_handle = mock_op_handle + + cursor.close() + + mock_backend.close_command.assert_called_once_with(mock_op_handle) + + self.assertIsNone(cursor.active_op_handle) + + self.assertFalse(cursor.open) + + def test_cursor_context_manager_handles_exit_exception(self): + """Test that cursor's context manager handles exceptions during __exit__.""" + mock_backend = Mock() + mock_connection = Mock() + + cursor = client.Cursor(mock_connection, mock_backend) + original_close = cursor.close + cursor.close = Mock(side_effect=Exception("Test error during close")) + + try: + with cursor: + raise ValueError("Test error inside context") + except ValueError: + pass + + cursor.close.assert_called_once() + + def test_connection_close_handles_cursor_close_exception(self): + """Test that _close handles exceptions from cursor.close() properly.""" + cursors_closed = [] + + def mock_close_with_exception(): + cursors_closed.append(1) + raise Exception("Test error during close") + + cursor1 = Mock() + cursor1.close = mock_close_with_exception + + def mock_close_normal(): + cursors_closed.append(2) + + cursor2 = Mock() + cursor2.close = mock_close_normal + + mock_backend = Mock() + mock_session_handle = Mock() + + try: + for cursor in [cursor1, cursor2]: + try: + cursor.close() + except Exception: + pass + + mock_backend.close_session(mock_session_handle) + except Exception as e: + self.fail(f"Connection close should handle exceptions: {e}") + + self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") + + def test_resultset_close_handles_cursor_already_closed_error(self): + """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" + result_set = client.ResultSet.__new__(client.ResultSet) + result_set.thrift_backend = Mock() + result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' + result_set.connection = Mock() + result_set.connection.open = True + result_set.op_state = 'RUNNING' + result_set.has_been_closed_server_side = False + result_set.command_id = Mock() + + class MockRequestError(Exception): + def __init__(self): + self.args = ["Error message", CursorAlreadyClosedError()] + + result_set.thrift_backend.close_command.side_effect = MockRequestError() + + original_close = client.ResultSet.close + try: + try: + if ( + result_set.op_state != result_set.thrift_backend.CLOSED_OP_STATE + and not result_set.has_been_closed_server_side + and result_set.connection.open + ): + result_set.thrift_backend.close_command(result_set.command_id) + except MockRequestError as e: + if isinstance(e.args[1], CursorAlreadyClosedError): + pass + finally: + result_set.has_been_closed_server_side = True + result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE + + result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) + + assert result_set.has_been_closed_server_side is True + + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE + finally: + pass + if __name__ == "__main__": suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) From bcf5994507aa25d3e5cff01c9f8beae6c202c8d5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 12:19:02 +0000 Subject: [PATCH 17/21] fix: separate session opening logic from instantiation ensures correctness of self.session.open call in Connection Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 5 ++--- src/databricks/sql/session.py | 28 +++++++++++++++++++++------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 82c74a1d5..5175bf6b2 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -241,6 +241,7 @@ def read(self) -> Optional[OAuthToken]: _use_arrow_native_complex_types, **kwargs, ) + self.session.open() logger.info( "Successfully opened connection with session " @@ -329,9 +330,7 @@ def get_protocol_version(openSessionResp): @property def open(self) -> bool: """Return whether the connection is open by checking if the session is open.""" - # NOTE: we have to check for the existence of session in case the __del__ is called - # before the session is instantiated - return hasattr(self, "session") and self.session.open + return self.session.is_open def cursor( self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index a308b71d5..f5b58686c 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -29,10 +29,14 @@ def __init__( This class handles all session-related behavior and communication with the backend. """ - self.open = False + self.is_open = False self.host = server_hostname self.port = kwargs.get("_port", 443) + self.session_configuration = session_configuration + self.catalog = catalog + self.schema = schema + auth_provider = get_python_sql_connector_auth_provider( server_hostname, **kwargs ) @@ -78,12 +82,16 @@ def __init__( **kwargs, ) + self._session_handle = None + self.protocol_version = None + + def open(self) -> None: self._open_session_resp = self.thrift_backend.open_session( - session_configuration, catalog, schema + self.session_configuration, self.catalog, self.schema ) self._session_handle = self._open_session_resp.sessionHandle self.protocol_version = self.get_protocol_version(self._open_session_resp) - self.open = True + self.is_open = True logger.info("Successfully opened session " + str(self.get_session_id_hex())) @staticmethod @@ -114,10 +122,16 @@ def get_session_handle(self): return self._session_handle def get_session_id(self): - return self.thrift_backend.handle_to_id(self._session_handle) + session_handle = self.get_session_handle() + if session_handle is None: + return None + return self.thrift_backend.handle_to_id(session_handle) def get_session_id_hex(self): - return self.thrift_backend.handle_to_hex_id(self._session_handle) + session_handle = self.get_session_handle() + if session_handle is None: + return None + return self.thrift_backend.handle_to_hex_id(session_handle) def close(self) -> None: """Close the underlying session.""" @@ -127,7 +141,7 @@ def close(self) -> None: return try: - self.thrift_backend.close_session(self._session_handle) + self.thrift_backend.close_session(self.get_session_handle()) except RequestError as e: if isinstance(e.args[1], SessionAlreadyClosedError): logger.info("Session was closed by a prior request") @@ -143,4 +157,4 @@ def close(self) -> None: except Exception as e: logger.error(f"Attempt to close session raised a local exception: {e}") - self.open = False + self.is_open = False From 500dd0b5d1649186536c6d9a1acf59f2c530d917 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 13:22:42 +0000 Subject: [PATCH 18/21] fix: use is_open attribute to denote session availability Signed-off-by: varun-edachali-dbx --- src/databricks/sql/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index f5b58686c..57338c61d 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -136,7 +136,7 @@ def get_session_id_hex(self): def close(self) -> None: """Close the underlying session.""" logger.info(f"Closing session {self.get_session_id_hex()}") - if not self.open: + if not self.is_open: logger.debug("Session appears to have been closed already") return From 510b454f07ac328289c4dfc114afd3c5090c35b5 Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Fri, 23 May 2025 13:22:55 +0000 Subject: [PATCH 19/21] fix: access thrift backend through session Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index 440d4efb3..a293cb8d3 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -916,7 +916,7 @@ def test_cursor_error_handling(self): assert op_handle is not None # Manually close the operation to simulate server-side closure - conn.thrift_backend.close_command(op_handle) + conn.session.thrift_backend.close_command(op_handle) cursor.close() From 634faa956e42146dd8766bc560f1aae185d7e76a Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sat, 24 May 2025 02:22:28 +0000 Subject: [PATCH 20/21] chore: use get_handle() instead of private session attribute in client Signed-off-by: varun-edachali-dbx --- src/databricks/sql/client.py | 16 ++++++++-------- src/databricks/sql/session.py | 30 +++++++++++++++--------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 5175bf6b2..d6a9e6b08 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -306,11 +306,11 @@ def __del__(self): def get_session_id(self): """Get the session ID from the Session object""" - return self.session.get_session_id() + return self.session.get_id() def get_session_id_hex(self): """Get the session ID in hex format from the Session object""" - return self.session.get_session_id_hex() + return self.session.get_id_hex() @staticmethod def server_parameterized_queries_enabled(protocolVersion): @@ -776,7 +776,7 @@ def execute( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -839,7 +839,7 @@ def execute_async( self._close_and_clear_active_result_set() self.thrift_backend.execute_command( operation=prepared_operation, - session_handle=self.connection.session._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, lz4_compression=self.connection.lz4_compression, @@ -935,7 +935,7 @@ def catalogs(self) -> "Cursor": self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_catalogs( - session_handle=self.connection.session._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -961,7 +961,7 @@ def schemas( self._check_not_closed() self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_schemas( - session_handle=self.connection.session._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -994,7 +994,7 @@ def tables( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_tables( - session_handle=self.connection.session._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, @@ -1029,7 +1029,7 @@ def columns( self._close_and_clear_active_result_set() execute_response = self.thrift_backend.get_columns( - session_handle=self.connection.session._session_handle, + session_handle=self.connection.session.get_handle(), max_rows=self.arraysize, max_bytes=self.buffer_size_bytes, cursor=self, diff --git a/src/databricks/sql/session.py b/src/databricks/sql/session.py index 57338c61d..f2f38d572 100644 --- a/src/databricks/sql/session.py +++ b/src/databricks/sql/session.py @@ -82,17 +82,17 @@ def __init__( **kwargs, ) - self._session_handle = None + self._handle = None self.protocol_version = None def open(self) -> None: self._open_session_resp = self.thrift_backend.open_session( self.session_configuration, self.catalog, self.schema ) - self._session_handle = self._open_session_resp.sessionHandle + self._handle = self._open_session_resp.sessionHandle self.protocol_version = self.get_protocol_version(self._open_session_resp) self.is_open = True - logger.info("Successfully opened session " + str(self.get_session_id_hex())) + logger.info("Successfully opened session " + str(self.get_id_hex())) @staticmethod def get_protocol_version(openSessionResp): @@ -118,30 +118,30 @@ def server_parameterized_queries_enabled(protocolVersion): else: return False - def get_session_handle(self): - return self._session_handle + def get_handle(self): + return self._handle - def get_session_id(self): - session_handle = self.get_session_handle() - if session_handle is None: + def get_id(self): + handle = self.get_handle() + if handle is None: return None - return self.thrift_backend.handle_to_id(session_handle) + return self.thrift_backend.handle_to_id(handle) - def get_session_id_hex(self): - session_handle = self.get_session_handle() - if session_handle is None: + def get_id_hex(self): + handle = self.get_handle() + if handle is None: return None - return self.thrift_backend.handle_to_hex_id(session_handle) + return self.thrift_backend.handle_to_hex_id(handle) def close(self) -> None: """Close the underlying session.""" - logger.info(f"Closing session {self.get_session_id_hex()}") + logger.info(f"Closing session {self.get_id_hex()}") if not self.is_open: logger.debug("Session appears to have been closed already") return try: - self.thrift_backend.close_session(self.get_session_handle()) + self.thrift_backend.close_session(self.get_handle()) except RequestError as e: if isinstance(e.args[1], SessionAlreadyClosedError): logger.info("Session was closed by a prior request") From a32862b0eb8b7abbe4726e2f879a2b70d4023f3d Mon Sep 17 00:00:00 2001 From: varun-edachali-dbx Date: Sat, 24 May 2025 02:22:42 +0000 Subject: [PATCH 21/21] formatting (black) Signed-off-by: varun-edachali-dbx --- tests/e2e/test_driver.py | 8 +++++-- tests/unit/test_client.py | 50 +++++++++++++++++++++------------------ 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/e2e/test_driver.py b/tests/e2e/test_driver.py index a293cb8d3..abe0e22d2 100644 --- a/tests/e2e/test_driver.py +++ b/tests/e2e/test_driver.py @@ -856,7 +856,9 @@ def test_closing_a_closed_connection_doesnt_fail(self, caplog): raise KeyboardInterrupt("Simulated interrupt") finally: if conn is not None: - assert not conn.open, "Connection should be closed after KeyboardInterrupt" + assert ( + not conn.open + ), "Connection should be closed after KeyboardInterrupt" def test_cursor_close_properly_closes_operation(self): """Test that Cursor.close() properly closes the active operation handle on the server.""" @@ -883,7 +885,9 @@ def test_cursor_close_properly_closes_operation(self): raise KeyboardInterrupt("Simulated interrupt") finally: if cursor is not None: - assert not cursor.open, "Cursor should be closed after KeyboardInterrupt" + assert ( + not cursor.open + ), "Cursor should be closed after KeyboardInterrupt" def test_nested_cursor_context_managers(self): """Test that nested cursor context managers properly close operations on the server.""" diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 30b1fd96a..47361206d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -528,7 +528,7 @@ def test_cursor_close_handles_exception(self): mock_backend = Mock() mock_connection = Mock() mock_op_handle = Mock() - + mock_backend.close_command.side_effect = Exception("Test error") cursor = client.Cursor(mock_connection, mock_backend) @@ -537,78 +537,80 @@ def test_cursor_close_handles_exception(self): cursor.close() mock_backend.close_command.assert_called_once_with(mock_op_handle) - + self.assertIsNone(cursor.active_op_handle) - + self.assertFalse(cursor.open) def test_cursor_context_manager_handles_exit_exception(self): """Test that cursor's context manager handles exceptions during __exit__.""" mock_backend = Mock() mock_connection = Mock() - + cursor = client.Cursor(mock_connection, mock_backend) original_close = cursor.close cursor.close = Mock(side_effect=Exception("Test error during close")) - + try: with cursor: raise ValueError("Test error inside context") except ValueError: pass - + cursor.close.assert_called_once() def test_connection_close_handles_cursor_close_exception(self): """Test that _close handles exceptions from cursor.close() properly.""" cursors_closed = [] - + def mock_close_with_exception(): cursors_closed.append(1) raise Exception("Test error during close") - + cursor1 = Mock() cursor1.close = mock_close_with_exception - + def mock_close_normal(): cursors_closed.append(2) - + cursor2 = Mock() cursor2.close = mock_close_normal - + mock_backend = Mock() mock_session_handle = Mock() - + try: for cursor in [cursor1, cursor2]: try: cursor.close() except Exception: pass - + mock_backend.close_session(mock_session_handle) except Exception as e: self.fail(f"Connection close should handle exceptions: {e}") - - self.assertEqual(cursors_closed, [1, 2], "Both cursors should have close called") + + self.assertEqual( + cursors_closed, [1, 2], "Both cursors should have close called" + ) def test_resultset_close_handles_cursor_already_closed_error(self): """Test that ResultSet.close() handles CursorAlreadyClosedError properly.""" result_set = client.ResultSet.__new__(client.ResultSet) result_set.thrift_backend = Mock() - result_set.thrift_backend.CLOSED_OP_STATE = 'CLOSED' + result_set.thrift_backend.CLOSED_OP_STATE = "CLOSED" result_set.connection = Mock() result_set.connection.open = True - result_set.op_state = 'RUNNING' + result_set.op_state = "RUNNING" result_set.has_been_closed_server_side = False result_set.command_id = Mock() class MockRequestError(Exception): def __init__(self): self.args = ["Error message", CursorAlreadyClosedError()] - + result_set.thrift_backend.close_command.side_effect = MockRequestError() - + original_close = client.ResultSet.close try: try: @@ -624,11 +626,13 @@ def __init__(self): finally: result_set.has_been_closed_server_side = True result_set.op_state = result_set.thrift_backend.CLOSED_OP_STATE - - result_set.thrift_backend.close_command.assert_called_once_with(result_set.command_id) - + + result_set.thrift_backend.close_command.assert_called_once_with( + result_set.command_id + ) + assert result_set.has_been_closed_server_side is True - + assert result_set.op_state == result_set.thrift_backend.CLOSED_OP_STATE finally: pass