From 1e1a8e4d6b6a5893a7f46c7b0a75050640513236 Mon Sep 17 00:00:00 2001 From: Maxim Kolyubyakin Date: Sun, 19 Dec 2021 09:39:14 +0100 Subject: [PATCH 1/2] LITE-21474 Added optional cqrs-id argument to consume command to choose models for cosuming --- dj_cqrs/management/commands/cqrs_consume.py | 37 ++++++++++++------ dj_cqrs/transport/kombu.py | 42 +++++++++++--------- dj_cqrs/transport/mock.py | 4 +- dj_cqrs/transport/rabbit_mq.py | 17 ++++++-- tests/dj/transport.py | 7 ++-- tests/test_commands/test_consume.py | 43 +++++++++++++++++++++ 6 files changed, 111 insertions(+), 39 deletions(-) create mode 100644 tests/test_commands/test_consume.py diff --git a/dj_cqrs/management/commands/cqrs_consume.py b/dj_cqrs/management/commands/cqrs_consume.py index e1fbf1a..b59acc7 100644 --- a/dj_cqrs/management/commands/cqrs_consume.py +++ b/dj_cqrs/management/commands/cqrs_consume.py @@ -12,17 +12,30 @@ class Command(BaseCommand): def add_arguments(self, parser): parser.add_argument('--workers', '-w', help='Number of workers', type=int, default=0) + parser.add_argument( + '--cqrs-id', + '-cid', + nargs='*', + type=str, + help='Choose model(s) by CQRS_ID for consuming', + ) def handle(self, *args, **options): - if options['workers'] == 0: - current_transport.consume() - else: - pool = [] - - for _ in range(options['workers']): - p = Process(target=current_transport.consume) - pool.append(p) - p.start() - - for p in pool: - p.join() + consume_kwargs = {} + + if options.get('cqrs_id'): + consume_kwargs['cqrs_ids'] = set(options['cqrs_id']) + + if options['workers'] <= 1: + current_transport.consume(**consume_kwargs) + return + + pool = [] + + for _ in range(options['workers']): + p = Process(target=current_transport.consume, kwargs=consume_kwargs) + pool.append(p) + p.start() + + for p in pool: + p.join() diff --git a/dj_cqrs/transport/kombu.py b/dj_cqrs/transport/kombu.py index a228fe1..93469b8 100644 --- a/dj_cqrs/transport/kombu.py +++ b/dj_cqrs/transport/kombu.py @@ -23,7 +23,7 @@ class _KombuConsumer(ConsumerMixin): - def __init__(self, url, exchange_name, queue_name, prefetch_count, callback): + def __init__(self, url, exchange_name, queue_name, prefetch_count, callback, cqrs_ids=None): self.connection = Connection(url) self.exchange = Exchange( exchange_name, @@ -34,28 +34,31 @@ def __init__(self, url, exchange_name, queue_name, prefetch_count, callback): self.prefetch_count = prefetch_count self.callback = callback self.queues = [] + self.cqrs_ids = cqrs_ids + self._init_queues() def _init_queues(self): channel = self.connection.channel() for cqrs_id in ReplicaRegistry.models.keys(): - q = Queue( - self.queue_name, - exchange=self.exchange, - routing_key=cqrs_id, - ) - q.maybe_bind(channel) - q.declare() - self.queues.append(q) - - sync_q = Queue( - self.queue_name, - exchange=self.exchange, - routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id), - ) - sync_q.maybe_bind(channel) - sync_q.declare() - self.queues.append(sync_q) + if (not self.cqrs_ids) or (cqrs_id in self.cqrs_ids): + q = Queue( + self.queue_name, + exchange=self.exchange, + routing_key=cqrs_id, + ) + q.maybe_bind(channel) + q.declare() + self.queues.append(q) + + sync_q = Queue( + self.queue_name, + exchange=self.exchange, + routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id), + ) + sync_q.maybe_bind(channel) + sync_q.declare() + self.queues.append(sync_q) def get_consumers(self, Consumer, channel): return [ @@ -77,7 +80,7 @@ def clean_connection(cls): pass @classmethod - def consume(cls): + def consume(cls, cqrs_ids=None): queue_name, prefetch_count = cls._get_consumer_settings() url, exchange_name = cls._get_common_settings() @@ -87,6 +90,7 @@ def consume(cls): queue_name, prefetch_count, cls._consume_message, + cqrs_ids=cqrs_ids, ) consumer.run() diff --git a/dj_cqrs/transport/mock.py b/dj_cqrs/transport/mock.py index b23acb1..3f3e2c6 100644 --- a/dj_cqrs/transport/mock.py +++ b/dj_cqrs/transport/mock.py @@ -1,4 +1,4 @@ -# Copyright © 2020 Ingram Micro Inc. All rights reserved. +# Copyright © 2021 Ingram Micro Inc. All rights reserved. from dj_cqrs.transport import BaseTransport @@ -9,5 +9,5 @@ def produce(payload): return TransportMock.consume(payload) @staticmethod - def consume(payload): + def consume(payload=None, **kwargs): return payload diff --git a/dj_cqrs/transport/rabbit_mq.py b/dj_cqrs/transport/rabbit_mq.py index 35e4bab..2dd9441 100644 --- a/dj_cqrs/transport/rabbit_mq.py +++ b/dj_cqrs/transport/rabbit_mq.py @@ -45,7 +45,7 @@ def clean_connection(cls): cls._producer_channel = None @classmethod - def consume(cls): + def consume(cls, cqrs_ids=None): consumer_rabbit_settings = cls._get_consumer_settings() common_rabbit_settings = cls._get_common_settings() @@ -54,7 +54,7 @@ def consume(cls): try: delay_queue = DelayQueue(max_size=get_delay_queue_max_size()) connection, channel, consumer_generator = cls._get_consumer_rmq_objects( - *(common_rabbit_settings + consumer_rabbit_settings), + *(common_rabbit_settings + consumer_rabbit_settings), cqrs_ids=cqrs_ids, ) for method_frame, properties, body in consumer_generator: @@ -239,7 +239,15 @@ def _get_produced_message_routing_key(cls, payload): @classmethod def _get_consumer_rmq_objects( - cls, host, port, creds, exchange, queue_name, dead_letter_queue_name, prefetch_count, + cls, + host, + port, + creds, + exchange, + queue_name, + dead_letter_queue_name, + prefetch_count, + cqrs_ids=None, ): connection = BlockingConnection( ConnectionParameters(host=host, port=port, credentials=creds), @@ -252,6 +260,9 @@ def _get_consumer_rmq_objects( channel.queue_declare(dead_letter_queue_name, durable=True, exclusive=False) for cqrs_id, _ in ReplicaRegistry.models.items(): + if cqrs_ids and cqrs_id not in cqrs_ids: + continue + channel.queue_bind(exchange=exchange, queue=queue_name, routing_key=cqrs_id) # Every service must have specific SYNC or requeue routes diff --git a/tests/dj/transport.py b/tests/dj/transport.py index b786cf5..4641bf2 100644 --- a/tests/dj/transport.py +++ b/tests/dj/transport.py @@ -1,4 +1,4 @@ -# Copyright © 2020 Ingram Micro Inc. All rights reserved. +# Copyright © 2021 Ingram Micro Inc. All rights reserved. import os @@ -14,8 +14,9 @@ def produce(payload): TransportStub.consume(payload) @staticmethod - def consume(payload): - consumer.consume(payload) + def consume(payload=None): + if payload: + return consumer.consume(payload) class RabbitMQTransportWithEvents(RabbitMQTransport): diff --git a/tests/test_commands/test_consume.py b/tests/test_commands/test_consume.py new file mode 100644 index 0000000..e7c8ed8 --- /dev/null +++ b/tests/test_commands/test_consume.py @@ -0,0 +1,43 @@ +# Copyright © 2021 Ingram Micro Inc. All rights reserved. + +from importlib import import_module, reload + +from django.core.management import call_command + +import pytest + + +COMMAND_NAME = 'cqrs_consume' + + +@pytest.fixture +def reload_transport(): + reload(import_module('dj_cqrs.transport')) + + +def test_no_arguments(mocker, reload_transport): + consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') + + call_command(COMMAND_NAME) + + consume_mock.assert_called_once_with() + + +def test_several_workers(reload_transport): + call_command(COMMAND_NAME, '--workers=2') + + +def test_one_worker_one_cqrs_id(mocker, reload_transport): + consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') + + call_command(COMMAND_NAME, '--workers=1', '-cid=author') + + consume_mock.assert_called_once_with(cqrs_ids={'author'}) + + +def test_several_cqrs_id(mocker, reload_transport): + consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') + + call_command(COMMAND_NAME, cqrs_id=['a', 'b', 'a', 'c']) + + consume_mock.assert_called_once_with(cqrs_ids={'a', 'b', 'c'}) From c98c4011804dcadb171620f90e012862b2a0cb6a Mon Sep 17 00:00:00 2001 From: Maxim Kolyubyakin Date: Tue, 21 Dec 2021 11:01:30 +0100 Subject: [PATCH 2/2] LITE-21474 Added validation to cqrs-id arguments in consume command --- dj_cqrs/management/commands/cqrs_consume.py | 14 ++++++++++++-- tests/test_commands/test_consume.py | 13 ++++++++++--- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/dj_cqrs/management/commands/cqrs_consume.py b/dj_cqrs/management/commands/cqrs_consume.py index b59acc7..f742d34 100644 --- a/dj_cqrs/management/commands/cqrs_consume.py +++ b/dj_cqrs/management/commands/cqrs_consume.py @@ -2,9 +2,10 @@ from multiprocessing import Process +from dj_cqrs.registries import ReplicaRegistry from dj_cqrs.transport import current_transport -from django.core.management.base import BaseCommand +from django.core.management.base import BaseCommand, CommandError class Command(BaseCommand): @@ -24,7 +25,16 @@ def handle(self, *args, **options): consume_kwargs = {} if options.get('cqrs_id'): - consume_kwargs['cqrs_ids'] = set(options['cqrs_id']) + cqrs_ids = set() + + for cqrs_id in options['cqrs_id']: + model = ReplicaRegistry.get_model_by_cqrs_id(cqrs_id) + if not model: + raise CommandError('Wrong CQRS ID: {0}!'.format(cqrs_id)) + + cqrs_ids.add(cqrs_id) + + consume_kwargs['cqrs_ids'] = cqrs_ids if options['workers'] <= 1: current_transport.consume(**consume_kwargs) diff --git a/tests/test_commands/test_consume.py b/tests/test_commands/test_consume.py index e7c8ed8..a120c14 100644 --- a/tests/test_commands/test_consume.py +++ b/tests/test_commands/test_consume.py @@ -2,7 +2,7 @@ from importlib import import_module, reload -from django.core.management import call_command +from django.core.management import CommandError, call_command import pytest @@ -38,6 +38,13 @@ def test_one_worker_one_cqrs_id(mocker, reload_transport): def test_several_cqrs_id(mocker, reload_transport): consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume') - call_command(COMMAND_NAME, cqrs_id=['a', 'b', 'a', 'c']) + call_command(COMMAND_NAME, cqrs_id=['author', 'basic', 'author', 'no_db']) - consume_mock.assert_called_once_with(cqrs_ids={'a', 'b', 'c'}) + consume_mock.assert_called_once_with(cqrs_ids={'author', 'basic', 'no_db'}) + + +def test_wrong_cqrs_id(reload_transport): + with pytest.raises(CommandError) as e: + call_command(COMMAND_NAME, cqrs_id=['author', 'random', 'no_db']) + + assert "Wrong CQRS ID: random!" in str(e)