-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add message bus address PR comments formatting * Check invalid input directory in nvflare config (#2295) * check invalid input directory * check invalid input directory add doc string add doc string rename receive_messages() to receive_message() change doc str to google doc string style rebase and formats * remove space * restore space * Address PR comments 1) remove data store scope/topic 2) add pub_sub interface and let databus implements the inferface 3) remove function_utils.py and unit tests for another PR * Address PR comments 1) remove data store scope/topic 2) add pub_sub interface and let databus implements the inferface 3) remove function_utils.py and unit tests for another PR * reduce lock scope * make sure the publish in parallel instead of sequential * rename send_data/receive_data() to put_data()/get_data() --------- Co-authored-by: Sean Yang <seany314@gmail.com>
- Loading branch information
1 parent
dbdbdeb
commit 8cb03ed
Showing
7 changed files
with
298 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import threading | ||
from concurrent.futures import ThreadPoolExecutor | ||
from typing import Any, Callable, List | ||
|
||
from nvflare.fuel.data_event.pub_sub import EventPubSub | ||
|
||
|
||
class DataBus(EventPubSub): | ||
""" | ||
Singleton class for a simple data bus implementation. | ||
This class allows components to subscribe to topics, publish messages to topics, | ||
and store/retrieve messages associated with specific keys and topics. | ||
""" | ||
|
||
_instance = None | ||
_lock = threading.Lock() | ||
|
||
def __new__(cls) -> "DataBus": | ||
""" | ||
Create a new instance of the DataBus class. | ||
This method ensures that only one instance of the class is created (singleton pattern). | ||
The databus | ||
""" | ||
with cls._lock: | ||
if not cls._instance: | ||
cls._instance = super(DataBus, cls).__new__(cls) | ||
cls._instance.subscribers = {} | ||
cls._instance.data_store = {} | ||
return cls._instance | ||
|
||
def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None: | ||
""" | ||
Subscribe a callback function to one or more topics. | ||
Args: | ||
topics (List[str]): A list of topics to subscribe to. | ||
callback (Callable): The callback function to be called when messages are published to the subscribed topics. | ||
""" | ||
|
||
if not topics: | ||
raise ValueError("topics must non-empty") | ||
|
||
for topic in topics: | ||
if topic.isspace(): | ||
raise ValueError(f"topics {topics}contains white space topic") | ||
|
||
with self._lock: | ||
if topic not in self.subscribers: | ||
self.subscribers[topic] = [] | ||
self.subscribers[topic].append(callback) | ||
|
||
def publish(self, topics: List[str], datum: Any) -> None: | ||
""" | ||
Publish a data to one or more topics, notifying all subscribed callbacks. | ||
Args: | ||
topics (List[str]): A list of topics to publish the data to. | ||
datum (Any): The data to be published to the specified topics. | ||
""" | ||
if topics: | ||
for topic in topics: | ||
if topic in self.subscribers: | ||
with self._lock: | ||
executor = ThreadPoolExecutor(max_workers=len(self.subscribers[topic])) | ||
for callback in self.subscribers[topic]: | ||
executor.submit(callback, topic, datum, self) | ||
executor.shutdown() | ||
|
||
def put_data(self, key: Any, datum: Any) -> None: | ||
""" | ||
Store a data associated with a key and topic. | ||
Args: | ||
key (Any): The key to associate with the stored message. | ||
datum (Any): The message to be stored. | ||
""" | ||
with self._lock: | ||
self.data_store[key] = datum | ||
|
||
def get_data(self, key: Any) -> Any: | ||
""" | ||
Retrieve a stored data associated with a key and topic. | ||
Args: | ||
key (Any): The key associated with the stored message. | ||
Returns: | ||
Any: The stored datum if found, or None if not found. | ||
""" | ||
return self.data_store.get(key) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Any, Optional | ||
|
||
from nvflare.fuel.data_event.data_bus import DataBus | ||
|
||
|
||
class EventManager: | ||
""" | ||
Class for managing events by interacting with a DataBus. | ||
Args: | ||
data_bus (DataBus): An instance of the DataBus class used for event communication. | ||
""" | ||
|
||
def __init__(self, data_bus: "DataBus"): | ||
""" | ||
Initialize the EventManager with a DataBus instance. | ||
Args: | ||
data_bus (DataBus): An instance of the DataBus class used for event communication. | ||
""" | ||
self.data_bus = data_bus | ||
|
||
def fire_event(self, event_name: str, event_data: Optional[Any] = None) -> None: | ||
""" | ||
Fire an event by publishing it to the DataBus. | ||
Args: | ||
event_name (str): The name of the event to be fired. | ||
event_data (Any, optional): Additional data associated with the event (default is None). | ||
""" | ||
self.data_bus.publish([event_name], event_data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Callable, List | ||
|
||
|
||
class EventPubSub: | ||
def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None: | ||
""" | ||
Subscribe a callback function to one or more topics. | ||
Args: | ||
topics (List[str]): A list of topics to subscribe to. | ||
callback (Callable): The callback function to be called when messages are published to the subscribed topics. | ||
""" | ||
|
||
def publish(self, topics: List[str], datum: Any) -> None: | ||
""" | ||
Publish a message to one or more topics, notifying all subscribed callbacks. | ||
Args: | ||
topics (List[str]): A list of topics to publish the message to. | ||
datum (Any): The message to be published to the specified topics. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
from nvflare.fuel.data_event.data_bus import DataBus | ||
from nvflare.fuel.data_event.event_manager import EventManager | ||
|
||
|
||
class TestMessageBus(unittest.TestCase): | ||
def setUp(self): | ||
self.data_bus = DataBus() | ||
self.event_manager = EventManager(self.data_bus) | ||
|
||
def test_subscribe_and_publish(self): | ||
result = {"count": 0} | ||
|
||
def callback_function(topic, datum, data_bus): | ||
result["count"] += 1 | ||
|
||
self.data_bus.subscribe(["test_topic"], callback_function) | ||
self.data_bus.publish(["test_topic"], "Test Message 1") | ||
self.data_bus.publish(["test_topic"], "Test Message 2") | ||
|
||
self.assertEqual(result["count"], 2) | ||
|
||
def test_singleton_message_bus(self): | ||
data_bus1 = DataBus() | ||
data_bus1.put_data("user_1", "Hello from User 1!") | ||
user_1_message = data_bus1.get_data("user_1") | ||
self.assertEqual(user_1_message, "Hello from User 1!") | ||
|
||
message_bus2 = DataBus() | ||
user_1_message = message_bus2.get_data("user_1") | ||
self.assertEqual(user_1_message, "Hello from User 1!") | ||
|
||
def test_send_message_and_receive_messages(self): | ||
self.data_bus.put_data("user_1", "Hello from User 1!") | ||
self.data_bus.put_data("user_2", "Greetings from User 2!") | ||
|
||
user_1_message = self.data_bus.get_data("user_1") | ||
user_2_message = self.data_bus.get_data("user_2") | ||
|
||
self.assertEqual(user_1_message, "Hello from User 1!") | ||
self.assertEqual(user_2_message, "Greetings from User 2!") | ||
|
||
self.data_bus.put_data("user_1", "2nd greetings from User 1!") | ||
user_1_message = self.data_bus.get_data("user_1") | ||
self.assertEqual(user_1_message, "2nd greetings from User 1!") | ||
|
||
def test_send_message_and_receive_messages_abnormal(self): | ||
user_3_message = self.data_bus.get_data("user_3") | ||
self.assertEqual(user_3_message, None) | ||
|
||
def test_fire_event(self): | ||
|
||
result = { | ||
"test_event": {"event_received": False}, | ||
"dev_event": {"event_received": False}, | ||
"prod_event": {"event_received": False}, | ||
} | ||
|
||
def event_handler(topic, data, data_bus): | ||
result[topic]["event_received"] = True | ||
if data_bus.get_data("hi") == "hello": | ||
self.data_bus.put_data("hi", "hello-world") | ||
|
||
self.data_bus.put_data("hi", "hello") | ||
|
||
self.data_bus.subscribe(["test_event", "dev_event", "prod_event"], event_handler) | ||
self.event_manager.fire_event("test_event", {"key": "value"}) | ||
self.event_manager.fire_event("dev_event", {"key": "value"}) | ||
|
||
self.assertTrue(result["test_event"]["event_received"]) | ||
self.assertTrue(result["dev_event"]["event_received"]) | ||
self.assertFalse(result["prod_event"]["event_received"]) |