Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fanout exchange #33

Merged
merged 8 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- name: Checkout
Expand Down
45 changes: 31 additions & 14 deletions amqp_mock/_storage.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,58 @@
from asyncio import Queue
from collections import OrderedDict, defaultdict
from typing import AsyncGenerator, DefaultDict, Dict, List
from collections import defaultdict
from typing import AsyncGenerator, DefaultDict, Dict, List, Tuple

from ._message import Message, MessageStatus, QueuedMessage


class Storage:
def __init__(self) -> None:
self._exchanges: Dict[str, List[Message]] = {}
self._exchange_types: Dict[str, str] = {}
self._queues: Dict[str, Queue[Message]] = {}
self._history: Dict[str, QueuedMessage] = OrderedDict()
self._history: List[Tuple[str, QueuedMessage]] = []
self._binds: DefaultDict[str, Dict[str, str]] = defaultdict(dict)

async def clear(self) -> None:
self._exchanges = {}
self._exchange_types = {}
self._queues = {}
self._history = OrderedDict()
self._history = []
self._binds = defaultdict(dict)

async def add_message_to_exchange(self, exchange: str, message: Message) -> None:
if exchange not in self._exchanges:
self._exchanges[exchange] = []
await self.declare_exchange(exchange)
self._exchanges[exchange].insert(0, message)

exchange_type = self._exchange_types[exchange]
binds = self._binds.get(exchange)
routing_key = message.routing_key

if binds and routing_key in binds:
await self.add_message_to_queue(binds[routing_key], message)
if binds:
if exchange_type == "direct":
routing_key = message.routing_key

if routing_key in binds:
await self.add_message_to_queue(binds[routing_key], message)
elif exchange_type == "fanout":
for queue in binds.values():
await self.add_message_to_queue(queue, message)
else:
raise RuntimeError(f"{exchange_type} exchanges not supported")

async def bind_queue_to_exchange(self, queue: str, exchange: str,
routing_key: str = "") -> None:
await self.declare_queue(queue)
self._binds[exchange][routing_key] = queue

async def declare_exchange(self, exchange: str, exchange_type: str = "direct") -> None:
if exchange not in self._exchanges:
self._exchanges[exchange] = []
self._exchange_types[exchange] = exchange_type

async def declare_queue(self, queue: str) -> None:
if queue not in self._queues:
self._queues[queue] = Queue()

await self.bind_queue_to_exchange(queue, "", routing_key=queue)
await self.bind_queue_to_exchange(queue, exchange="", routing_key=queue)

async def get_messages_from_exchange(self, exchange: str) -> List[Message]:
if exchange not in self._exchanges:
Expand All @@ -51,13 +66,15 @@ async def delete_messages_from_exchange(self, exchange: str) -> None:
async def add_message_to_queue(self, queue: str, message: Message) -> None:
await self.declare_queue(queue)
await self._queues[queue].put(message)
self._history[message.id] = QueuedMessage(message, queue)
self._history.append((message.id, QueuedMessage(message, queue)))

async def get_history(self) -> List[QueuedMessage]:
return [message for message in self._history.values()][::-1]
return [message[1] for message in self._history[::-1]]

async def change_message_status(self, message_id: str, status: MessageStatus) -> None:
self._history[message_id].set_status(status)
for msg_id, message in self._history:
if msg_id == message_id:
message.set_status(status)

async def get_next_message(self, queue: str) -> AsyncGenerator[Message, None]:
if queue not in self._queues:
Expand Down
8 changes: 8 additions & 0 deletions amqp_mock/amqp_server/_amqp_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, reader: StreamReader, writer: StreamWriter,
self._delivery_tag = 0
self._on_consume = on_consume
self._on_bind: Optional[Callable[[str, str, str], Awaitable[None]]] = None
self._on_declare_exchange: Optional[Callable[[str, str], Awaitable[None]]] = None
self._on_declare_queue: Optional[Callable[[str], Awaitable[None]]] = None
self._on_publish: Optional[Callable[[Message], Awaitable[None]]] = None
self._on_ack: Optional[Callable[[str], Awaitable[None]]] = None
Expand All @@ -45,6 +46,11 @@ def on_bind(self, callback: Callable[[str, str, str], Awaitable[None]]) -> 'Amqp
self._on_bind = callback
return self

def on_declare_exchange(self, callback: Callable[[str, str],
Awaitable[None]]) -> 'AmqpConnection':
self._on_declare_exchange = callback
return self

