Skip to content
This repository has been archived by the owner on Nov 5, 2019. It is now read-only.

Commit

Permalink
Clean up callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
Denis Bychkov committed Apr 24, 2013
1 parent 1973feb commit 09f6433
Show file tree
Hide file tree
Showing 6 changed files with 166 additions and 72 deletions.
3 changes: 2 additions & 1 deletion asyncmongo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
GEO2D = "2d"
"""Index specifier for a 2-dimensional `geospatial index`"""

from errors import Error, InterfaceError, DatabaseError, DataError, IntegrityError, ProgrammingError, NotSupportedError
from errors import (Error, InterfaceError, AuthenticationError, DatabaseError, RSConnectionError,
DataError, IntegrityError, ProgrammingError, NotSupportedError)

from client import Client
129 changes: 79 additions & 50 deletions asyncmongo/asyncjobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,38 @@ def process(self, *args, **kwargs):
else:
self.callback(None, e)

class AuthorizeJob(object):
def __init__(self, connection, dbuser, dbpass, pool):
super(AuthorizeJob, self).__init__()

class AsyncJob(object):
def __init__(self, connection, state, err_callback):
super(AsyncJob, self).__init__()
self.connection = connection
self._state = "start"
self._err_callback = err_callback
self._state = state

def _error(self, e):
self.connection.close()
if self._err_callback:
self._err_callback(e)

def update_err_callback(self, err_callback):
self._err_callback = err_callback

def __repr__(self):
return "%s at 0x%X, state = %r" % (self.__class__.__name__, id(self), self._state)


class AuthorizeJob(AsyncJob):
def __init__(self, connection, dbuser, dbpass, pool, err_callback):
super(AuthorizeJob, self).__init__(connection, "start", err_callback)
self.dbuser = dbuser
self.dbpass = dbpass
self.pool = pool

def __repr__(self):
return "AuthorizeJob at 0x%X, state = %r" % (id(self), self._state)

def process(self, response=None, error=None):
if error:
logging.debug(error)
logging.debug(response)
raise AuthenticationError(error)
logging.debug("Error during authentication: %r", error)
self._error(AuthenticationError(error))
return

if self._state == "start":
self._state = "nonce"
Expand All @@ -80,9 +95,13 @@ def process(self, response=None, error=None):
elif self._state == "nonce":
# this is the nonce response
self._state = "finish"
nonce = response['data'][0]['nonce']
logging.debug("Nonce received: %r", nonce)
key = helpers._auth_key(nonce, self.dbuser, self.dbpass)
try:
nonce = response['data'][0]['nonce']
logging.debug("Nonce received: %r", nonce)
key = helpers._auth_key(nonce, self.dbuser, self.dbpass)
except Exception, e:
self._error(AuthenticationError(e))
return

