diff --git a/adsputils/__init__.py b/adsputils/__init__.py index 7da1453..ddc1701 100644 --- a/adsputils/__init__.py +++ b/adsputils/__init__.py @@ -381,8 +381,34 @@ def __init__(self, app_name, *args, **kwargs): self.exchange = Exchange(self._config.get('CELERY_DEFAULT_EXCHANGE', 'ads-pipeline'), type=self._config.get('CELERY_DEFAULT_EXCHANGE_TYPE', 'topic')) - self.forwarding_connection = None - self._forward_message = None + self.forward_message_dict = {} + def setup_forward_message(output_celery_broker=None, output_taskname=None): + broker = output_celery_broker or self._config.get('OUTPUT_CELERY_BROKER') + + if broker: + task_name = output_taskname or self._config.get('OUTPUT_TASKNAME') + if not task_name: + raise NotImplementedError('Sorry, your app is not properly configured (no task handler).') + + @self.task(name=task_name, + exchange=self._config.get('OUTPUT_EXCHANGE', 'ads-pipeline'), + queue=self._config.get('OUTPUT_QUEUE', 'update-record'), + routing_key=self._config.get('OUTPUT_QUEUE', 'update-record')) + def _forward_message(self, *args, **kwargs): + """A handler that can be used to forward stuff out of our + queue. It does nothing (it doesn't process data)""" + self.logger.error('We should have never been called directly! %s' % \ + (args, kwargs)) + + return {'broker': broker, 'forward message': _forward_message} + + if not self._config.get('FORWARD_MSG_DICT'): + self.forward_message_dict['default'] = setup_forward_message() + else: + for setup in self._config.get('FORWARD_MSG_DICT'): + if not setup.get('OUTPUT_PIPELINE') or not setup.get('OUTPUT_CELERY_BROKER') or not setup.get('OUTPUT_TASKNAME'): + raise NotImplementedError('Sorry, your app is not properly configured (setup for multiple pipelines missing keys).') + self.forward_message_dict[setup.get('OUTPUT_PIPELINE')] = setup_forward_message(output_celery_broker=setup.get('OUTPUT_CELERY_BROKER'), output_taskname=setup.get('OUTPUT_TASKNAME')) # HTTP connection pool # - The maximum number of retries each connection should attempt: this @@ -409,39 +435,33 @@ def _set_serializer(self): self.conf['CELERY_TASK_SERIALIZER'] = 'adsmsg' self.conf['CELERY_RESULT_SERIALIZER'] = 'adsmsg' - def forward_message(self, output_taskname=None, output_celery_broker=None, *args, **kwargs): - """Class method that sets up the message forwarding handler dynamically based on - OUTPUT_TASKNAME and OUTPUT_CELERY_BROKER.""" - - # Use OUTPUT_CELERY_BROKER from config if not provided at call time - broker = output_celery_broker or self._config.get('OUTPUT_CELERY_BROKER') - if broker: - # kombu connection is lazy loaded, so it's ok to create now - self.forwarding_connection = BrokerConnection(broker) - - if not self.forwarding_connection: - raise NotImplementedError('Sorry, your app is not properly configured (no broker).') - - # Use OUTPUT_TASKNAME from config if not provided at call time - task_name = output_taskname or self._config.get('OUTPUT_TASKNAME') - - if task_name: - @self.task(name=task_name, - exchange=self._config.get('OUTPUT_EXCHANGE', 'ads-pipeline'), - queue=self._config.get('OUTPUT_QUEUE', 'update-record'), - routing_key=self._config.get('OUTPUT_QUEUE', 'update-record')) - def _forward_message(*args, **kwargs): - """A handler that can be used to forward stuff out of our queue. It does nothing (it doesn't process data).""" - self.logger.error('We should have never been called directly! %s' % (args, kwargs)) - - self._forward_message = _forward_message - - if not self._forward_message: - raise NotImplementedError('Sorry, your app is not properly configured (no task handler).') - - self.logger.debug('Forwarding results out to: %s', self.forwarding_connection) - return self._forward_message.apply_async(args, kwargs, - connection=self.forwarding_connection) + def forward_message(self, *args, **kwargs): + """Class method that is replaced during initializiton with the real + implementation (IFF) the OUTPUT_TASKNAME and other OUTPUT_ parameters + are specified. + + To set in config: + - For a single output destination: + - OUTPUT_CELERY_BROKER + - OUTPUT_TASKNAME + At call time: + self.forward_message(message) + + - For multiple output destinations: + - FORWARD_MSG_DICT = [{OUTPUT_PIPELINE: , OUTPUT_CELERY_BROKER: , OUTPUT_TASKNAME: }, ...] + where OUTPUT_PIPELINE is a string that will need to be specified in the call to forward_message as: + self.forward_message(message, pipeline=OUTPUT_PIPELINE) + """ + pipeline = kwargs.get('pipeline', 'default') + + if self.forward_message_dict and pipeline: + if not self.forward_message_dict[pipeline].get('broker'): + raise NotImplementedError('Sorry, your app is not properly configured (no broker).') + forwarding_connection = BrokerConnection(self.forward_message_dict[pipeline].get('broker')) + self.logger.debug('Forwarding results out to: %s', self.forward_message_dict[pipeline].get('broker')) + return self.forward_message_dict[pipeline]['forward message'].apply_async(args, kwargs, connection=forwarding_connection) + else: + raise NotImplementedError('Sorry, your app is not properly configured.') def _get_callers_module(self): frame = inspect.stack()[2] diff --git a/adsputils/tests/test_init.py b/adsputils/tests/test_init.py index f71fd21..dd508a8 100644 --- a/adsputils/tests/test_init.py +++ b/adsputils/tests/test_init.py @@ -69,5 +69,32 @@ def test_u2asc(self): input3 = input2.encode('utf16') self.assertRaises(UnicodeHandlerError, adsputils.u2asc, input3) +class TestCelery(unittest.TestCase): + + def test_forward_message_single(self): + app = adsputils.ADSCelery('test',local_config={ + 'OUTPUT_CELERY_BROKER': 'testbroker', + 'OUTPUT_TASKNAME': 'testtaskname' + }) + + self.assertIn('default',app.forward_message_dict.keys()) + self.assertEqual(app.forward_message_dict['default'].get('broker'), 'testbroker') + + def test_forward_message_multiple(self): + app = adsputils.ADSCelery('test', local_config={ + 'FORWARD_MSG_DICT': [{'OUTPUT_PIPELINE': 'augment', + 'OUTPUT_CELERY_BROKER': 'testbroker', + 'OUTPUT_TASKNAME': 'testtaskname'}, + {'OUTPUT_PIPELINE': 'classifier', + 'OUTPUT_CELERY_BROKER': 'testbroker2', + 'OUTPUT_TASKNAME': 'testtaskname2'}] + }) + + self.assertEqual(len(app.forward_message_dict.keys()), 2) + self.assertIn('augment', app.forward_message_dict.keys()) + self.assertIn('classifier', app.forward_message_dict.keys()) + self.assertIn('broker', app.forward_message_dict['augment'].keys()) + self.assertEqual(app.forward_message_dict['augment']['broker'], 'testbroker') + if __name__ == '__main__': unittest.main()