Skip to content

Commit

Permalink
Adding link to make bidrectional comms easier to configure. (#10)
Browse files Browse the repository at this point in the history
* Adding link to make bidrectional comms easier to configure.

* Added more tests.
  • Loading branch information
edavalosanaya authored Nov 7, 2023
1 parent 6862292 commit 07b9f32
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 4 deletions.
10 changes: 10 additions & 0 deletions aiodistbus/eventbus/eventbus.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ async def _wrapper(event):
# Store the entrypoint
self._dentrypoints[f"{ip}:{port}"] = e

async def link(
self,
ip: str,
port: int,
to_event_types: Optional[List[str]] = None,
from_event_types: Optional[List[str]] = None,
):
await self.listen(ip, port, event_types=from_event_types)
await self.forward(ip, port, event_types=to_event_types)

async def close(self):
"""Close the eventbus"""
# Emit first to allow for cleanup
Expand Down
5 changes: 1 addition & 4 deletions aiodistbus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,7 @@ async def reconstruct(event_str: str, dtype: Optional[Type] = None) -> Event:
event = reconstruct_event_data(event, dtype)
elif event.dtype and event.dtype != "builtins.NoneType":
l_dtype = locate(event.dtype)
if l_dtype is type:
event = reconstruct_event_data(event, l_dtype)
else:
logger.error(f"Could not find type {event.dtype}")
event = reconstruct_event_data(event, l_dtype) # type: ignore

return event

Expand Down
136 changes: 136 additions & 0 deletions test/test_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from aiodistbus import DEventBus, EntryPoint, EventBus

from .conftest import (
ExampleEvent,
func,
Expand All @@ -28,6 +30,7 @@
("test_bool", func_bool, bool, True),
("test_none", func_none, None, None),
("test_dict", func_dict, dict, {"hello": "world"}),
("test.spaces", func, ExampleEvent, ExampleEvent(msg="world")),
],
)
async def test_forward_bus_to_dbus(
Expand Down Expand Up @@ -166,3 +169,136 @@ async def test_bus_listen_to_dbus_wildcard(bus, dbus, entrypoints, dentrypoints)

# Assert
assert event.id in e1._received


@pytest.mark.parametrize(
"event_type, func, dtype, dtype_instance",
[
("test", func, ExampleEvent, ExampleEvent("Hello")),
("test_str", func_str, str, "Hello"),
("test_bytes", func_bytes, bytes, b"Hello"),
("test_list", func_list, List, ["Hello"]),
("test_int", func_int, int, 1),
("test_float", func_float, float, 1.0),
("test_bool", func_bool, bool, True),
("test_none", func_none, None, None),
("test_dict", func_dict, dict, {"hello": "world"}),
],
)
async def test_local_buses_comms_server_to_client(
event_type, func, dtype, dtype_instance
):

# Local buses
sb = EventBus()
sdbus = DEventBus()
cb = EventBus()

# Create entrypoint
ce = EntryPoint()
await ce.connect(cb)
await ce.on(event_type, func, dtype)
se = EntryPoint()
await se.connect(sb)

# Link
await sb.forward(sdbus.ip, sdbus.port)
await cb.listen(sdbus.ip, sdbus.port)

# Send message
event = await se.emit(event_type, dtype_instance)

# Flush
await sdbus.flush()

# Assert
assert event and event.id in ce._received


@pytest.mark.parametrize(
"event_type, func, dtype, dtype_instance",
[
("test", func, ExampleEvent, ExampleEvent("Hello")),
("test_str", func_str, str, "Hello"),
("test_bytes", func_bytes, bytes, b"Hello"),
("test_list", func_list, List, ["Hello"]),
("test_int", func_int, int, 1),
("test_float", func_float, float, 1.0),
("test_bool", func_bool, bool, True),
("test_none", func_none, None, None),
("test_dict", func_dict, dict, {"hello": "world"}),
],
)
async def test_local_buses_comms_client_to_server(
event_type, func, dtype, dtype_instance
):

# Local buses
sb = EventBus()
sdbus = DEventBus()
cb = EventBus()

# Create entrypoint
ce = EntryPoint()
await ce.connect(cb)
se = EntryPoint()
await se.connect(sb)
await se.on(event_type, func, dtype)

# Link
await cb.forward(sdbus.ip, sdbus.port)
await sdbus.forward(sb)

# Send message
event = await ce.emit(event_type, dtype_instance)

# Flush
await sdbus.flush()

# Assert
assert event and event.id in se._received


@pytest.mark.parametrize(
"event_type, func, dtype, dtype_instance",
[
("test", func, ExampleEvent, ExampleEvent("Hello")),
("test_str", func_str, str, "Hello"),
("test_bytes", func_bytes, bytes, b"Hello"),
("test_list", func_list, List, ["Hello"]),
("test_int", func_int, int, 1),
("test_float", func_float, float, 1.0),
("test_bool", func_bool, bool, True),
("test_none", func_none, None, None),
("test_dict", func_dict, dict, {"hello": "world"}),
],
)
async def test_local_buses_comms_bidirectional(event_type, func, dtype, dtype_instance):

# Local buses
sb = EventBus()
sdbus = DEventBus()
cb = EventBus()

# Create entrypoint
ce = EntryPoint()
await ce.connect(cb)
await ce.on(f"client.{event_type}", func, dtype)
se = EntryPoint()
await se.connect(sb)
await se.on(f"server.{event_type}", func, dtype)

# Link
await cb.link(sdbus.ip, sdbus.port, ["server.*"], ["client.*"])
await sb.link(sdbus.ip, sdbus.port, ["client.*"], ["server.*"])

# Send message
cevent = await ce.emit(f"server.{event_type}", dtype_instance)
sevent = await se.emit(f"client.{event_type}", dtype_instance)

# Flush
await sdbus.flush()

# Assert
assert cevent and cevent.id in se._received
assert sevent and sevent.id in ce._received

0 comments on commit 07b9f32

Please sign in to comment.