From 09f6433d65f5d66766c7dd75f7788025a01bcd60 Mon Sep 17 00:00:00 2001 From: Denis Bychkov Date: Wed, 24 Apr 2013 18:08:15 +0000 Subject: [PATCH] Clean up callbacks --- asyncmongo/__init__.py | 3 +- asyncmongo/asyncjobs.py | 129 ++++++++++++++++++++++-------------- asyncmongo/connection.py | 60 ++++++++++++----- asyncmongo/errors.py | 2 +- test/test_authentication.py | 28 +++++++- test/test_replica_set.py | 16 +++++ 6 files changed, 166 insertions(+), 72 deletions(-) diff --git a/asyncmongo/__init__.py b/asyncmongo/__init__.py index 06f8dab..4d8c5f4 100755 --- a/asyncmongo/__init__.py +++ b/asyncmongo/__init__.py @@ -34,6 +34,7 @@ GEO2D = "2d" """Index specifier for a 2-dimensional `geospatial index`""" -from errors import Error, InterfaceError, DatabaseError, DataError, IntegrityError, ProgrammingError, NotSupportedError +from errors import (Error, InterfaceError, AuthenticationError, DatabaseError, RSConnectionError, + DataError, IntegrityError, ProgrammingError, NotSupportedError) from client import Client diff --git a/asyncmongo/asyncjobs.py b/asyncmongo/asyncjobs.py index 139051a..6430a13 100644 --- a/asyncmongo/asyncjobs.py +++ b/asyncmongo/asyncjobs.py @@ -47,23 +47,38 @@ def process(self, *args, **kwargs): else: self.callback(None, e) -class AuthorizeJob(object): - def __init__(self, connection, dbuser, dbpass, pool): - super(AuthorizeJob, self).__init__() + +class AsyncJob(object): + def __init__(self, connection, state, err_callback): + super(AsyncJob, self).__init__() self.connection = connection - self._state = "start" + self._err_callback = err_callback + self._state = state + + def _error(self, e): + self.connection.close() + if self._err_callback: + self._err_callback(e) + + def update_err_callback(self, err_callback): + self._err_callback = err_callback + + def __repr__(self): + return "%s at 0x%X, state = %r" % (self.__class__.__name__, id(self), self._state) + + +class AuthorizeJob(AsyncJob): + def __init__(self, connection, dbuser, dbpass, pool, err_callback): + super(AuthorizeJob, self).__init__(connection, "start", err_callback) self.dbuser = dbuser self.dbpass = dbpass self.pool = pool - def __repr__(self): - return "AuthorizeJob at 0x%X, state = %r" % (id(self), self._state) - def process(self, response=None, error=None): if error: - logging.debug(error) - logging.debug(response) - raise AuthenticationError(error) + logging.debug("Error during authentication: %r", error) + self._error(AuthenticationError(error)) + return if self._state == "start": self._state = "nonce" @@ -80,9 +95,13 @@ def process(self, response=None, error=None): elif self._state == "nonce": # this is the nonce response self._state = "finish" - nonce = response['data'][0]['nonce'] - logging.debug("Nonce received: %r", nonce) - key = helpers._auth_key(nonce, self.dbuser, self.dbpass) + try: + nonce = response['data'][0]['nonce'] + logging.debug("Nonce received: %r", nonce) + key = helpers._auth_key(nonce, self.dbuser, self.dbpass) + except Exception, e: + self._error(AuthenticationError(e)) + return msg = message.query( 0, @@ -98,28 +117,31 @@ def process(self, response=None, error=None): self.connection._send_message(msg, self.process) elif self._state == "finish": self._state = "done" - assert response['number_returned'] == 1 - response = response['data'][0] - if response['ok'] != 1: - logging.debug('Failed authentication %s' % response['errmsg']) - raise AuthenticationError(response['errmsg']) + try: + assert response['number_returned'] == 1 + response = response['data'][0] + except Exception, e: + self._error(AuthenticationError(e)) + return + + if response.get("ok") != 1: + logging.debug("Failed authentication %s", response.get("errmsg")) + self._error(AuthenticationError(response.get("errmsg"))) + return self.connection._next_job() else: - raise ValueError("Unexpected state: %s" % self._state) + self._error(ValueError("Unexpected state: %s" % self._state)) -class ConnectRSJob(object): - def __init__(self, connection, seed, rs, secondary_only): - self.connection = connection + +class ConnectRSJob(AsyncJob): + def __init__(self, connection, seed, rs, secondary_only, err_callback): + super(ConnectRSJob, self).__init__(connection, "seed", err_callback) self.known_hosts = set(seed) self.rs = rs self._blacklisted = set() - self._state = "seed" self._primary = None self._sec_only = secondary_only - def __repr__(self): - return "ConnectRSJob at 0x%X, state = %s" % (id(self), self._state) - def process(self, response=None, error=None): if error: logging.debug("Problem connecting: %s", error) @@ -159,7 +181,8 @@ def process(self, response=None, error=None): break else: - raise RSConnectionError("No more hosts to try, tried: %s" % self.known_hosts) + self._error(RSConnectionError("No more hosts to try, tried: %s" % self.known_hosts)) + return self._state = "ismaster" msg = message.query( @@ -174,36 +197,42 @@ def process(self, response=None, error=None): elif self._state == "ismaster": logging.debug("ismaster response: %r", response) - if len(response["data"]) == 1: + try: + assert len(response["data"]) == 1 res = response["data"][0] - else: - raise RSConnectionError("Invalid response data: %r" % response["data"]) + except Exception, e: + self._error(RSConnectionError("Invalid response data: %r" % response.get("data"))) + return rs_name = res.get("setName") - if rs_name: - if rs_name != self.rs: - raise RSConnectionError("Wrong replica set: %s, expected: %s" % - (rs_name, self.rs)) + if rs_name and rs_name != self.rs: + self._error(RSConnectionError("Wrong replica set: %s, expected: %s" % (rs_name, self.rs))) + return + hosts = res.get("hosts") if hosts: self.known_hosts.update(helpers._parse_host(h) for h in hosts) ismaster = res.get("ismaster") hidden = res.get("hidden") - if ismaster and not self._sec_only: # master and required to connect to primary - assert not hidden, "Primary cannot be hidden" - logging.debug("Connected to master (%s)", res.get("me", "unknown")) - self._state = "done" - self.connection._next_job() - elif not ismaster and self._sec_only and not hidden: # not master and required to connect to secondary - assert res.get("secondary"), "Secondary must self-report as secondary" - logging.debug("Connected to secondary (%s)", res.get("me", "unknown")) - self._state = "done" - self.connection._next_job() - else: # either not master and primary connection required or master and secondary required - primary = res.get("primary") - if primary: - self._primary = helpers._parse_host(primary) + try: + if ismaster and not self._sec_only: # master and required to connect to primary + assert not hidden, "Primary cannot be hidden" + logging.debug("Connected to master (%s)", res.get("me", "unknown")) + self._state = "done" + self.connection._next_job() + elif not ismaster and self._sec_only and not hidden: # not master and required to connect to secondary + assert res.get("secondary"), "Secondary must self-report as secondary" + logging.debug("Connected to secondary (%s)", res.get("me", "unknown")) + self._state = "done" + self.connection._next_job() + else: # either not master and primary connection required or master and secondary required + primary = res.get("primary") + if primary: + self._primary = helpers._parse_host(primary) + self._state = "seed" + self.process() + except Exception, e: + self._error(RSConnectionError(e)) + return - self._state = "seed" - self.process() diff --git a/asyncmongo/connection.py b/asyncmongo/connection.py index 500e57d..5c8ac69 100644 --- a/asyncmongo/connection.py +++ b/asyncmongo/connection.py @@ -19,6 +19,7 @@ import struct import logging from types import NoneType +import functools from errors import ProgrammingError, IntegrityError, InterfaceError import helpers @@ -84,19 +85,24 @@ def __init__(self, self.__backend = self.__load_backend(backend) self.__job_queue = [] self.usage_count = 0 - self.__connect() + + self.__connect(self.connection_error) + + def connection_error(self, error): + raise error def __load_backend(self, name): __import__('asyncmongo.backends.%s_backend' % name) mod = sys.modules['asyncmongo.backends.%s_backend' % name] return mod.AsyncBackend() - def __connect(self): + def __connect(self, err_callback): + # The callback is only called in case of exception by async jobs if self.__dbuser and self.__dbpass: - self._put_job(asyncjobs.AuthorizeJob(self, self.__dbuser, self.__dbpass, self.__pool)) + self._put_job(asyncjobs.AuthorizeJob(self, self.__dbuser, self.__dbpass, self.__pool, err_callback)) if self.__rs: - self._put_job(asyncjobs.ConnectRSJob(self, self.__seed, self.__rs, self.__secondary_only)) + self._put_job(asyncjobs.ConnectRSJob(self, self.__seed, self.__rs, self.__secondary_only, err_callback)) # Mark the connection as alive, even though it's not alive yet to prevent double-connecting self.__alive = True else: @@ -116,34 +122,54 @@ def _socket_connect(self): def _socket_close(self): """cleanup after the socket is closed by the other end""" - if self.__callback: - self.__callback(None, InterfaceError('connection closed')) + callback = self.__callback self.__callback = None - self.__alive = False - self.__pool.cache(self) + try: + if callback: + callback(None, InterfaceError('connection closed')) + finally: + self.__alive = False + self.__pool.cache(self) def _close(self): """close the socket and cleanup""" - if self.__callback: - self.__callback(None, InterfaceError('connection closed')) + callback = self.__callback self.__callback = None - self.__alive = False - self.__stream.close() - + try: + if callback: + callback(None, InterfaceError('connection closed')) + finally: + self.__alive = False + self.__stream.close() + def close(self): """close this connection; re-cache this connection object""" - self._close() - self.__pool.cache(self) + try: + self._close() + finally: + self.__pool.cache(self) def send_message(self, message, callback): """ send a message over the wire; callback=None indicates a safe=False call where we write and forget about it""" if self.__callback is not None: raise ProgrammingError('connection already in use') - + + if callback: + err_callback = functools.partial(callback, None) + else: + err_callback = None + + # Go and update err_callback for async jobs in queue if any + for job in self.__job_queue: + # this is a dirty hack and I hate it, but there is no way of setting the correct + # err_callback during the connection time + if isinstance(job, asyncjobs.AsyncJob): + job.update_err_callback(err_callback) + if not self.__alive: if self.__autoreconnect: - self.__connect() + self.__connect(err_callback) else: raise InterfaceError('connection invalid. autoreconnect=False') diff --git a/asyncmongo/errors.py b/asyncmongo/errors.py index 0000592..c6acd1f 100644 --- a/asyncmongo/errors.py +++ b/asyncmongo/errors.py @@ -29,7 +29,7 @@ class Error(StandardError): class InterfaceError(Error): pass -class RSConnectionError(Error): +class RSConnectionError(InterfaceError): pass class DatabaseError(Error): diff --git a/test/test_authentication.py b/test/test_authentication.py index 2198458..11ecd42 100644 --- a/test/test_authentication.py +++ b/test/test_authentication.py @@ -18,7 +18,8 @@ def setUp(self): def test_authentication(self): try: test_shunt.setup() - db = asyncmongo.Client(pool_id='testauth', host='127.0.0.1', port=27018, dbname='test', dbuser='testuser', dbpass='testpass', maxconnections=2) + db = asyncmongo.Client(pool_id='testauth', host='127.0.0.1', port=27018, dbname='test', dbuser='testuser', + dbpass='testpass', maxconnections=2) def update_callback(response, error): logging.info("UPDATE:") @@ -27,8 +28,9 @@ def update_callback(response, error): assert len(response) == 1 test_shunt.register_called('update') - db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, upsert=True, callback=update_callback) - + db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, upsert=True, + callback=update_callback) + tornado.ioloop.IOLoop.instance().start() test_shunt.assert_called('update') @@ -49,3 +51,23 @@ def query_callback(response, error): tornado.ioloop.IOLoop.instance().stop() raise + def test_failed_auth(self): + try: + test_shunt.setup() + db = asyncmongo.Client(pool_id='testauth_f', host='127.0.0.1', port=27018, dbname='test', dbuser='testuser', + dbpass='wrong', maxconnections=2) + + def query_callback(response, error): + tornado.ioloop.IOLoop.instance().stop() + logging.info(response) + logging.info(error) + assert isinstance(error, asyncmongo.AuthenticationError) + assert response is None + test_shunt.register_called('auth_failed') + + db.test_stats.find_one({"_id" : TEST_TIMESTAMP}, callback=query_callback) + tornado.ioloop.IOLoop.instance().start() + test_shunt.assert_called('auth_failed') + except: + tornado.ioloop.IOLoop.instance().stop() + raise diff --git a/test/test_replica_set.py b/test/test_replica_set.py index 3540a79..54b9034 100644 --- a/test/test_replica_set.py +++ b/test/test_replica_set.py @@ -117,6 +117,22 @@ def process(self, *args, **kwargs): def test_update(self): try: test_shunt.setup() + + db = asyncmongo.Client(pool_id='testrs_f', rs="wrong_rs", seed=[("127.0.0.1", 27020)], dbname='test', maxconnections=2) + + # Try to update with a wrong replica set name + def update_callback(response, error): + tornado.ioloop.IOLoop.instance().stop() + logging.info(response) + logging.info(error) + assert isinstance(error, asyncmongo.RSConnectionError) + test_shunt.register_called('update_f') + + db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, callback=update_callback) + + tornado.ioloop.IOLoop.instance().start() + test_shunt.assert_called('update_f') + db = asyncmongo.Client(pool_id='testrs', rs="rs0", seed=[("127.0.0.1", 27020)], dbname='test', maxconnections=2) # Update