Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(framework) Fix the usage of FfsFactory in rest_api.py #4493

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/py/flwr/server/superlink/fleet/rest_rere/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import sys
from collections.abc import Awaitable
from typing import Callable, TypeVar
from typing import Callable, TypeVar, cast

from google.protobuf.message import Message as GrpcMessage

Expand All @@ -39,8 +39,9 @@
)
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.server.superlink.ffs.ffs import Ffs
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.fleet.message_handler import message_handler
from flwr.server.superlink.linkstate import LinkState
from flwr.server.superlink.linkstate import LinkState, LinkStateFactory

try:
from starlette.applications import Starlette
Expand Down Expand Up @@ -90,7 +91,7 @@ async def wrapper(request: Request) -> Response:
async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
"""Create Node."""
# Get state from app
state: LinkState = app.state.STATE_FACTORY.state()
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.create_node(request=request, state=state)
Expand All @@ -100,7 +101,7 @@ async def create_node(request: CreateNodeRequest) -> CreateNodeResponse:
async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
"""Delete Node Id."""
# Get state from app
state: LinkState = app.state.STATE_FACTORY.state()
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.delete_node(request=request, state=state)
Expand All @@ -110,7 +111,7 @@ async def delete_node(request: DeleteNodeRequest) -> DeleteNodeResponse:
async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
"""Pull TaskIns."""
# Get state from app
state: LinkState = app.state.STATE_FACTORY.state()
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.pull_task_ins(request=request, state=state)
Expand All @@ -121,7 +122,7 @@ async def pull_task_ins(request: PullTaskInsRequest) -> PullTaskInsResponse:
async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
"""Push TaskRes."""
# Get state from app
state: LinkState = app.state.STATE_FACTORY.state()
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.push_task_res(request=request, state=state)
Expand All @@ -131,7 +132,7 @@ async def push_task_res(request: PushTaskResRequest) -> PushTaskResResponse:
async def ping(request: PingRequest) -> PingResponse:
"""Ping."""
# Get state from app
state: LinkState = app.state.STATE_FACTORY.state()
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.ping(request=request, state=state)
Expand All @@ -141,7 +142,7 @@ async def ping(request: PingRequest) -> PingResponse:
async def get_run(request: GetRunRequest) -> GetRunResponse:
"""GetRun."""
# Get state from app
state: LinkState = app.state.STATE_FACTORY.state()
state: LinkState = cast(LinkStateFactory, app.state.STATE_FACTORY).state()

# Handle message
return message_handler.get_run(request=request, state=state)
Expand All @@ -151,7 +152,7 @@ async def get_run(request: GetRunRequest) -> GetRunResponse:
async def get_fab(request: GetFabRequest) -> GetFabResponse:
"""GetRun."""
# Get ffs from app
ffs: Ffs = app.state.FFS_FACTORY.state()
ffs: Ffs = cast(FfsFactory, app.state.FFS_FACTORY).ffs()

# Handle message
return message_handler.get_fab(request=request, ffs=ffs)
Expand Down