Skip to content

Commit

Permalink
add message bus
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Jan 15, 2024
1 parent c806e1e commit 35c4c3a
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 1 deletion.
13 changes: 13 additions & 0 deletions nvflare/fuel/message/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
21 changes: 21 additions & 0 deletions nvflare/fuel/message/event_manger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class EventManager:
def __init__(self, message_bus):
self.message_bus = message_bus

def fire_event(self, event_name, event_data=None):
self.message_bus.publish(event_name, event_data)
49 changes: 49 additions & 0 deletions nvflare/fuel/message/message_bus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import threading


class MessageBus:
_instance = None
_lock = threading.Lock()

def __new__(cls):
with cls._lock:
if not cls._instance:
cls._instance = super(MessageBus, cls).__new__(cls)
cls._instance.subscribers = {}
cls._instance.message_store = {}
return cls._instance

def subscribe(self, topic, callback):
if topic not in self.subscribers:
self.subscribers[topic] = []
self.subscribers[topic].append(callback)

def publish(self, topic, message):
if topic in self.subscribers:
for callback in self.subscribers[topic]:
callback(message)

def send_message(self, key, message, topic: str = "default"):
if topic not in self.message_store:
self.message_store[topic] = {}

self.message_store[topic][key] = message

self.publish(key, message) # Notify subscribers about the new message

def receive_messages(self, key, topic: str = "default"):
return self.message_store.get(topic, {}).get(key)
36 changes: 36 additions & 0 deletions nvflare/fuel/utils/function_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
from typing import Callable


def find_task_fn(task_fn_path) -> Callable:
# Split the text by the last dot
tokens = task_fn_path.rsplit(".", 1)
module_name = tokens[0]
fn_name = tokens[1] if len(tokens) > 1 else ""
module = importlib.import_module(module_name)
fn = getattr(module, fn_name)
return fn


def require_arguments(func):
signature = inspect.signature(func)
parameters = signature.parameters
req = any(p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD for p in parameters.values())
size = len(parameters)
args_with_defaults = [param for param in parameters.values() if param.default != inspect.Parameter.empty]
default_args_size = len(args_with_defaults)
return req, size, default_args_size
2 changes: 1 addition & 1 deletion runtest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function check_license() {
folders_to_check_license="nvflare examples tests integration research"
echo "checking license header in folder: $folders_to_check_license"
(grep -r --include "*.py" --exclude-dir "*protos*" -L \
"\(# Copyright (c) \(2021\|2022\|2023\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \
"\(# Copyright (c) \(2021\|2022\|2023\|2024\), NVIDIA CORPORATION. All rights reserved.\)\|\(This file is released into the public domain.\)" \
${folders_to_check_license} || true) > no_license.lst
if [ -s no_license.lst ]; then
# The file is not-empty.
Expand Down
13 changes: 13 additions & 0 deletions tests/unit_test/fuel/message/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
84 changes: 84 additions & 0 deletions tests/unit_test/fuel/message/message_bus_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

from nvflare.fuel.message.event_manger import EventManager
from nvflare.fuel.message.message_bus import MessageBus


class TestMessageBus(unittest.TestCase):
def setUp(self):
self.message_bus = MessageBus()
self.event_manager = EventManager(self.message_bus)

def test_subscribe_and_publish(self):
result = {"count": 0}

def callback_function(message):
result["count"] += 1

self.message_bus.subscribe("test_topic", callback_function)
self.message_bus.publish("test_topic", "Test Message 1")
self.message_bus.publish("test_topic", "Test Message 2")

self.assertEqual(result["count"], 2)

def test_singleton_message_bus(self):
message_bus1 = MessageBus()
message_bus1.send_message("user_1", "Hello from User 1!")
user_1_message = message_bus1.receive_messages("user_1")
self.assertEqual(user_1_message, "Hello from User 1!")

message_bus2 = MessageBus()
user_1_message = message_bus2.receive_messages("user_1")
self.assertEqual(user_1_message, "Hello from User 1!")

def test_send_message_and_receive_messages(self):
self.message_bus.send_message("user_1", "Hello from User 1!")
self.message_bus.send_message("user_2", "Greetings from User 2!")

user_1_message = self.message_bus.receive_messages("user_1")
user_2_message = self.message_bus.receive_messages("user_2")

self.assertEqual(user_1_message, "Hello from User 1!")
self.assertEqual(user_2_message, "Greetings from User 2!")

self.message_bus.send_message("user_1", "2nd greetings from User 1!")
user_1_message = self.message_bus.receive_messages("user_1")
self.assertEqual(user_1_message, "2nd greetings from User 1!")

self.message_bus.send_message("user_1", "3rd greetings from User 1!", topic="channel-3")
user_1_message = self.message_bus.receive_messages("user_1")
self.assertEqual(user_1_message, "2nd greetings from User 1!")

user_1_message = self.message_bus.receive_messages("user_1", topic="channel-3")
self.assertEqual(user_1_message, "3rd greetings from User 1!")

def test_send_message_and_receive_messages_abnormal(self):
user_3_message = self.message_bus.receive_messages("user_3")
self.assertEqual(user_3_message, None)

user_3_message = self.message_bus.receive_messages("user_3", topic="channel")
self.assertEqual(user_3_message, None)

def test_fire_event(self):
result = {"event_received": False}

def event_handler(data):
result["event_received"] = True

self.message_bus.subscribe("test_event", event_handler)
self.event_manager.fire_event("test_event", {"key": "value"})

self.assertTrue(result["event_received"])
39 changes: 39 additions & 0 deletions tests/unit_test/fuel/utils/function_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
from unittest.mock import MagicMock, patch

from nvflare.fuel.utils.function_utils import find_task_fn


class TestFindTaskFn(unittest.TestCase):
@patch("importlib.import_module")
def test_find_task_fn_with_module(self, mock_import_module):
# Test find_task_fn when a module is specified in task_fn_path
task_fn_path = "nvflare.utils.cli_utils.get_home_dir"
mock_module = MagicMock()
mock_import_module.return_value = mock_module

result = find_task_fn(task_fn_path)

mock_import_module.assert_called_once_with("nvflare.utils.cli_utils")
self.assertTrue(callable(result))

def test_find_task_fn_without_module(self):
# Test find_task_fn when no module is specified in task_fn_path
task_fn_path = "get_home_dir"
with self.assertRaises(ModuleNotFoundError) as context:
result = find_task_fn(task_fn_path)

0 comments on commit 35c4c3a

Please sign in to comment.