msg = message.query(
0,
Expand All @@ -98,28 +117,31 @@ def process(self, response=None, error=None):
self.connection._send_message(msg, self.process)
elif self._state == "finish":
self._state = "done"
assert response['number_returned'] == 1
response = response['data'][0]
if response['ok'] != 1:
logging.debug('Failed authentication %s' % response['errmsg'])
raise AuthenticationError(response['errmsg'])
try:
assert response['number_returned'] == 1
response = response['data'][0]
except Exception, e:
self._error(AuthenticationError(e))
return

if response.get("ok") != 1:
logging.debug("Failed authentication %s", response.get("errmsg"))
self._error(AuthenticationError(response.get("errmsg")))
return
self.connection._next_job()
else:
raise ValueError("Unexpected state: %s" % self._state)
self._error(ValueError("Unexpected state: %s" % self._state))

class ConnectRSJob(object):
def __init__(self, connection, seed, rs, secondary_only):
self.connection = connection

class ConnectRSJob(AsyncJob):
def __init__(self, connection, seed, rs, secondary_only, err_callback):
super(ConnectRSJob, self).__init__(connection, "seed", err_callback)
self.known_hosts = set(seed)
self.rs = rs
self._blacklisted = set()
self._state = "seed"
self._primary = None
self._sec_only = secondary_only

def __repr__(self):
return "ConnectRSJob at 0x%X, state = %s" % (id(self), self._state)

def process(self, response=None, error=None):
if error:
logging.debug("Problem connecting: %s", error)
Expand Down Expand Up @@ -159,7 +181,8 @@ def process(self, response=None, error=None):
break

else:
raise RSConnectionError("No more hosts to try, tried: %s" % self.known_hosts)
self._error(RSConnectionError("No more hosts to try, tried: %s" % self.known_hosts))
return

self._state = "ismaster"
msg = message.query(
Expand All @@ -174,36 +197,42 @@ def process(self, response=None, error=None):
elif self._state == "ismaster":
logging.debug("ismaster response: %r", response)

if len(response["data"]) == 1:
try:
assert len(response["data"]) == 1
res = response["data"][0]
else:
raise RSConnectionError("Invalid response data: %r" % response["data"])
except Exception, e:
self._error(RSConnectionError("Invalid response data: %r" % response.get("data")))
return

rs_name = res.get("setName")
if rs_name:
if rs_name != self.rs:
raise RSConnectionError("Wrong replica set: %s, expected: %s" %
(rs_name, self.rs))
if rs_name and rs_name != self.rs:
self._error(RSConnectionError("Wrong replica set: %s, expected: %s" % (rs_name, self.rs)))
return

hosts = res.get("hosts")
if hosts:
self.known_hosts.update(helpers._parse_host(h) for h in hosts)

ismaster = res.get("ismaster")
hidden = res.get("hidden")
if ismaster and not self._sec_only: # master and required to connect to primary
assert not hidden, "Primary cannot be hidden"
logging.debug("Connected to master (%s)", res.get("me", "unknown"))
self._state = "done"
self.connection._next_job()
elif not ismaster and self._sec_only and not hidden: # not master and required to connect to secondary
assert res.get("secondary"), "Secondary must self-report as secondary"
logging.debug("Connected to secondary (%s)", res.get("me", "unknown"))
self._state = "done"
self.connection._next_job()
else: # either not master and primary connection required or master and secondary required
primary = res.get("primary")
if primary:
self._primary = helpers._parse_host(primary)
try:
if ismaster and not self._sec_only: # master and required to connect to primary
assert not hidden, "Primary cannot be hidden"
logging.debug("Connected to master (%s)", res.get("me", "unknown"))
self._state = "done"
self.connection._next_job()
elif not ismaster and self._sec_only and not hidden: # not master and required to connect to secondary
assert res.get("secondary"), "Secondary must self-report as secondary"
logging.debug("Connected to secondary (%s)", res.get("me", "unknown"))
self._state = "done"
self.connection._next_job()
else: # either not master and primary connection required or master and secondary required
primary = res.get("primary")
if primary:
self._primary = helpers._parse_host(primary)
self._state = "seed"
self.process()
except Exception, e:
self._error(RSConnectionError(e))
return

self._state = "seed"
self.process()
60 changes: 43 additions & 17 deletions asyncmongo/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import struct
import logging
from types import NoneType
import functools

from errors import ProgrammingError, IntegrityError, InterfaceError
import helpers
Expand Down Expand Up @@ -84,19 +85,24 @@ def __init__(self,
self.__backend = self.__load_backend(backend)
self.__job_queue = []
self.usage_count = 0
self.__connect()

self.__connect(self.connection_error)

def connection_error(self, error):
raise error

def __load_backend(self, name):
__import__('asyncmongo.backends.%s_backend' % name)
mod = sys.modules['asyncmongo.backends.%s_backend' % name]
return mod.AsyncBackend()

def __connect(self):
def __connect(self, err_callback):
# The callback is only called in case of exception by async jobs
if self.__dbuser and self.__dbpass:
self._put_job(asyncjobs.AuthorizeJob(self, self.__dbuser, self.__dbpass, self.__pool))
self._put_job(asyncjobs.AuthorizeJob(self, self.__dbuser, self.__dbpass, self.__pool, err_callback))

if self.__rs:
self._put_job(asyncjobs.ConnectRSJob(self, self.__seed, self.__rs, self.__secondary_only))
self._put_job(asyncjobs.ConnectRSJob(self, self.__seed, self.__rs, self.__secondary_only, err_callback))
# Mark the connection as alive, even though it's not alive yet to prevent double-connecting
self.__alive = True
else:
Expand All @@ -116,34 +122,54 @@ def _socket_connect(self):

def _socket_close(self):
"""cleanup after the socket is closed by the other end"""
if self.__callback:
self.__callback(None, InterfaceError('connection closed'))
callback = self.__callback
self.__callback = None
self.__alive = False
self.__pool.cache(self)
try:
if callback:
callback(None, InterfaceError('connection closed'))
finally:
self.__alive = False
self.__pool.cache(self)

def _close(self):
"""close the socket and cleanup"""
if self.__callback:
self.__callback(None, InterfaceError('connection closed'))
callback = self.__callback
self.__callback = None
self.__alive = False
self.__stream.close()

try:
if callback:
callback(None, InterfaceError('connection closed'))
finally:
self.__alive = False
self.__stream.close()

def close(self):
"""close this connection; re-cache this connection object"""
self._close()
self.__pool.cache(self)
try:
self._close()
finally:
self.__pool.cache(self)

def send_message(self, message, callback):
""" send a message over the wire; callback=None indicates a safe=False call where we write and forget about it"""

if self.__callback is not None:
raise ProgrammingError('connection already in use')


if callback:
err_callback = functools.partial(callback, None)
else:
err_callback = None

# Go and update err_callback for async jobs in queue if any
for job in self.__job_queue:
# this is a dirty hack and I hate it, but there is no way of setting the correct
# err_callback during the connection time
if isinstance(job, asyncjobs.AsyncJob):
job.update_err_callback(err_callback)

if not self.__alive:
if self.__autoreconnect:
self.__connect()
self.__connect(err_callback)
else:
raise InterfaceError('connection invalid. autoreconnect=False')

Expand Down
2 changes: 1 addition & 1 deletion asyncmongo/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Error(StandardError):
class InterfaceError(Error):
pass

class RSConnectionError(Error):
class RSConnectionError(InterfaceError):
pass

class DatabaseError(Error):
Expand Down
28 changes: 25 additions & 3 deletions test/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def setUp(self):
def test_authentication(self):
try:
test_shunt.setup()
db = asyncmongo.Client(pool_id='testauth', host='127.0.0.1', port=27018, dbname='test', dbuser='testuser', dbpass='testpass', maxconnections=2)
db = asyncmongo.Client(pool_id='testauth', host='127.0.0.1', port=27018, dbname='test', dbuser='testuser',
dbpass='testpass', maxconnections=2)

def update_callback(response, error):
logging.info("UPDATE:")
Expand All @@ -27,8 +28,9 @@ def update_callback(response, error):
assert len(response) == 1
test_shunt.register_called('update')

db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, upsert=True, callback=update_callback)

db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, upsert=True,
callback=update_callback)

tornado.ioloop.IOLoop.instance().start()
test_shunt.assert_called('update')

Expand All @@ -49,3 +51,23 @@ def query_callback(response, error):
tornado.ioloop.IOLoop.instance().stop()
raise

def test_failed_auth(self):
try:
test_shunt.setup()
db = asyncmongo.Client(pool_id='testauth_f', host='127.0.0.1', port=27018, dbname='test', dbuser='testuser',
dbpass='wrong', maxconnections=2)

def query_callback(response, error):
tornado.ioloop.IOLoop.instance().stop()
logging.info(response)
logging.info(error)
assert isinstance(error, asyncmongo.AuthenticationError)
assert response is None
test_shunt.register_called('auth_failed')

db.test_stats.find_one({"_id" : TEST_TIMESTAMP}, callback=query_callback)
tornado.ioloop.IOLoop.instance().start()
test_shunt.assert_called('auth_failed')
except:
tornado.ioloop.IOLoop.instance().stop()
raise
16 changes: 16 additions & 0 deletions test/test_replica_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def process(self, *args, **kwargs):
def test_update(self):
try:
test_shunt.setup()

db = asyncmongo.Client(pool_id='testrs_f', rs="wrong_rs", seed=[("127.0.0.1", 27020)], dbname='test', maxconnections=2)

# Try to update with a wrong replica set name
def update_callback(response, error):
tornado.ioloop.IOLoop.instance().stop()
logging.info(response)
logging.info(error)
assert isinstance(error, asyncmongo.RSConnectionError)
test_shunt.register_called('update_f')

db.test_stats.update({"_id" : TEST_TIMESTAMP}, {'$inc' : {'test_count' : 1}}, callback=update_callback)

tornado.ioloop.IOLoop.instance().start()
test_shunt.assert_called('update_f')

db = asyncmongo.Client(pool_id='testrs', rs="rs0", seed=[("127.0.0.1", 27020)], dbname='test', maxconnections=2)

# Update
Expand Down

0 comments on commit 09f6433

Please sign in to comment.