Skip to content

Commit

Permalink
DataBus (#2285)
Browse files Browse the repository at this point in the history
* add message bus

address PR comments
formatting

* Check invalid input directory in nvflare config (#2295)

* check invalid input directory

* check invalid input directory

add doc string

add doc string

rename receive_messages() to receive_message()

change doc str to google doc string style

rebase and formats

* remove space

* restore space

* Address PR comments
1) remove data store scope/topic
2) add pub_sub interface and let databus implements the inferface
3) remove function_utils.py and unit tests for another PR

* Address PR comments
1) remove data store scope/topic
2) add pub_sub interface and let databus implements the inferface
3) remove function_utils.py and unit tests for another PR

* reduce lock scope

* make sure the publish in parallel instead of sequential

* rename send_data/receive_data() to put_data()/get_data()

---------

Co-authored-by: Sean Yang <seany314@gmail.com>
  • Loading branch information
chesterxgchen and SYangster authored Feb 5, 2024
1 parent dbdbdeb commit 8cb03ed
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 1 deletion.
13 changes: 13 additions & 0 deletions nvflare/fuel/data_event/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
106 changes: 106 additions & 0 deletions nvflare/fuel/data_event/data_bus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Callable, List

from nvflare.fuel.data_event.pub_sub import EventPubSub


class DataBus(EventPubSub):
"""
Singleton class for a simple data bus implementation.
This class allows components to subscribe to topics, publish messages to topics,
and store/retrieve messages associated with specific keys and topics.
"""

_instance = None
_lock = threading.Lock()

def __new__(cls) -> "DataBus":
"""
Create a new instance of the DataBus class.
This method ensures that only one instance of the class is created (singleton pattern).
The databus
"""
with cls._lock:
if not cls._instance:
cls._instance = super(DataBus, cls).__new__(cls)
cls._instance.subscribers = {}
cls._instance.data_store = {}
return cls._instance

def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None:
"""
Subscribe a callback function to one or more topics.
Args:
topics (List[str]): A list of topics to subscribe to.
callback (Callable): The callback function to be called when messages are published to the subscribed topics.
"""

if not topics:
raise ValueError("topics must non-empty")

for topic in topics:
if topic.isspace():
raise ValueError(f"topics {topics}contains white space topic")

with self._lock:
if topic not in self.subscribers:
self.subscribers[topic] = []
self.subscribers[topic].append(callback)

def publish(self, topics: List[str], datum: Any) -> None:
"""
Publish a data to one or more topics, notifying all subscribed callbacks.
Args:
topics (List[str]): A list of topics to publish the data to.
datum (Any): The data to be published to the specified topics.
"""
if topics:
for topic in topics:
if topic in self.subscribers:
with self._lock:
executor = ThreadPoolExecutor(max_workers=len(self.subscribers[topic]))
for callback in self.subscribers[topic]:
executor.submit(callback, topic, datum, self)
executor.shutdown()

def put_data(self, key: Any, datum: Any) -> None:
"""
Store a data associated with a key and topic.
Args:
key (Any): The key to associate with the stored message.
datum (Any): The message to be stored.
"""
with self._lock:
self.data_store[key] = datum

def get_data(self, key: Any) -> Any:
"""
Retrieve a stored data associated with a key and topic.
Args:
key (Any): The key associated with the stored message.
Returns:
Any: The stored datum if found, or None if not found.
"""
return self.data_store.get(key)
45 changes: 45 additions & 0 deletions nvflare/fuel/data_event/event_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

from nvflare.fuel.data_event.data_bus import DataBus


class EventManager:
"""
Class for managing events by interacting with a DataBus.
Args:
data_bus (DataBus): An instance of the DataBus class used for event communication.
"""

def __init__(self, data_bus: "DataBus"):
"""
Initialize the EventManager with a DataBus instance.
Args:
data_bus (DataBus): An instance of the DataBus class used for event communication.
"""
self.data_bus = data_bus

def fire_event(self, event_name: str, event_data: Optional[Any] = None) -> None:
"""
Fire an event by publishing it to the DataBus.
Args:
event_name (str): The name of the event to be fired.
event_data (Any, optional): Additional data associated with the event (default is None).
"""
self.data_bus.publish([event_name], event_data)
34 changes: 34 additions & 0 deletions nvflare/fuel/data_event/pub_sub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, List


