Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move driver/fleet/state into superlink subpackage #2924

Merged
merged 6 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/py/flwr/client/grpc_client/connection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ServerMessage,
)
from flwr.server.client_manager import SimpleClientManager
from flwr.server.fleet.grpc_bidi.grpc_server import start_grpc_server
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import start_grpc_server

from .connection import grpc_connection

Expand Down Expand Up @@ -100,7 +100,9 @@ def mock_join( # type: ignore # pylint: disable=invalid-name


@patch(
"flwr.server.fleet.grpc_bidi.flower_service_servicer.FlowerServiceServicer.Join",
# pylint: disable=line-too-long
"flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer.FlowerServiceServicer.Join", # noqa: E501
# pylint: enable=line-too-long
mock_join,
)
def test_integration_connection() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/driver/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def start_driver( # pylint: disable=too-many-arguments, too-many-locals
An implementation of the abstract base class
`flwr.server.strategy.Strategy`. If no strategy is provided, then
`start_server` will use `flwr.server.strategy.FedAvg`.
client_manager : Optional[flwr.server.DriverClientManager] (default: None)
client_manager : Optional[flwr.server.ClientManager] (default: None)
An implementation of the class `flwr.server.ClientManager`. If no
implementation is provided, then `start_driver` will use
`flwr.server.SimpleClientManager`.
Expand Down
18 changes: 9 additions & 9 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,16 @@
add_FleetServicer_to_server,
)
from flwr.server.client_manager import ClientManager, SimpleClientManager
from flwr.server.driver.driver_servicer import DriverServicer
from flwr.server.fleet.grpc_bidi.grpc_server import (
generic_create_grpc_server,
start_grpc_server,
)
from flwr.server.fleet.grpc_rere.fleet_servicer import FleetServicer
from flwr.server.history import History
from flwr.server.server import Server
from flwr.server.state import StateFactory
from flwr.server.strategy import FedAvg, Strategy
from flwr.server.superlink.driver.driver_servicer import DriverServicer
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import (
generic_create_grpc_server,
start_grpc_server,
)
from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
from flwr.server.superlink.state import StateFactory

ADDRESS_DRIVER_API = "0.0.0.0:9091"
ADDRESS_FLEET_API_GRPC_RERE = "0.0.0.0:9092"
Expand Down Expand Up @@ -561,7 +561,7 @@ def _run_fleet_api_rest(
try:
import uvicorn

from flwr.server.fleet.rest_rere.rest_api import app as fast_api_app
from flwr.server.superlink.fleet.rest_rere.rest_api import app as fast_api_app
except ModuleNotFoundError:
sys.exit(MISSING_EXTRA_REST)
if workers != 1:
Expand All @@ -584,7 +584,7 @@ def _run_fleet_api_rest(
raise ValueError(validation_exceptions)

uvicorn.run(
app="flwr.server.fleet.rest_rere.rest_api:app",
app="flwr.server.superlink.fleet.rest_rere.rest_api:app",
port=port,
host=host,
reload=False,
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/client_manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from unittest.mock import MagicMock

from flwr.server.client_manager import SimpleClientManager
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy


def test_simple_client_manager_register() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/criterion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from flwr.server.client_manager import SimpleClientManager
from flwr.server.client_proxy import ClientProxy
from flwr.server.criterion import Criterion
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy


def test_criterion_applied() -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/strategy/fedadagrad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy

from .fedadagrad import FedAdagrad

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/strategy/fedmedian_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy

from .fedmedian import FedMedian

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/strategy/krum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy

from .krum import Krum

Expand Down
2 changes: 1 addition & 1 deletion src/py/flwr/server/strategy/multikrum_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parameters_to_ndarrays,
)
from flwr.server.client_proxy import ClientProxy
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy

from .krum import Krum

Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/server/superlink/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower SuperLink."""
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskRes # pylint: disable=E0611
from flwr.server.state import State, StateFactory
from flwr.server.superlink.state import State, StateFactory
from flwr.server.utils.validator import validate_task_ins_or_res


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""DriverServicer tests."""


from flwr.server.driver.driver_servicer import _raise_if
from flwr.server.superlink.driver.driver_servicer import _raise_if

# pylint: disable=broad-except

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,12 @@
ServerMessage,
)
from flwr.server.client_manager import ClientManager
from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import (
GrpcBridge,
InsWrapper,
ResWrapper,
)
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy


def default_bridge_factory() -> GrpcBridge:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
ClientMessage,
ServerMessage,
)
from flwr.server.fleet.grpc_bidi.flower_service_servicer import (
from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
FlowerServiceServicer,
register_client_proxy,
)
from flwr.server.fleet.grpc_bidi.grpc_bridge import InsWrapper, ResWrapper
from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import InsWrapper, ResWrapper

CLIENT_MESSAGE = ClientMessage()
SERVER_MESSAGE = ServerMessage()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
ClientMessage,
ServerMessage,
)
from flwr.server.fleet.grpc_bidi.grpc_bridge import (
from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import (
GrpcBridge,
GrpcBridgeClosed,
InsWrapper,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
ServerMessage,
)
from flwr.server.client_proxy import ClientProxy
from flwr.server.fleet.grpc_bidi.grpc_bridge import GrpcBridge, InsWrapper, ResWrapper
from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import (
GrpcBridge,
InsWrapper,
ResWrapper,
)


class GrpcClientProxy(ClientProxy):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
Parameters,
Scalar,
)
from flwr.server.fleet.grpc_bidi.grpc_bridge import ResWrapper
from flwr.server.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy
from flwr.server.superlink.fleet.grpc_bidi.grpc_bridge import ResWrapper
from flwr.server.superlink.fleet.grpc_bidi.grpc_client_proxy import GrpcClientProxy

