diff --git a/aiodistbus/eventbus/eventbus.py b/aiodistbus/eventbus/eventbus.py index 56f9488..d535bb1 100644 --- a/aiodistbus/eventbus/eventbus.py +++ b/aiodistbus/eventbus/eventbus.py @@ -164,6 +164,16 @@ async def _wrapper(event): # Store the entrypoint self._dentrypoints[f"{ip}:{port}"] = e + async def link( + self, + ip: str, + port: int, + to_event_types: Optional[List[str]] = None, + from_event_types: Optional[List[str]] = None, + ): + await self.listen(ip, port, event_types=from_event_types) + await self.forward(ip, port, event_types=to_event_types) + async def close(self): """Close the eventbus""" # Emit first to allow for cleanup diff --git a/aiodistbus/utils.py b/aiodistbus/utils.py index 9c2aab2..ebbb4a3 100644 --- a/aiodistbus/utils.py +++ b/aiodistbus/utils.py @@ -134,10 +134,7 @@ async def reconstruct(event_str: str, dtype: Optional[Type] = None) -> Event: event = reconstruct_event_data(event, dtype) elif event.dtype and event.dtype != "builtins.NoneType": l_dtype = locate(event.dtype) - if l_dtype is type: - event = reconstruct_event_data(event, l_dtype) - else: - logger.error(f"Could not find type {event.dtype}") + event = reconstruct_event_data(event, l_dtype) # type: ignore return event diff --git a/test/test_bridge.py b/test/test_bridge.py index 6a63bd7..71ca939 100644 --- a/test/test_bridge.py +++ b/test/test_bridge.py @@ -2,6 +2,8 @@ import pytest +from aiodistbus import DEventBus, EntryPoint, EventBus + from .conftest import ( ExampleEvent, func, @@ -28,6 +30,7 @@ ("test_bool", func_bool, bool, True), ("test_none", func_none, None, None), ("test_dict", func_dict, dict, {"hello": "world"}), + ("test.spaces", func, ExampleEvent, ExampleEvent(msg="world")), ], ) async def test_forward_bus_to_dbus( @@ -166,3 +169,136 @@ async def test_bus_listen_to_dbus_wildcard(bus, dbus, entrypoints, dentrypoints) # Assert assert event.id in e1._received + + +@pytest.mark.parametrize( + "event_type, func, dtype, dtype_instance", + [ + ("test", func, ExampleEvent, ExampleEvent("Hello")), + ("test_str", func_str, str, "Hello"), + ("test_bytes", func_bytes, bytes, b"Hello"), + ("test_list", func_list, List, ["Hello"]), + ("test_int", func_int, int, 1), + ("test_float", func_float, float, 1.0), + ("test_bool", func_bool, bool, True), + ("test_none", func_none, None, None), + ("test_dict", func_dict, dict, {"hello": "world"}), + ], +) +async def test_local_buses_comms_server_to_client( + event_type, func, dtype, dtype_instance +): + + # Local buses + sb = EventBus() + sdbus = DEventBus() + cb = EventBus() + + # Create entrypoint + ce = EntryPoint() + await ce.connect(cb) + await ce.on(event_type, func, dtype) + se = EntryPoint() + await se.connect(sb) + + # Link + await sb.forward(sdbus.ip, sdbus.port) + await cb.listen(sdbus.ip, sdbus.port) + + # Send message + event = await se.emit(event_type, dtype_instance) + + # Flush + await sdbus.flush() + + # Assert + assert event and event.id in ce._received + + +@pytest.mark.parametrize( + "event_type, func, dtype, dtype_instance", + [ + ("test", func, ExampleEvent, ExampleEvent("Hello")), + ("test_str", func_str, str, "Hello"), + ("test_bytes", func_bytes, bytes, b"Hello"), + ("test_list", func_list, List, ["Hello"]), + ("test_int", func_int, int, 1), + ("test_float", func_float, float, 1.0), + ("test_bool", func_bool, bool, True), + ("test_none", func_none, None, None), + ("test_dict", func_dict, dict, {"hello": "world"}), + ], +) +async def test_local_buses_comms_client_to_server( + event_type, func, dtype, dtype_instance +): + + # Local buses + sb = EventBus() + sdbus = DEventBus() + cb = EventBus() + + # Create entrypoint + ce = EntryPoint() + await ce.connect(cb) + se = EntryPoint() + await se.connect(sb) + await se.on(event_type, func, dtype) + + # Link + await cb.forward(sdbus.ip, sdbus.port) + await sdbus.forward(sb) + + # Send message + event = await ce.emit(event_type, dtype_instance) + + # Flush + await sdbus.flush() + + # Assert + assert event and event.id in se._received + + +@pytest.mark.parametrize( + "event_type, func, dtype, dtype_instance", + [ + ("test", func, ExampleEvent, ExampleEvent("Hello")), + ("test_str", func_str, str, "Hello"), + ("test_bytes", func_bytes, bytes, b"Hello"), + ("test_list", func_list, List, ["Hello"]), + ("test_int", func_int, int, 1), + ("test_float", func_float, float, 1.0), + ("test_bool", func_bool, bool, True), + ("test_none", func_none, None, None), + ("test_dict", func_dict, dict, {"hello": "world"}), + ], +) +async def test_local_buses_comms_bidirectional(event_type, func, dtype, dtype_instance): + + # Local buses + sb = EventBus() + sdbus = DEventBus() + cb = EventBus() + + # Create entrypoint + ce = EntryPoint() + await ce.connect(cb) + await ce.on(f"client.{event_type}", func, dtype) + se = EntryPoint() + await se.connect(sb) + await se.on(f"server.{event_type}", func, dtype) + + # Link + await cb.link(sdbus.ip, sdbus.port, ["server.*"], ["client.*"]) + await sb.link(sdbus.ip, sdbus.port, ["client.*"], ["server.*"]) + + # Send message + cevent = await ce.emit(f"server.{event_type}", dtype_instance) + sevent = await se.emit(f"client.{event_type}", dtype_instance) + + # Flush + await sdbus.flush() + + # Assert + assert cevent and cevent.id in se._received + assert sevent and sevent.id in ce._received