class EventPubSub:
def subscribe(self, topics: List[str], callback: Callable[[str, Any, "DataBus"], None]) -> None:
"""
Subscribe a callback function to one or more topics.
Args:
topics (List[str]): A list of topics to subscribe to.
callback (Callable): The callback function to be called when messages are published to the subscribed topics.
"""

def publish(self, topics: List[str], datum: Any) -> None:
"""
Publish a message to one or more topics, notifying all subscribed callbacks.
Args:
topics (List[str]): A list of topics to publish the message to.
datum (Any): The message to be published to the specified topics.
"""
2 changes: 1 addition & 1 deletion runtest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function check_license() {
folders_to_check_license="nvflare examples tests integration research"
echo "checking license header in folder: $folders_to_check_license"
(grep -r --include "*.py" --exclude-dir "*protos*" -L \
"\(# Copyright (c) \(2021\|2022\|2023\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \
"\(# Copyright (c) \(2021\|2022\|2023\|2024\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \
${folders_to_check_license} || true) > no_license.lst
if [ -s no_license.lst ]; then
# The file is not-empty.
Expand Down
13 changes: 13 additions & 0 deletions tests/unit_test/fuel/data_event/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
86 changes: 86 additions & 0 deletions tests/unit_test/fuel/data_event/data_bus_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from nvflare.fuel.data_event.data_bus import DataBus
from nvflare.fuel.data_event.event_manager import EventManager


class TestMessageBus(unittest.TestCase):
def setUp(self):
self.data_bus = DataBus()
self.event_manager = EventManager(self.data_bus)

def test_subscribe_and_publish(self):
result = {"count": 0}

def callback_function(topic, datum, data_bus):
result["count"] += 1

self.data_bus.subscribe(["test_topic"], callback_function)
self.data_bus.publish(["test_topic"], "Test Message 1")
self.data_bus.publish(["test_topic"], "Test Message 2")

self.assertEqual(result["count"], 2)

def test_singleton_message_bus(self):
data_bus1 = DataBus()
data_bus1.put_data("user_1", "Hello from User 1!")
user_1_message = data_bus1.get_data("user_1")
self.assertEqual(user_1_message, "Hello from User 1!")

message_bus2 = DataBus()
user_1_message = message_bus2.get_data("user_1")
self.assertEqual(user_1_message, "Hello from User 1!")

def test_send_message_and_receive_messages(self):
self.data_bus.put_data("user_1", "Hello from User 1!")
self.data_bus.put_data("user_2", "Greetings from User 2!")

user_1_message = self.data_bus.get_data("user_1")
user_2_message = self.data_bus.get_data("user_2")

self.assertEqual(user_1_message, "Hello from User 1!")
self.assertEqual(user_2_message, "Greetings from User 2!")

self.data_bus.put_data("user_1", "2nd greetings from User 1!")
user_1_message = self.data_bus.get_data("user_1")
self.assertEqual(user_1_message, "2nd greetings from User 1!")

def test_send_message_and_receive_messages_abnormal(self):
user_3_message = self.data_bus.get_data("user_3")
self.assertEqual(user_3_message, None)

def test_fire_event(self):

result = {
"test_event": {"event_received": False},
"dev_event": {"event_received": False},
"prod_event": {"event_received": False},
}

def event_handler(topic, data, data_bus):
result[topic]["event_received"] = True
if data_bus.get_data("hi") == "hello":
self.data_bus.put_data("hi", "hello-world")

self.data_bus.put_data("hi", "hello")

self.data_bus.subscribe(["test_event", "dev_event", "prod_event"], event_handler)
self.event_manager.fire_event("test_event", {"key": "value"})
self.event_manager.fire_event("dev_event", {"key": "value"})

self.assertTrue(result["test_event"]["event_received"])
self.assertTrue(result["dev_event"]["event_received"])
self.assertFalse(result["prod_event"]["event_received"])

0 comments on commit 8cb03ed

Please sign in to comment.