Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Simulator SP Aux message support #3076

Merged
merged 11 commits into from
Nov 27, 2024
19 changes: 17 additions & 2 deletions nvflare/private/fed/simulator/simulator_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Dict, List, Optional

from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ServerCommandKey
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReservedTopic, ServerCommandKey, SiteType
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import ReturnCode, Shareable, make_reply
from nvflare.apis.workspace import Workspace
from nvflare.fuel.f3.message import Message
from nvflare.private.fed.server.run_manager import RunManager
from nvflare.private.fed.server.server_state import HotState
Expand Down Expand Up @@ -144,6 +145,20 @@ def _create_server_engine(self, args, snapshot_persistor):

def deploy(self, args, grpc_args=None, secure_train=False):
super(FederatedServer, self).deploy(args, grpc_args, secure_train)
os.makedirs(os.path.join(args.workspace, "local"), exist_ok=True)
os.makedirs(os.path.join(args.workspace, "startup"), exist_ok=True)
workspace = Workspace(args.workspace, "server", args.config_folder)
run_manager = RunManager(
yhwen marked this conversation as resolved.
Show resolved Hide resolved
server_name=SiteType.SERVER,
engine=self.engine,
job_id="",
workspace=workspace,
components={},
handlers=[],
)
self.engine.set_run_manager(run_manager)
self.engine.initialize_comm(self.cell)

self._register_cellnet_cbs()

def stop_training(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/unit_test/fuel/f3/streaming/streaming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ def __init__(self):


class TestStreamCell:
@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def port(self):
return get_open_ports(1)[0]

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def state(self):
return State()

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def server_cell(self, port, state):
listening_url = f"tcp://localhost:{port}"
cell = CoreCell(RX_CELL, listening_url, secure=False, credentials={})
Expand All @@ -51,7 +51,7 @@ def server_cell(self, port, state):
yield stream_cell
cell.stop()

@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def client_cell(self, port, state):
connect_url = f"tcp://localhost:{port}"
cell = CoreCell(TX_CELL, connect_url, secure=False, credentials={})
Expand Down
33 changes: 22 additions & 11 deletions tests/unit_test/private/fed/app/deployer/simulator_deployer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import argparse
import os
import shutil
import tempfile
import unittest
Expand All @@ -26,6 +27,8 @@
from nvflare.private.fed.app.deployer.simulator_deployer import SimulatorDeployer
from nvflare.private.fed.app.simulator.simulator import define_simulator_parser
from nvflare.private.fed.client.fed_client import FederatedClient
from nvflare.private.fed.server.run_manager import RunManager
from nvflare.private.fed.simulator.simulator_server import SimulatorServer

# from nvflare.private.fed.simulator.simulator_server import SimulatorServer
from nvflare.security.security import EmptyAuthorizer
Expand All @@ -49,17 +52,6 @@ def _create_parser(self):

return parser

# Disable this test temporarily since it conflicts with other tests.
# def test_create_server(self):
# with patch("nvflare.private.fed.app.utils.FedAdminServer") as mock_admin:
# workspace = tempfile.mkdtemp()
# parser = self._create_parser()
# args = parser.parse_args(["job_folder", "-w" + workspace, "-n 2", "-t 1"])
# _, server = self.deployer.create_fl_server(args)
# assert isinstance(server, SimulatorServer)
# server.cell.stop()
# shutil.rmtree(workspace)

@patch("nvflare.private.fed.client.fed_client.FederatedClient.register")
# @patch("nvflare.private.fed.app.deployer.simulator_deployer.FederatedClient.start_heartbeat")
# @patch("nvflare.private.fed.app.deployer.simulator_deployer.FedAdminAgent")
Expand All @@ -71,3 +63,22 @@ def test_create_client(self, mock_register):
assert isinstance(client, FederatedClient)
client.cell.stop()
shutil.rmtree(workspace)

@patch("nvflare.private.fed.server.admin.FedAdminServer.start")
@patch("nvflare.private.fed.simulator.simulator_server.SimulatorServer._register_cellnet_cbs")
@patch("nvflare.private.fed.server.fed_server.Cell")
def test_create_server(self, mock_admin, mock_simulator_server, mock_cell):
workspace = tempfile.mkdtemp()
os.mkdir(os.path.join(workspace, "local"))
os.mkdir(os.path.join(workspace, "startup"))
parser = self._create_parser()
args = parser.parse_args(["job_folder", "-w" + workspace, "-n 2", "-t 1"])
args.config_folder = "config"
_, server = self.deployer.create_fl_server(args)

assert isinstance(server, SimulatorServer)
assert isinstance(server.engine.run_manager, RunManager)

server.cell.stop()
server.close()
shutil.rmtree(workspace)
Loading