MESSAGE_PARAMETERS = Parameters(tensors=[], tensor_type="np")
MESSAGE_FIT_RES = ClientMessage(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
add_FlowerServiceServicer_to_server,
)
from flwr.server.client_manager import ClientManager
from flwr.server.driver.driver_servicer import DriverServicer
from flwr.server.fleet.grpc_bidi.flower_service_servicer import FlowerServiceServicer
from flwr.server.fleet.grpc_rere.fleet_servicer import FleetServicer
from flwr.server.superlink.driver.driver_servicer import DriverServicer
from flwr.server.superlink.fleet.grpc_bidi.flower_service_servicer import (
FlowerServiceServicer,
)
from flwr.server.superlink.fleet.grpc_rere.fleet_servicer import FleetServicer

INVALID_CERTIFICATES_ERR_MSG = """
When setting any of root_certificate, certificate, or private_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
from typing import Tuple, cast

from flwr.server.client_manager import SimpleClientManager
from flwr.server.fleet.grpc_bidi.grpc_server import (
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import (
start_grpc_server,
valid_certificates,
)

root_dir = dirname(abspath(join(__file__, "../../../../../..")))
root_dir = dirname(abspath(join(__file__, "../../../../../../..")))


def load_certificates() -> Tuple[str, str, str]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
PushTaskResRequest,
PushTaskResResponse,
)
from flwr.server.fleet.message_handler import message_handler
from flwr.server.state import StateFactory
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.state import StateFactory


class FleetServicer(fleet_pb2_grpc.FleetServicer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.state import State
from flwr.server.superlink.state import State


def create_node(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
PullTaskInsRequest,
PushTaskResRequest,
)
from flwr.server.fleet.message_handler import message_handler
from flwr.server.state import State
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.state import State

try:
from starlette.applications import Starlette
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from flwr.common import log, now
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.state.state import State
from flwr.server.superlink.state.state import State
from flwr.server.utils import validate_task_ins_or_res


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

import unittest

from flwr.server.state.sqlite_state import task_ins_to_dict
from flwr.server.state.state_test import create_task_ins
from flwr.server.superlink.state.sqlite_state import task_ins_to_dict
from flwr.server.superlink.state.state_test import create_task_ins


class SqliteStateTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.recordset_pb2 import RecordSet # pylint: disable=E0611
from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.state import InMemoryState, SqliteState, State
from flwr.server.superlink.state import InMemoryState, SqliteState, State


class StateTest(unittest.TestCase):
Expand Down
Loading