Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions dj_cqrs/management/commands/cqrs_consume.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,50 @@

from multiprocessing import Process

from dj_cqrs.registries import ReplicaRegistry
from dj_cqrs.transport import current_transport

from django.core.management.base import BaseCommand
from django.core.management.base import BaseCommand, CommandError


class Command(BaseCommand):
help = 'Starts CQRS worker, which consumes messages from message queue.'

def add_arguments(self, parser):
parser.add_argument('--workers', '-w', help='Number of workers', type=int, default=0)
parser.add_argument(
'--cqrs-id',
'-cid',
nargs='*',
type=str,
help='Choose model(s) by CQRS_ID for consuming',
)

def handle(self, *args, **options):
if options['workers'] == 0:
current_transport.consume()
else:
pool = []

for _ in range(options['workers']):
p = Process(target=current_transport.consume)
pool.append(p)
p.start()

for p in pool:
p.join()
consume_kwargs = {}

if options.get('cqrs_id'):
cqrs_ids = set()

for cqrs_id in options['cqrs_id']:
model = ReplicaRegistry.get_model_by_cqrs_id(cqrs_id)
if not model:
raise CommandError('Wrong CQRS ID: {0}!'.format(cqrs_id))

cqrs_ids.add(cqrs_id)

consume_kwargs['cqrs_ids'] = cqrs_ids

if options['workers'] <= 1:
current_transport.consume(**consume_kwargs)
return

pool = []

for _ in range(options['workers']):
p = Process(target=current_transport.consume, kwargs=consume_kwargs)
pool.append(p)
p.start()

for p in pool:
p.join()
42 changes: 23 additions & 19 deletions dj_cqrs/transport/kombu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class _KombuConsumer(ConsumerMixin):

def __init__(self, url, exchange_name, queue_name, prefetch_count, callback):
def __init__(self, url, exchange_name, queue_name, prefetch_count, callback, cqrs_ids=None):
self.connection = Connection(url)
self.exchange = Exchange(
exchange_name,
Expand All @@ -34,28 +34,31 @@ def __init__(self, url, exchange_name, queue_name, prefetch_count, callback):
self.prefetch_count = prefetch_count
self.callback = callback
self.queues = []
self.cqrs_ids = cqrs_ids

self._init_queues()

def _init_queues(self):
channel = self.connection.channel()
for cqrs_id in ReplicaRegistry.models.keys():
q = Queue(
self.queue_name,
exchange=self.exchange,
routing_key=cqrs_id,
)
q.maybe_bind(channel)
q.declare()
self.queues.append(q)

sync_q = Queue(
self.queue_name,
exchange=self.exchange,
routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id),
)
sync_q.maybe_bind(channel)
sync_q.declare()
self.queues.append(sync_q)
if (not self.cqrs_ids) or (cqrs_id in self.cqrs_ids):
q = Queue(
self.queue_name,
exchange=self.exchange,
routing_key=cqrs_id,
)
q.maybe_bind(channel)
q.declare()
self.queues.append(q)

sync_q = Queue(
self.queue_name,
exchange=self.exchange,
routing_key='cqrs.{0}.{1}'.format(self.queue_name, cqrs_id),
)
sync_q.maybe_bind(channel)
sync_q.declare()
self.queues.append(sync_q)

def get_consumers(self, Consumer, channel):
return [
Expand All @@ -77,7 +80,7 @@ def clean_connection(cls):
pass

@classmethod
def consume(cls):
def consume(cls, cqrs_ids=None):
queue_name, prefetch_count = cls._get_consumer_settings()
url, exchange_name = cls._get_common_settings()

Expand All @@ -87,6 +90,7 @@ def consume(cls):
queue_name,
prefetch_count,
cls._consume_message,
cqrs_ids=cqrs_ids,
)
consumer.run()

Expand Down
4 changes: 2 additions & 2 deletions dj_cqrs/transport/mock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © 2020 Ingram Micro Inc. All rights reserved.
# Copyright © 2021 Ingram Micro Inc. All rights reserved.

