Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow binding a queue to an exchange #9

Merged
merged 2 commits into from
Jan 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions amqp_mock/_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from asyncio import Queue
from collections import OrderedDict
from typing import AsyncGenerator, Dict, List
from collections import defaultdict, OrderedDict
from typing import AsyncGenerator, DefaultDict, Dict, List

from ._message import Message, MessageStatus, QueuedMessage

Expand All @@ -10,17 +10,35 @@ def __init__(self) -> None:
self._exchanges: Dict[str, List[Message]] = {}
self._queues: Dict[str, Queue[Message]] = {}
self._history: Dict[str, QueuedMessage] = OrderedDict()
self._binds: DefaultDict[str, Dict[str, str]] = defaultdict(dict)

async def clear(self) -> None:
self._exchanges = {}
self._queues = {}
self._history = OrderedDict()
self._binds = defaultdict(dict)

async def add_message_to_exchange(self, exchange: str, message: Message) -> None:
if exchange not in self._exchanges:
self._exchanges[exchange] = []
self._exchanges[exchange].insert(0, message)

binds = self._binds.get(exchange)
routing_key = message.routing_key

if binds and routing_key in binds:
await self.add_message_to_queue(binds[routing_key], message)

async def bind_queue_to_exchange(self, queue: str, exchange: str,
routing_key: str = "") -> None:
self._binds[exchange][routing_key] = queue

async def declare_queue(self, queue: str) -> None:
if queue not in self._queues:
self._queues[queue] = Queue()

await self.bind_queue_to_exchange(queue, "", routing_key=queue)

async def get_messages_from_exchange(self, exchange: str) -> List[Message]:
if exchange not in self._exchanges:
return []
Expand All @@ -31,8 +49,7 @@ async def delete_messages_from_exchange(self, exchange: str) -> None:
self._exchanges[exchange] = []

async def add_message_to_queue(self, queue: str, message: Message) -> None:
if queue not in self._queues:
self._queues[queue] = Queue()
await self.declare_queue(queue)
await self._queues[queue].put(message)
self._history[message.id] = QueuedMessage(message, queue)

Expand Down
16 changes: 16 additions & 0 deletions amqp_mock/amqp_server/_amqp_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,21 @@ def __init__(self, reader: StreamReader, writer: StreamWriter,
self._incoming_message: Union[Message, None] = None
self._delivery_tag = 0
self._on_consume = on_consume
self._on_bind: Optional[Callable[[str, str, str], Awaitable[None]]] = None
self._on_declare_queue: Optional[Callable[[str], Awaitable[None]]] = None
self._on_publish: Optional[Callable[[Message], Awaitable[None]]] = None
self._on_ack: Optional[Callable[[str], Awaitable[None]]] = None
self._on_nack: Optional[Callable[[str], Awaitable[None]]] = None
self._on_close: Optional[Callable[['AmqpConnection'], Awaitable[None]]] = None

def on_bind(self, callback: Callable[[str, str, str], Awaitable[None]]) -> 'AmqpConnection':
self._on_bind = callback
return self

def on_declare_queue(self, callback: Callable[[str], Awaitable[None]]) -> 'AmqpConnection':
self._on_declare_queue = callback
return self

def on_publish(self, callback: Callable[[Message], Awaitable[None]]) -> 'AmqpConnection':
self._on_publish = callback
return self
Expand Down Expand Up @@ -190,6 +200,9 @@ async def _send_channel_open_ok(self, channel_id: int, frame_in: spec.Channel.Op
await self._send_frame(channel_id, frame_out)

async def _send_queue_declare_ok(self, channel_id: int, frame_in: spec.Queue.Declare) -> None:
if self._on_declare_queue:
await self._on_declare_queue(frame_in.queue)

frame_out = spec.Queue.DeclareOk(queue=frame_in.queue, message_count=0, consumer_count=0)
return await self._send_frame(channel_id, frame_out)

Expand All @@ -199,6 +212,9 @@ async def _send_exchange_declare_ok(self, channel_id: int,
return await self._send_frame(channel_id, frame_out)

async def _send_queue_bind_ok(self, channel_id: int, frame_in: spec.Queue.Bind) -> None:
if self._on_bind:
await self._on_bind(frame_in.queue, frame_in.exchange, frame_in.routing_key)

frame_out = spec.Queue.BindOk()
return await self._send_frame(channel_id, frame_out)

Expand Down
8 changes: 8 additions & 0 deletions amqp_mock/amqp_server/_amqp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def port(self) -> Optional[int]:
def port(self, value: int) -> None:
self._port = value

async def _on_bind(self, queue: str, exchange: str, routing_key: str) -> None:
await self._storage.bind_queue_to_exchange(queue, exchange, routing_key)

async def _on_declare_queue(self, queue: str):
await self._storage.declare_queue(queue)

async def _on_publish(self, message: Message) -> None:
try:
message.value = json.loads(message.value.decode())
Expand All @@ -72,6 +78,8 @@ async def _on_close(self, connection: AmqpConnection) -> None:
def __call__(self, reader: StreamReader, writer: StreamWriter) -> AmqpConnection:
connection = AmqpConnection(reader, writer, self._on_consume, self._server_properties)
connection.on_publish(self._on_publish) \
.on_bind(self._on_bind) \
.on_declare_queue(self._on_declare_queue) \
.on_ack(self._on_ack) \
.on_nack(self._on_nack) \
.on_close(self._on_close)
Expand Down
5 changes: 3 additions & 2 deletions tests/_test_utils/amqp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ async def queue_bind(self, queue_name: str, exchange_name: str) -> None:
res = await self._channel.queue_bind(queue_name, exchange_name, routing_key="")
assert isinstance(res, spec.Queue.BindOk)

async def publish(self, message: bytes, exchange_name: str) -> None:
res = await self._channel.basic_publish(message, exchange=exchange_name, routing_key="")
async def publish(self, message: bytes, exchange_name: str, routing_key: str = "") -> None:
res = await self._channel.basic_publish(message, exchange=exchange_name,
routing_key=routing_key)
assert isinstance(res, spec.Basic.Ack)

async def _on_message(self, message: DeliveredMessage) -> None:
Expand Down
38 changes: 38 additions & 0 deletions tests/test_publish_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ async def test_publish_message_specific_queue(*, mock_server, mock_client, amqp_
assert messages[0].body == to_binary(message1)


@pytest.mark.asyncio
async def test_publish_to_default_exchange(*, mock_server, mock_client, amqp_client):
with given:
exchange = ""
queue = "test_queue"
message = {"value": "text"}

with when:
await amqp_client.declare_queue(queue)
await amqp_client.publish(to_binary(message), exchange, routing_key=queue)
await amqp_client.consume(queue)

with then:
messages = await amqp_client.wait_for(message_count=1)
assert len(messages) == 1
assert messages[0].body == to_binary(message)
messages = await mock_client.get_exchange_messages("")
assert len(messages) == 1
assert messages[0].value == message


@pytest.mark.asyncio
async def test_publish_to_exchange_with_bound_queue(*, mock_server, amqp_client):
with given:
exchange = "test_exchange"
queue = "test_queue"
message = b"text"

with when:
await amqp_client.queue_bind(queue, exchange)
await amqp_client.publish(message, exchange)
await amqp_client.consume(queue)

with then:
messages = await amqp_client.wait_for(message_count=1)
assert len(messages) == 1


@pytest.mark.asyncio
async def test_publish_no_messages(*, mock_server, amqp_client):
with given:
Expand Down