Skip to content

Commit

Permalink
Check invalid input directory in nvflare config (NVIDIA#2295)
Browse files Browse the repository at this point in the history
* check invalid input directory

* check invalid input directory

add doc string

add doc string

rename receive_messages() to receive_message()

change doc str to google doc string style
  • Loading branch information
chesterxgchen committed Jan 24, 2024
1 parent d11a42e commit 005d483
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 21 deletions.
3 changes: 3 additions & 0 deletions nvflare/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def def_config_parser(sub_cmd):
def handle_config_cmd(args):
config_file_path, nvflare_config = get_hidden_config()

if not args.job_templates_dir or not os.path.isdir(args.job_templates_dir):
raise ValueError(f"job_templates_dir='{args.job_templates_dir}', it is not a directory")

nvflare_config = create_startup_kit_config(nvflare_config, args.startup_kit_dir)
nvflare_config = create_poc_workspace_config(nvflare_config, args.poc_workspace_dir)
nvflare_config = create_job_template_config(nvflare_config, args.job_templates_dir)
Expand Down
57 changes: 50 additions & 7 deletions nvflare/fuel/message/data_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,85 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Callable, Any
import threading
from typing import Callable, List


class DataBus:
"""
Singleton class for a simple data bus implementation.
This class allows components to subscribe to topics, publish messages to topics,
and store/retrieve messages associated with specific keys and topics.
"""

_instance = None
_lock = threading.Lock()

def __new__(cls):
def __new__(cls) -> 'DataBus':
"""
Create a new instance of the DataBus class.
This method ensures that only one instance of the class is created (singleton pattern).
"""
with cls._lock:
if not cls._instance:
cls._instance = super(DataBus, cls).__new__(cls)
cls._instance.subscribers = {}
cls._instance.message_store = {}
return cls._instance

def subscribe(self, topics: List[str], callback: Callable):
def subscribe(self, topics: List[str], callback: Callable) -> None:
"""
Subscribe a callback function to one or more topics.
Args:
topics (List[str]): A list of topics to subscribe to.
callback (Callable): The callback function to be called when messages are published to the subscribed topics.
"""
if topics:
for topic in topics:
if topic not in self.subscribers:
self.subscribers[topic] = []
self.subscribers[topic].append(callback)

def publish(self, topics: List[str], message: any):
def publish(self, topics: List[str], message: Any) -> None:
"""
Publish a message to one or more topics, notifying all subscribed callbacks.
Args:
topics (List[str]): A list of topics to publish the message to.
message (Any): The message to be published to the specified topics.
"""
if topics:
for topic in topics:
if topic in self.subscribers:
for callback in self.subscribers[topic]:
callback(message, topic)

def send_message(self, key, message, topic: str = "default"):
def send_message(self, key: Any, message: Any, topic: str = "default") -> None:
"""
Store a message associated with a key and topic.
Args:
key (Any): The key to associate with the stored message.
message (Any): The message to be stored.
topic (str): The topic under which the message is stored (default is "default").
"""
if topic not in self.message_store:
self.message_store[topic] = {}

self.message_store[topic][key] = message

def receive_messages(self, key, topic: str = "default"):
def receive_message(self, key: Any, topic: str = "default") -> Any:
"""
Retrieve a stored message associated with a key and topic.
Args:
key (Any): The key associated with the stored message.
topic (str): The topic under which the message is stored (default is "default").
Returns:
Any: The stored message if found, or None if not found.
"""
return self.message_store.get(topic, {}).get(key)
27 changes: 25 additions & 2 deletions nvflare/fuel/message/event_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Any

from nvflare.fuel.message.data_bus import DataBus


class EventManager:
def __init__(self, data_bus: DataBus):
"""
Class for managing events by interacting with a DataBus.
Args:
data_bus (DataBus): An instance of the DataBus class used for event communication.
"""

def __init__(self, data_bus: 'DataBus'):
"""
Initialize the EventManager with a DataBus instance.
Args:
data_bus (DataBus): An instance of the DataBus class used for event communication.
"""
self.data_bus = data_bus

def fire_event(self, event_name, event_data=None):
def fire_event(self, event_name: str, event_data: Optional[Any] = None) -> None:
"""
Fire an event by publishing it to the DataBus.
Args:
event_name (str): The name of the event to be fired.
event_data (Any, optional): Additional data associated with the event (default is None).
"""
self.data_bus.publish([event_name], event_data)
27 changes: 24 additions & 3 deletions nvflare/fuel/utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,19 @@
# limitations under the License.
import importlib
import inspect
from typing import Callable
from typing import Callable, Tuple


def find_task_fn(task_fn_path) -> Callable:
def find_task_fn(task_fn_path: str) -> Callable:
"""
Find and return a callable task function based on its module path.
Args:
task_fn_path (str): The path to the task function in the format "module_path.function_name".
Returns:
Callable: The callable task function.
"""
# Split the text by the last dot
tokens = task_fn_path.rsplit(".", 1)
module_name = tokens[0]
Expand All @@ -26,7 +35,19 @@ def find_task_fn(task_fn_path) -> Callable:
return fn


def require_arguments(func):
def require_arguments(func: Callable) -> Tuple[bool, int, int]:
"""
Check if a function requires arguments and provide information about its signature.
Args:
func (Callable): The function to be checked.
Returns:
Tuple[bool, int, int]: A tuple containing three elements:
1. A boolean indicating whether the function requires any arguments.
2. The total number of parameters in the function's signature.
3. The number of parameters with default values (i.e., optional parameters).
"""
signature = inspect.signature(func)
parameters = signature.parameters
req = any(p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD for p in parameters.values())
Expand Down
18 changes: 9 additions & 9 deletions tests/unit_test/fuel/message/data_bus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,39 +37,39 @@ def callback_function(message, topic="default"):
def test_singleton_message_bus(self):
message_bus1 = DataBus()
message_bus1.send_message("user_1", "Hello from User 1!")
user_1_message = message_bus1.receive_messages("user_1")
user_1_message = message_bus1.receive_message("user_1")
self.assertEqual(user_1_message, "Hello from User 1!")

message_bus2 = DataBus()
user_1_message = message_bus2.receive_messages("user_1")
user_1_message = message_bus2.receive_message("user_1")
self.assertEqual(user_1_message, "Hello from User 1!")

def test_send_message_and_receive_messages(self):
self.message_bus.send_message("user_1", "Hello from User 1!")
self.message_bus.send_message("user_2", "Greetings from User 2!")

user_1_message = self.message_bus.receive_messages("user_1")
user_2_message = self.message_bus.receive_messages("user_2")
user_1_message = self.message_bus.receive_message("user_1")
user_2_message = self.message_bus.receive_message("user_2")

self.assertEqual(user_1_message, "Hello from User 1!")
self.assertEqual(user_2_message, "Greetings from User 2!")

self.message_bus.send_message("user_1", "2nd greetings from User 1!")
user_1_message = self.message_bus.receive_messages("user_1")
user_1_message = self.message_bus.receive_message("user_1")
self.assertEqual(user_1_message, "2nd greetings from User 1!")

self.message_bus.send_message("user_1", "3rd greetings from User 1!", topic="channel-3")
user_1_message = self.message_bus.receive_messages("user_1")
user_1_message = self.message_bus.receive_message("user_1")
self.assertEqual(user_1_message, "2nd greetings from User 1!")

user_1_message = self.message_bus.receive_messages("user_1", topic="channel-3")
user_1_message = self.message_bus.receive_message("user_1", topic="channel-3")
self.assertEqual(user_1_message, "3rd greetings from User 1!")

def test_send_message_and_receive_messages_abnormal(self):
user_3_message = self.message_bus.receive_messages("user_3")
user_3_message = self.message_bus.receive_message("user_3")
self.assertEqual(user_3_message, None)

user_3_message = self.message_bus.receive_messages("user_3", topic="channel")
user_3_message = self.message_bus.receive_message("user_3", topic="channel")
self.assertEqual(user_3_message, None)

def test_fire_event(self):
Expand Down

0 comments on commit 005d483

Please sign in to comment.