def on_declare_queue(self, callback: Callable[[str], Awaitable[None]]) -> 'AmqpConnection':
self._on_declare_queue = callback
return self
Expand Down Expand Up @@ -229,6 +235,8 @@ async def _send_queue_declare_ok(self, channel_id: int,

async def _send_exchange_declare_ok(self, channel_id: int,
frame_in: commands.Exchange.Declare) -> None:
if self._on_declare_exchange:
await self._on_declare_exchange(frame_in.exchange, frame_in.exchange_type)
frame_out = commands.Exchange.DeclareOk()
return await self._send_frame(channel_id, frame_out)

Expand Down
4 changes: 4 additions & 0 deletions amqp_mock/amqp_server/_amqp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def port(self, value: int) -> None:
async def _on_bind(self, queue: str, exchange: str, routing_key: str) -> None:
await self._storage.bind_queue_to_exchange(queue, exchange, routing_key)

async def _on_declare_exchange(self, exchange: str, exchange_type: str) -> None:
await self._storage.declare_exchange(exchange, exchange_type)

async def _on_declare_queue(self, queue: str) -> None:
await self._storage.declare_queue(queue)

Expand Down Expand Up @@ -79,6 +82,7 @@ def __call__(self, reader: StreamReader, writer: StreamWriter) -> AmqpConnection
connection = AmqpConnection(reader, writer, self._on_consume, self._server_properties)
connection.on_publish(self._on_publish) \
.on_bind(self._on_bind) \
.on_declare_exchange(self._on_declare_exchange) \
.on_declare_queue(self._on_declare_queue) \
.on_ack(self._on_ack) \
.on_nack(self._on_nack) \
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ pytest-asyncio==0.21.1
pytest-cov==4.1.0
pytest==7.4.0
rtry==1.5.0
d42==1.7.0
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def find_dev_required():
long_description_content_type="text/markdown",
author="Nikita Tsvetkov",
author_email="tsv1@fastmail.com",
python_requires=">=3.7.0",
python_requires=">=3.8.0",
url="https://github.com/tsv1/amqp-mock",
license="Apache-2.0",
packages=find_packages(exclude=("tests",)),
Expand All @@ -28,7 +28,6 @@ def find_dev_required():
tests_require=find_dev_required(),
classifiers=[
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
Expand Down
5 changes: 3 additions & 2 deletions tests/_test_utils/amqp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ async def declare_queue(self, queue_name: str) -> None:
res = await self._channel.queue_declare(queue_name)
assert isinstance(res, commands.Queue.DeclareOk)

async def queue_bind(self, queue_name: str, exchange_name: str) -> None:
res = await self._channel.queue_bind(queue_name, exchange_name, routing_key="")
async def queue_bind(self, queue_name: str, exchange_name: str,
routing_key: str = "") -> None:
res = await self._channel.queue_bind(queue_name, exchange_name, routing_key=routing_key)
assert isinstance(res, commands.Queue.BindOk)

async def publish(self, message: bytes, exchange_name: str, routing_key: str = "") -> None:
Expand Down
11 changes: 10 additions & 1 deletion tests/_test_utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import json
from typing import Any
from typing import Any, Dict, List, Union
from uuid import uuid4

from amqp_mock import Message, QueuedMessage


def to_binary(message: Any) -> bytes:
return json.dumps(message).encode()


def random_uuid() -> str:
return str(uuid4())


_MessageType = Union[QueuedMessage, Message]


def to_dict(smth: List[_MessageType]) -> List[Dict[str, Any]]:
return [x.to_dict() for x in smth]
22 changes: 22 additions & 0 deletions tests/_test_utils/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from d42 import schema

from amqp_mock import MessageStatus

MessageSchema = schema.dict({
"id": schema.str.len(1, ...),
"value": schema.any,
"exchange": schema.str,
"routing_key": schema.str,
"properties": schema.dict,
})

QueuedMessageSchema = schema.dict({
"message": MessageSchema,
"queue": schema.str,
"status": schema.any(
schema.str(MessageStatus.INIT),
schema.str(MessageStatus.CONSUMING),
schema.str(MessageStatus.ACKED),
schema.str(MessageStatus.NACKED),
),
})
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from district42.types import Schema
from valera import Formatter, validate


def pytest_assertrepr_compare(op, left, right):
if isinstance(right, Schema):
result = validate(right, left)
formatter = Formatter()
errors = ["- " + e.format(formatter) for e in result.get_errors()]
return ["valera.ValidationException"] + errors
23 changes: 23 additions & 0 deletions tests/test_publish_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,29 @@ async def test_publish_message_specific_queue(*, mock_server, mock_client, amqp_
assert messages[0].body == to_binary(message1)


@pytest.mark.asyncio
async def test_publish_to_fanout_exchange(*, mock_server, amqp_client):
with given:
exchange = "test_exchange"
queue1, queue2 = "test_queue1", "test_queue2"
message = {"value": "text"}

with when:
await amqp_client.declare_exchange(exchange, "fanout")

for queue in [queue1, queue2]:
await amqp_client.queue_bind(queue, exchange, routing_key=queue)
await amqp_client.consume(queue)

await amqp_client.publish(to_binary(message), exchange)

with then:
messages = await amqp_client.wait_for(message_count=2)
assert len(messages) == 2
for message_ in messages:
assert message_.body == to_binary(message)


@pytest.mark.asyncio
async def test_publish_no_messages(*, mock_server, amqp_client):
with given:
Expand Down
Loading