Skip to content

Commit

Permalink
Add Simulator SP Aux message support (NVIDIA#3076)
Browse files Browse the repository at this point in the history
* Added Simulator SP Aux message support.

* Added unit test for test_run_manager_creation.

* Added handle for the workspace creation.

* Changed to use SiteType.SERVER.

* Changed the scope=module to scope=session in streaming/streaming_test.py.

* renamed variable.

* use mock Cell for the test_run_manager_creation.

* refactored.

* Added more assert tests for test_create_server.

---------

Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <yuantingh@nvidia.com>
  • Loading branch information
yhwen and YuanTingHsieh authored Nov 27, 2024
1 parent db0bead commit da6f808
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 17 deletions.
19 changes: 17 additions & 2 deletions nvflare/private/fed/simulator/simulator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict, List, Optional

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ServerCommandKey
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ServerCommandKey, SiteType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.message import Message
from nvflare.private.fed.server.run_manager import RunManager
from nvflare.private.fed.server.server_state import HotState
Expand Down Expand Up @@ -144,6 +145,20 @@ def _create_server_engine(self, args, snapshot_persistor):

def deploy(self, args, grpc_args=None, secure_train=False):
super(FederatedServer, self).deploy(args, grpc_args, secure_train)
os.makedirs(os.path.join(args.workspace, "local"), exist_ok=True)
os.makedirs(os.path.join(args.workspace, "startup"), exist_ok=True)
workspace = Workspace(args.workspace, "server", args.config_folder)
run_manager = RunManager(
server_name=SiteType.SERVER,
engine=self.engine,
job_id="",
workspace=workspace,
components={},
handlers=[],
)
self.engine.set_run_manager(run_manager)
self.engine.initialize_comm(self.cell)

self._register_cellnet_cbs()

def stop_training(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_test/fuel/f3/streaming/streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def __init__(self):


class TestStreamCell:
@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def port(self):
return get_open_ports(1)[0]

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def state(self):
return State()

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def server_cell(self, port, state):
listening_url = f"tcp://localhost:{port}"
cell = CoreCell(RX_CELL, listening_url, secure=False, credentials={})
Expand All @@ -51,7 +51,7 @@ def server_cell(self, port, state):
yield stream_cell
cell.stop()

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def client_cell(self, port, state):
connect_url = f"tcp://localhost:{port}"
cell = CoreCell(TX_CELL, connect_url, secure=False, credentials={})
Expand Down
33 changes: 22 additions & 11 deletions tests/unit_test/private/fed/app/deployer/simulator_deployer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import os
import shutil
import tempfile
import unittest
Expand All @@ -26,6 +27,8 @@
from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer
from nvflare.private.fed.app.simulator.simulator import define_simulator_parser
from nvflare.private.fed.client.fed_client import FederatedClient
from nvflare.private.fed.server.run_manager import RunManager
from nvflare.private.fed.simulator.simulator_server import SimulatorServer

# from nvflare.private.fed.simulator.simulator_server import SimulatorServer
from nvflare.security.security import EmptyAuthorizer
Expand All @@ -49,17 +52,6 @@ def _create_parser(self):

return parser

# Disable this test temporarily since it conflicts with other tests.
# def test_create_server(self):
# with patch("nvflare.private.fed.app.utils.FedAdminServer") as mock_admin:
# workspace = tempfile.mkdtemp()
# parser = self._create_parser()
# args = parser.parse_args(["job_folder", "-w" + workspace, "-n 2", "-t 1"])
# _, server = self.deployer.create_fl_server(args)
# assert isinstance(server, SimulatorServer)
# server.cell.stop()
# shutil.rmtree(workspace)

@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
# @patch("nvflare.private.fed.app.deployer.simulator_deployer.FederatedClient.start_heartbeat")
# @patch("nvflare.private.fed.app.deployer.simulator_deployer.FedAdminAgent")
Expand All @@ -71,3 +63,22 @@ def test_create_client(self, mock_register):
assert isinstance(client, FederatedClient)
client.cell.stop()
shutil.rmtree(workspace)

@patch("nvflare.private.fed.server.admin.FedAdminServer.start")
@patch("nvflare.private.fed.simulator.simulator_server.SimulatorServer._register_cellnet_cbs")
@patch("nvflare.private.fed.server.fed_server.Cell")
def test_create_server(self, mock_admin, mock_simulator_server, mock_cell):
workspace = tempfile.mkdtemp()
os.mkdir(os.path.join(workspace, "local"))
os.mkdir(os.path.join(workspace, "startup"))
parser = self._create_parser()
args = parser.parse_args(["job_folder", "-w" + workspace, "-n 2", "-t 1"])
args.config_folder = "config"
_, server = self.deployer.create_fl_server(args)

assert isinstance(server, SimulatorServer)
assert isinstance(server.engine.run_manager, RunManager)

server.cell.stop()
server.close()
shutil.rmtree(workspace)

0 comments on commit da6f808

Please sign in to comment.