diff --git a/dj_cqrs/management/commands/cqrs_consume.py b/dj_cqrs/management/commands/cqrs_consume.py index e1fbf1a..f742d34 100644 --- a/dj_cqrs/management/commands/cqrs_consume.py +++ b/dj_cqrs/management/commands/cqrs_consume.py @@ -2,9 +2,10 @@ from multiprocessing import Process +from dj_cqrs.registries import ReplicaRegistry from dj_cqrs.transport import current_transport -from django.core.management.base import BaseCommand +from django.core.management.base import BaseCommand, CommandError class Command(BaseCommand): @@ -12,17 +13,39 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--workers', '-w', help='Number of workers', type=int, default=0) + parser.add_argument( + '--cqrs-id', + '-cid', + nargs='*', + type=str, + help='Choose model(s) by CQRS_ID for consuming', + ) def handle(self, *args, **options): - if options['workers'] == 0: - current_transport.consume() - else: - pool = [] - - for _ in range(options['workers']): - p = Process(target=current_transport.consume) - pool.append(p) - p.start() - - for p in pool: - p.join() + consume_kwargs = {} + + if options.get('cqrs_id'): + cqrs_ids = set() + + for cqrs_id in options['cqrs_id']: + model = ReplicaRegistry.get_model_by_cqrs_id(cqrs_id) + if not model: + raise CommandError('Wrong CQRS ID: {0}!'.format(cqrs_id)) + + cqrs_ids.add(cqrs_id) + + consume_kwargs['cqrs_ids'] = cqrs_ids + + if options['workers'] <= 1: + current_transport.consume(**consume_kwargs) + return + + pool = [] + + for _ in range(options['workers']): + p = Process(target=current_transport.consume, kwargs=consume_kwargs) + pool.append(p) + p.start() + + for p in pool: + p.join() diff --git a/dj_cqrs/transport/kombu.py b/dj_cqrs/transport/kombu.py index a228fe1..93469b8 100644 --- a/dj_cqrs/transport/kombu.py +++ b/dj_cqrs/transport/kombu.py @@ -23,7 +23,7 @@ class _KombuConsumer(ConsumerMixin): - def __init__(self, url, exchange_name, queue_name, prefetch_count, callback): + def __init__(self, url, exchange_name, queue_name, prefetch_count, callback, cqrs_ids=None): self.connection = Connection(url) self.exchange = Exchange( exchange_name, @@ -34,28 +34,31 @@ def __init__(self, url, exchange_name, queue_name, prefetch_count, callback): self.prefetch_count = prefetch_count self.callback = callback self.queues = [] + self.cqrs_ids = cqrs_ids + self._init_queues() def _init_queues(self): channel = self.connection.channel() for cqrs_id in ReplicaRegistry.models.keys(): - q = Queue( - self.queue_name, - exchange=self.exchange, - routing_key=cqrs_id, - ) - q.maybe_bind(channel) - q.declare() - self.queues.append(q) - - sync_q = Queue( - self.queue_name, - exchange=self.exchange, - routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id), - ) - sync_q.maybe_bind(channel) - sync_q.declare() - self.queues.append(sync_q) + if (not self.cqrs_ids) or (cqrs_id in self.cqrs_ids): + q = Queue( + self.queue_name, + exchange=self.exchange, + routing_key=cqrs_id, + ) + q.maybe_bind(channel) + q.declare() + self.queues.append(q) + + sync_q = Queue( + self.queue_name, + exchange=self.exchange, + routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id), + ) + sync_q.maybe_bind(channel) + sync_q.declare() + self.queues.append(sync_q) def get_consumers(self, Consumer, channel): return [ @@ -77,7 +80,7 @@ def clean_connection(cls): pass @classmethod - def consume(cls): + def consume(cls, cqrs_ids=None): queue_name, prefetch_count = cls._get_consumer_settings() url, exchange_name = cls._get_common_settings() @@ -87,6 +90,7 @@ def consume(cls): queue_name, prefetch_count, cls._consume_message, + cqrs_ids=cqrs_ids, ) consumer.run() diff --git a/dj_cqrs/transport/mock.py b/dj_cqrs/transport/mock.py index b23acb1..3f3e2c6 100644 --- a/dj_cqrs/transport/mock.py +++ b/dj_cqrs/transport/mock.py @@ -1,4 +1,4 @@ -# Copyright © 2020 Ingram Micro Inc. All rights reserved. +# Copyright © 2021 Ingram Micro Inc. All rights reserved. from dj_cqrs.transport import BaseTransport @@ -9,5 +9,5 @@ def produce(payload): return TransportMock.consume(payload) @staticmethod - def consume(payload): + def consume(payload=None, **kwargs): return payload diff --git a/dj_cqrs/transport/rabbit_mq.py b/dj_cqrs/transport/rabbit_mq.py index 35e4bab..2dd9441 100644 --- a/dj_cqrs/transport/rabbit_mq.py +++ b/dj_cqrs/transport/rabbit_mq.py @@ -45,7 +45,7 @@ def clean_connection(cls): cls._producer_channel = None @classmethod - def consume(cls): + def consume(cls, cqrs_ids=None): consumer_rabbit_settings = cls._get_consumer_settings() common_rabbit_settings = cls._get_common_settings() @@ -54,7 +54,7 @@ def consume(cls): try: delay_queue = DelayQueue(max_size=get_delay_queue_max_size()) connection, channel, consumer_generator = cls._get_consumer_rmq_objects( - *(common_rabbit_settings + consumer_rabbit_settings), + *(common_rabbit_settings + consumer_rabbit_settings), cqrs_ids=cqrs_ids, ) for method_frame, properties, body in consumer_generator: @@ -239,7 +239,15 @@ def _get_produced_message_routing_key(cls, payload): @classmethod def _get_consumer_rmq_objects( - cls, host, port, creds, exchange, queue_name, dead_letter_queue_name, prefetch_count, + cls, + host, + port, + creds, + exchange, + queue_name, + dead_letter_queue_name, + prefetch_count, + cqrs_ids=None, ): connection = BlockingConnection( ConnectionParameters(host=host, port=port, credentials=creds), @@ -252,6 +260,9 @@ def _get_consumer_rmq_objects( channel.queue_declare(dead_letter_queue_name, durable=True, exclusive=False) for cqrs_id, _ in ReplicaRegistry.models.items(): + if cqrs_ids and cqrs_id not in cqrs_ids: + continue + channel.queue_bind(exchange=exchange, queue=queue_name, routing_key=cqrs_id) # Every service must have specific SYNC or requeue routes diff --git a/tests/dj/transport.py b/tests/dj/transport.py index b786cf5..4641bf2 100644 --- a/tests/dj/transport.py +++ b/tests/dj/transport.py @@ -1,4 +1,4 @@ -# Copyright © 2020 Ingram Micro Inc. All rights reserved. +# Copyright © 2021 Ingram Micro Inc. All rights reserved. import os @@ -14,8 +14,9 @@ def produce(payload): TransportStub.consume(payload) @staticmethod - def consume(payload): - consumer.consume(payload) + def consume(payload=None): + if payload: + return consumer.consume(payload) class RabbitMQTransportWithEvents(RabbitMQTransport): diff --git a/tests/test_commands/test_consume.py b/tests/test_commands/test_consume.py new file mode 100644 index 0000000..a120c14 --- /dev/null +++ b/tests/test_commands/test_consume.py @@ -0,0 +1,50 @@ +# Copyright © 2021 Ingram Micro Inc. All rights reserved. + +from importlib import import_module, reload + +from django.core.management import CommandError, call_command + +import pytest + + +COMMAND_NAME = 'cqrs_consume' + + +@pytest.fixture +def reload_transport(): + reload(import_module('dj_cqrs.transport')) + + +def test_no_arguments(mocker, reload_transport): + consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') + + call_command(COMMAND_NAME) + + consume_mock.assert_called_once_with() + + +def test_several_workers(reload_transport): + call_command(COMMAND_NAME, '--workers=2') + + +def test_one_worker_one_cqrs_id(mocker, reload_transport): + consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') + + call_command(COMMAND_NAME, '--workers=1', '-cid=author') + + consume_mock.assert_called_once_with(cqrs_ids={'author'}) + + +def test_several_cqrs_id(mocker, reload_transport): + consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') + + call_command(COMMAND_NAME, cqrs_id=['author', 'basic', 'author', 'no_db']) + + consume_mock.assert_called_once_with(cqrs_ids={'author', 'basic', 'no_db'}) + + +def test_wrong_cqrs_id(reload_transport): + with pytest.raises(CommandError) as e: + call_command(COMMAND_NAME, cqrs_id=['author', 'random', 'no_db']) + + assert "Wrong CQRS ID: random!" in str(e)