from dj_cqrs.transport import BaseTransport

Expand All @@ -9,5 +9,5 @@ def produce(payload):
return TransportMock.consume(payload)

@staticmethod
def consume(payload):
def consume(payload=None, **kwargs):
return payload
17 changes: 14 additions & 3 deletions dj_cqrs/transport/rabbit_mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def clean_connection(cls):
cls._producer_channel = None

@classmethod
def consume(cls):
def consume(cls, cqrs_ids=None):
consumer_rabbit_settings = cls._get_consumer_settings()
common_rabbit_settings = cls._get_common_settings()

Expand All @@ -54,7 +54,7 @@ def consume(cls):
try:
delay_queue = DelayQueue(max_size=get_delay_queue_max_size())
connection, channel, consumer_generator = cls._get_consumer_rmq_objects(
*(common_rabbit_settings + consumer_rabbit_settings),
*(common_rabbit_settings + consumer_rabbit_settings), cqrs_ids=cqrs_ids,
)

for method_frame, properties, body in consumer_generator:
Expand Down Expand Up @@ -239,7 +239,15 @@ def _get_produced_message_routing_key(cls, payload):

@classmethod
def _get_consumer_rmq_objects(
cls, host, port, creds, exchange, queue_name, dead_letter_queue_name, prefetch_count,
cls,
host,
port,
creds,
exchange,
queue_name,
dead_letter_queue_name,
prefetch_count,
cqrs_ids=None,
):
connection = BlockingConnection(
ConnectionParameters(host=host, port=port, credentials=creds),
Expand All @@ -252,6 +260,9 @@ def _get_consumer_rmq_objects(
channel.queue_declare(dead_letter_queue_name, durable=True, exclusive=False)

for cqrs_id, _ in ReplicaRegistry.models.items():
if cqrs_ids and cqrs_id not in cqrs_ids:
continue

channel.queue_bind(exchange=exchange, queue=queue_name, routing_key=cqrs_id)

# Every service must have specific SYNC or requeue routes
Expand Down
7 changes: 4 additions & 3 deletions tests/dj/transport.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright © 2020 Ingram Micro Inc. All rights reserved.
# Copyright © 2021 Ingram Micro Inc. All rights reserved.

import os

Expand All @@ -14,8 +14,9 @@ def produce(payload):
TransportStub.consume(payload)

@staticmethod
def consume(payload):
consumer.consume(payload)
def consume(payload=None):
if payload:
return consumer.consume(payload)


class RabbitMQTransportWithEvents(RabbitMQTransport):
Expand Down
50 changes: 50 additions & 0 deletions tests/test_commands/test_consume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright © 2021 Ingram Micro Inc. All rights reserved.

from importlib import import_module, reload

from django.core.management import CommandError, call_command

import pytest


COMMAND_NAME = 'cqrs_consume'


@pytest.fixture
def reload_transport():
reload(import_module('dj_cqrs.transport'))


def test_no_arguments(mocker, reload_transport):
consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume')

call_command(COMMAND_NAME)

consume_mock.assert_called_once_with()


def test_several_workers(reload_transport):
call_command(COMMAND_NAME, '--workers=2')


def test_one_worker_one_cqrs_id(mocker, reload_transport):
consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume')

call_command(COMMAND_NAME, '--workers=1', '-cid=author')

consume_mock.assert_called_once_with(cqrs_ids={'author'})


def test_several_cqrs_id(mocker, reload_transport):
consume_mock = mocker.patch('tests.dj.transport.TransportStub.consume')

call_command(COMMAND_NAME, cqrs_id=['author', 'basic', 'author', 'no_db'])

consume_mock.assert_called_once_with(cqrs_ids={'author', 'basic', 'no_db'})


def test_wrong_cqrs_id(reload_transport):
with pytest.raises(CommandError) as e:
call_command(COMMAND_NAME, cqrs_id=['author', 'random', 'no_db'])

assert "Wrong CQRS ID: random!" in str(e)