From c5d8ca7f6c88351f1b8f3b230baa916cd7c3397b Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Thu, 21 Mar 2024 17:44:42 +0800 Subject: [PATCH] fix: remove new exceptions (#326) * fix: remove new exceptions * fix: opt-in the the execution error retry * add logging for session.execute when retry execution enabled * address comments from nicco & optmizations * remove unnecessarily use space * fix: should return when session failed to be newed(but removed from pool) --- nebula3/gclient/net/Connection.py | 12 +-- nebula3/gclient/net/Session.py | 100 +++++++++++---------- nebula3/gclient/net/SessionPool.py | 134 +++++++++++++++-------------- tests/test_connection.py | 8 +- tests/test_session.py | 39 --------- tests/test_session_pool.py | 8 +- tests/test_ssl_connection.py | 14 ++- 7 files changed, 133 insertions(+), 182 deletions(-) diff --git a/nebula3/gclient/net/Connection.py b/nebula3/gclient/net/Connection.py index b4a306de..ad9071a6 100644 --- a/nebula3/gclient/net/Connection.py +++ b/nebula3/gclient/net/Connection.py @@ -27,8 +27,6 @@ AuthFailedException, IOErrorException, ClientServerIncompatibleException, - SessionException, - ExecutionErrorException, ) from nebula3.gclient.net.AuthResult import AuthResult @@ -198,12 +196,6 @@ def execute_parameter(self, session_id, stmt, params): """ try: resp = self._connection.executeWithParameter(session_id, stmt, params) - if resp.error_code == ErrorCode.E_SESSION_INVALID: - raise SessionException(resp.error_code, resp.error_msg) - if resp.error_code == ErrorCode.E_SESSION_TIMEOUT: - raise SessionException(resp.error_code, resp.error_msg) - if resp.error_code == ErrorCode.E_EXECUTION_ERROR: - raise ExecutionErrorException(resp.error_msg) return resp except Exception as te: if isinstance(te, TTransportException): @@ -274,7 +266,7 @@ def close(self): self._connection._iprot.trans.close() except Exception as e: logger.error( - 'Close connection to {}:{} failed:{}'.format(self._ip, self._port, e) + "Close connection to {}:{} failed:{}".format(self._ip, self._port, e) ) def ping(self): @@ -282,7 +274,7 @@ def ping(self): :return: True or False """ try: - resp = self._connection.execute(0, 'YIELD 1;') + resp = self._connection.execute(0, "YIELD 1;") return True except Exception: return False diff --git a/nebula3/gclient/net/Session.py b/nebula3/gclient/net/Session.py index 9cd3c2df..1faf4b58 100644 --- a/nebula3/gclient/net/Session.py +++ b/nebula3/gclient/net/Session.py @@ -5,14 +5,14 @@ # This source code is licensed under Apache 2.0 License. +import json import time from nebula3.Exception import ( IOErrorException, NotValidConnectionException, - ExecutionErrorException, ) - +from nebula3.common.ttypes import ErrorCode from nebula3.data.ResultSet import ResultSet from nebula3.gclient.net.AuthResult import AuthResult from nebula3.logger import logger @@ -25,8 +25,8 @@ def __init__( auth_result: AuthResult, pool, retry_connect=True, - retry_times=3, - retry_interval_sec=1, + execution_retry_count=0, + retry_interval_seconds=1, ): """ Initialize the Session object. @@ -35,8 +35,8 @@ def __init__( :param auth_result: The result of the authentication process. :param pool: The pool object where the session was created. :param retry_connect: A boolean indicating whether to retry the connection if it fails. - :param retry_times: The number of times to retry the connection. - :param retry_interval_sec: The interval between connection retries in seconds. + :param execution_retry_count: The number of attempts to retry the execution upon encountering an execution error(-1005), with the default being 0 (no retries). + :param retry_interval_seconds: The interval between connection retries in seconds. """ self._session_id = auth_result.get_session_id() self._timezone_offset = auth_result.get_timezone_offset() @@ -45,8 +45,8 @@ def __init__( # connection the where the session was created, if session pool was used self._pool = pool self._retry_connect = retry_connect - self._retry_times = retry_times - self._retry_interval_sec = retry_interval_sec + self._execution_retry_count = execution_retry_count + self._retry_interval_seconds = retry_interval_seconds # the time stamp when the session was added to the idle list of the session pool self._idle_time_start = 0 @@ -57,11 +57,27 @@ def execute_parameter(self, stmt, params): :return: ResultSet """ if self._connection is None: - raise RuntimeError('The session has been released') + raise RuntimeError("The session has been released") try: start_time = time.time() resp = self._connection.execute_parameter(self._session_id, stmt, params) end_time = time.time() + + if ( + self._execution_retry_count > 0 + and resp.error_code == ErrorCode.E_EXECUTION_ERROR + ): + for retry_count in range(1, self._execution_retry_count + 1): + logger.warning( + f"Execution error, retrying {retry_count}/{self._execution_retry_count} after {self._retry_interval_seconds}s" + ) + time.sleep(self._retry_interval_seconds) + resp = self._connection.execute_parameter( + self._session_id, stmt, params + ) + if resp.error_code != ErrorCode.E_EXECUTION_ERROR: + break + return ResultSet( resp, all_latency=int((end_time - start_time) * 1000000), @@ -72,7 +88,7 @@ def execute_parameter(self, stmt, params): self._pool.update_servers_status() if self._retry_connect: if not self._reconnect(): - logger.warning('Retry connect failed') + logger.warning("Retry connect failed") raise IOErrorException( IOErrorException.E_ALL_BROKEN, ie.message ) @@ -86,27 +102,6 @@ def execute_parameter(self, stmt, params): timezone_offset=self._timezone_offset, ) raise - except ExecutionErrorException as eee: - retry_count = 0 - while retry_count < self._retry_times: - try: - # TODO: add exponential backoff - time.sleep(self._retry_interval_sec) - resp = self._connection.execute_parameter( - self._session_id, stmt, params - ) - end_time = time.time() - return ResultSet( - resp, - all_latency=int((end_time - start_time) * 1000000), - timezone_offset=self._timezone_offset, - ) - except ExecutionErrorException: - if retry_count >= self._retry_times - 1: - raise eee - else: - retry_count += 1 - continue except Exception: raise @@ -244,18 +239,37 @@ def execute_json_with_parameter(self, stmt, params): :return: JSON string """ if self._connection is None: - raise RuntimeError('The session has been released') + raise RuntimeError("The session has been released") try: resp_json = self._connection.execute_json_with_parameter( self._session_id, stmt, params ) + if self._execution_retry_count > 0: + for retry_count in range(self._execution_retry_count): + if ( + json.loads(resp_json).get("errors", [{}])[0].get("code") + != ErrorCode.E_EXECUTION_ERROR + ): + break + logger.warning( + "Execute failed, retry count:{}/{} in {} seconds".format( + retry_count + 1, + self._execution_retry_count, + self._retry_interval_seconds, + ) + ) + time.sleep(self._retry_interval_seconds) + resp_json = self._connection.execute_json_with_parameter( + self._session_id, stmt, params + ) return resp_json + except IOErrorException as ie: if ie.type == IOErrorException.E_CONNECT_BROKEN: self._pool.update_servers_status() if self._retry_connect: if not self._reconnect(): - logger.warning('Retry connect failed') + logger.warning("Retry connect failed") raise IOErrorException( IOErrorException.E_ALL_BROKEN, ie.message ) @@ -264,22 +278,6 @@ def execute_json_with_parameter(self, stmt, params): ) return resp_json raise - except ExecutionErrorException as eee: - retry_count = 0 - while retry_count < self._retry_times: - try: - # TODO: add exponential backoff - time.sleep(self._retry_interval_sec) - resp = self._connection.execute_json_with_parameter( - self._session_id, stmt, params - ) - return resp - except ExecutionErrorException: - if retry_count >= self._retry_times - 1: - raise eee - else: - retry_count += 1 - continue except Exception: raise @@ -310,7 +308,7 @@ def ping_session(self): return True else: logger.error( - 'failed to ping the session: error code:{}, error message:{}'.format( + "failed to ping the session: error code:{}, error message:{}".format( resp.error_code, resp.error_msg ) ) @@ -342,5 +340,5 @@ def _idle_time(self): def _sign_out(self): """sign out the session""" if self._connection is None: - raise RuntimeError('The session has been released') + raise RuntimeError("The session has been released") self._connection.signout(self._session_id) diff --git a/nebula3/gclient/net/SessionPool.py b/nebula3/gclient/net/SessionPool.py index 652f7334..1cadb0e3 100644 --- a/nebula3/gclient/net/SessionPool.py +++ b/nebula3/gclient/net/SessionPool.py @@ -11,11 +11,11 @@ from threading import RLock, Timer import time +from nebula3.common.ttypes import ErrorCode from nebula3.Exception import ( AuthFailedException, NoValidSessionException, InValidHostname, - SessionException, ) from nebula3.gclient.net.Session import Session @@ -85,12 +85,12 @@ def init(self, configs): try: self._check_configs() except Exception as e: - logger.error('Invalid configs: {}'.format(e)) + logger.error("Invalid configs: {}".format(e)) return False if self._close: - logger.error('The pool has init or closed.') - raise RuntimeError('The pool has init or closed.') + logger.error("The pool has init or closed.") + raise RuntimeError("The pool has init or closed.") self._configs = configs # ping all servers @@ -102,14 +102,14 @@ def init(self, configs): ok_num = self.get_ok_servers_num() if ok_num < len(self._addresses): raise RuntimeError( - 'The services status exception: {}'.format(self._get_services_status()) + "The services status exception: {}".format(self._get_services_status()) ) # iterate all addresses and create sessions to fullfil the min_size for i in range(self._configs.min_size): session = self._new_session() if session is None: - raise RuntimeError('Get session failed') + raise RuntimeError("Get session failed") self._add_session_to_idle(session) return True @@ -143,7 +143,7 @@ def ping(self, address): return True except Exception as ex: logger.warning( - 'Connect {}:{} failed: {}'.format(address[0], address[1], ex) + "Connect {}:{} failed: {}".format(address[0], address[1], ex) ) return False @@ -170,34 +170,38 @@ def execute_parameter(self, stmt, params): """ session = self._get_idle_session() if session is None: - raise RuntimeError('Get session failed') + raise RuntimeError("Get session failed") self._add_session_to_active(session) try: resp = session.execute_parameter(stmt, params) - # reset the space name to the pool config - if resp.space_name() != self._space_name: - self._set_space_to_default(session) - - # move the session back to the idle list - self._return_session(session) - - return resp - except SessionException as se: - if se.type in [ - SessionException.E_SESSION_INVALID, - SessionException.E_SESSION_TIMEOUT, + # Check for session validity based on error code + if resp.error_code() in [ + ErrorCode.E_SESSION_INVALID, + ErrorCode.E_SESSION_TIMEOUT, ]: self._active_sessions.remove(session) session = self._get_idle_session() if session is None: - raise RuntimeError('Get session failed') + logger.warning( + "Session invalid or timeout, removed from the pool, but failed to get a new session." + ) + return resp + logger.warning("Session invalid or timeout, session has been recycled") self._add_session_to_idle(session) - raise se + else: + # reset the space name to the pool config + if resp.space_name() != self._space_name: + self._set_space_to_default(session) + + # move the session back to the idle list + self._return_session(session) + + return resp except Exception as e: - logger.error('Execute failed: {}'.format(e)) + logger.error("Execute failed: {}".format(e)) # remove the session from the pool if it is invalid self._active_sessions.remove(session) raise e @@ -268,34 +272,38 @@ def execute_json(self, stmt): def execute_json_with_parameter(self, stmt, params): session = self._get_idle_session() if session is None: - raise RuntimeError('Get session failed') + raise RuntimeError("Get session failed") self._add_session_to_active(session) try: resp = session.execute_json_with_parameter(stmt, params) - - # reset the space name to the pool config json_obj = json.loads(resp) - if json_obj["results"][0]["spaceName"] != self._space_name: - self._set_space_to_default(session) - - # move the session back to the idle list - self._return_session(session) - - return resp - except SessionException as se: - if se.type in [ - SessionException.E_SESSION_INVALID, - SessionException.E_SESSION_TIMEOUT, + # Check for session validity based on error code + if json_obj.get("errors", [{}])[0].get("code") in [ + ErrorCode.E_SESSION_INVALID, + ErrorCode.E_SESSION_TIMEOUT, ]: self._active_sessions.remove(session) session = self._get_idle_session() if session is None: - raise RuntimeError('Get session failed') + logger.warning( + "Session invalid or timeout, removed from the pool, but failed to get a new session." + ) + return resp self._add_session_to_idle(session) - raise se + logger.warning("Session invalid or timeout, session has been recycled") + + else: + # reset the space name to the pool config + if json_obj["results"][0]["spaceName"] != self._space_name: + self._set_space_to_default(session) + + # move the session back to the idle list + self._return_session(session) + + return resp except Exception as e: - logger.error('Execute failed: {}'.format(e)) + logger.error("Execute failed: {}".format(e)) # remove the session from the pool if it is invalid self._active_sessions.remove(session) raise e @@ -329,11 +337,11 @@ def get_ok_servers_num(self): def _get_services_status(self): msg_list = [] for addr in self._addresses_status.keys(): - status = 'OK' + status = "OK" if self._addresses_status[addr] != self.S_OK: - status = 'BAD' - msg_list.append('[services: {}, status: {}]'.format(addr, status)) - return ', '.join(msg_list) + status = "BAD" + msg_list.append("[services: {}, status: {}]".format(addr, status)) + return ", ".join(msg_list) def update_servers_status(self): """update the servers' status""" @@ -361,7 +369,7 @@ def _get_idle_session(self): return self._new_session() else: raise NoValidSessionException( - 'The total number of sessions reaches the pool max size {}'.format( + "The total number of sessions reaches the pool max size {}".format( self._configs.max_size ) ) @@ -373,7 +381,7 @@ def _new_session(self): :return: Session """ if self._ssl_configs is not None: - raise RuntimeError('SSL is not supported yet') + raise RuntimeError("SSL is not supported yet") self._pos = (self._pos + 1) % len(self._addresses) next_addr_index = self._pos @@ -386,7 +394,7 @@ def _new_session(self): # if the address is bad, skip it if self._addresses_status[addr] == self.S_BAD: - logger.warning('The graph service {} is not available'.format(addr)) + logger.warning("The graph service {} is not available".format(addr)) retries = retries - 1 next_addr_index = (next_addr_index + 1) % len(self._addresses) continue @@ -405,10 +413,10 @@ def _new_session(self): session = Session(connection, auth_result, self, False) # switch to the space specified in the configs - resp = session.execute('USE {}'.format(self._space_name)) + resp = session.execute("USE {}".format(self._space_name)) if not resp.is_succeeded(): raise RuntimeError( - 'Failed to get session, cannot set the session space to {} error: {} {}'.format( + "Failed to get session, cannot set the session space to {} error: {} {}".format( self._space_name, resp.error_code(), resp.error_msg() ) ) @@ -419,7 +427,7 @@ def _new_session(self): "User not exist" ): logger.error( - 'Authentication failed, because of bad credentials, close the pool {}'.format( + "Authentication failed, because of bad credentials, close the pool {}".format( e ) ) @@ -429,7 +437,7 @@ def _new_session(self): raise raise RuntimeError( - 'Failed to get a valid session, no graph service is available' + "Failed to get a valid session, no graph service is available" ) def _return_session(self, session): @@ -471,14 +479,14 @@ def _set_space_to_default(self, session): :return: void """ try: - resp = session.execute('USE {}'.format(self._space_name)) + resp = session.execute("USE {}".format(self._space_name)) if not resp.is_succeeded(): raise RuntimeError( - 'Failed to set the session space to {}'.format(self._space_name) + "Failed to set the session space to {}".format(self._space_name) ) except Exception: logger.warning( - 'Failed to set the session space to {}, the current session has been dropped'.format( + "Failed to set the session space to {}, the current session has been dropped".format( self._space_name ) ) @@ -517,23 +525,23 @@ def _period_detect(self): def _check_configs(self): """validate the configs""" if self._configs.min_size < 0: - raise RuntimeError('The min_size must be greater than 0') + raise RuntimeError("The min_size must be greater than 0") if self._configs.max_size < 0: - raise RuntimeError('The max_size must be greater than 0') + raise RuntimeError("The max_size must be greater than 0") if self._configs.min_size > self._configs.max_size: raise RuntimeError( - 'The min_size must be less than or equal to the max_size' + "The min_size must be less than or equal to the max_size" ) if self._configs.idle_time < 0: - raise RuntimeError('The idle_time must be greater or equal to 0') + raise RuntimeError("The idle_time must be greater or equal to 0") if self._configs.timeout < 0: - raise RuntimeError('The timeout must be greater or equal to 0') + raise RuntimeError("The timeout must be greater or equal to 0") if self._space_name == "": - raise RuntimeError('The space_name must be set') + raise RuntimeError("The space_name must be set") if self._username == "": - raise RuntimeError('The username must be set') + raise RuntimeError("The username must be set") if self._password == "": - raise RuntimeError('The password must be set') + raise RuntimeError("The password must be set") if self._addresses is None or len(self._addresses) == 0: - raise RuntimeError('The addresses must be set') + raise RuntimeError("The addresses must be set") diff --git a/tests/test_connection.py b/tests/test_connection.py index b55860da..fead1345 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -10,7 +10,7 @@ from unittest import TestCase from nebula3.common import ttypes -from nebula3.Exception import IOErrorException, SessionException +from nebula3.Exception import IOErrorException from nebula3.gclient.net import Connection AddrIp = ["127.0.0.1", "::1"] @@ -42,10 +42,8 @@ def test_release(self): conn.signout(session_id) # the session delete later time.sleep(12) - try: - conn.execute(session_id, "SHOW SPACES") - except Exception as ex: - assert isinstance(ex, SessionException), ex + resp = conn.execute(session_id, "SHOW SPACES") + assert resp.error_code != ttypes.ErrorCode.SUCCEEDED conn.close() except Exception as ex: assert False, ex diff --git a/tests/test_session.py b/tests/test_session.py index f966d213..4bf0942a 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -12,10 +12,6 @@ from nebula3.Config import Config from nebula3.gclient.net import ConnectionPool -from nebula3.Exception import ( - SessionException, - ExecutionErrorException, -) class TestSession(TestCase): @@ -93,38 +89,3 @@ def test_4_timeout(self): except Exception as ex: assert str(ex).find("timed out") > 0 assert True, ex - - def test_5_session_exception(self): - # test SessionException will be raised when session is invalid - try: - session = self.pool.get_session(self.user_name, self.password) - another_session = self.pool.get_session(self.user_name, self.password) - session_id = session._session_id - another_session.execute(f"KILL SESSION {session_id}") - session.execute("SHOW HOSTS") - except Exception as ex: - assert isinstance(ex, SessionException), "expect to get SessionException" - - def test_6_execute_exception(self): - # test ExecutionErrorException will be raised when execute error - # we need to mock a query's response code to trigger ExecutionErrorException - from unittest.mock import Mock, patch - - try: - session = self.pool.get_session(self.user_name, self.password) - # Mocking the Connection.execute_parameter method - with patch( - 'nebula3.graph.GraphService.Client.executeWithParameter' - ) as mock_execute: - mock_response = Mock() - mock_response.error_code = ExecutionErrorException.E_EXECUTION_ERROR - mock_execute.return_value = mock_response - session.execute("SHOW HOSTS") - # Assert that execute_parameter was called 3 times (retry mechanism) - assert ( - mock_execute.call_count == 3 - ), "execute_parameter was not retried 3 times" - except ExecutionErrorException as ex: - assert True, "ExecutionErrorException triggered as expected" - except Exception as ex: - assert False, f"Unexpected exception: {str(ex)}" diff --git a/tests/test_session_pool.py b/tests/test_session_pool.py index da0303ac..01d4d2b0 100644 --- a/tests/test_session_pool.py +++ b/tests/test_session_pool.py @@ -15,7 +15,6 @@ from nebula3.Config import SessionPoolConfig from nebula3.Exception import ( InValidHostname, - SessionException, ) from nebula3.gclient.net import Connection from nebula3.gclient.net.SessionPool import SessionPool @@ -160,11 +159,10 @@ def test_session_renew_when_invalid(self): session.execute(f"KILL SESSION {session_id}") try: session_pool.execute("SHOW HOSTS;") - except Exception as ex: - assert isinstance(ex, SessionException), "expect to get SessionException" - # The only session(size=1) should be renewed and usable + except Exception: + pass # - session_id is not in the pool - # - session_pool is usable + # - session_pool is still usable after renewing assert ( session_id not in session_pool._idle_sessions ), "session should be renewed" diff --git a/tests/test_ssl_connection.py b/tests/test_ssl_connection.py index 4f83ae2d..73b9dc5f 100644 --- a/tests/test_ssl_connection.py +++ b/tests/test_ssl_connection.py @@ -14,7 +14,7 @@ from nebula3.common import ttypes from nebula3.Config import SSL_config -from nebula3.Exception import IOErrorException, SessionException +from nebula3.Exception import IOErrorException from nebula3.gclient.net import Connection current_dir = os.path.dirname(os.path.abspath(__file__)) @@ -62,10 +62,8 @@ def test_release(self): conn.signout(session_id) # the session delete later time.sleep(12) - try: - conn.execute(session_id, "SHOW SPACES") - except Exception as ex: - assert isinstance(ex, SessionException), ex + resp = conn.execute(session_id, "SHOW SPACES") + assert resp.error_code != ttypes.ErrorCode.SUCCEEDED conn.close() except Exception as ex: assert False, ex @@ -106,10 +104,8 @@ def test_release_self_signed(self): conn.signout(session_id) # the session delete later time.sleep(12) - try: - conn.execute(session_id, "SHOW SPACES") - except Exception as ex: - assert isinstance(ex, SessionException), ex + resp = conn.execute(session_id, "SHOW SPACES") + assert resp.error_code != ttypes.ErrorCode.SUCCEEDED conn.close() except Exception as ex: assert False, ex