Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) Add get_pending_run_id method to LinkState #4357

Merged
merged 39 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
049e3d4
init
panh99 Sep 18, 2024
52721e0
rename and introduce constants
panh99 Sep 18, 2024
6779654
fix field name
panh99 Sep 18, 2024
0646b5c
fix comments
panh99 Sep 18, 2024
cfbd1ec
Merge branch 'main' into run-mgmt-state
panh99 Sep 19, 2024
8d38b65
format
panh99 Sep 19, 2024
8d19b85
Merge branch 'main' into run-mgmt-state
panh99 Sep 23, 2024
ee578b9
rename enum
panh99 Sep 23, 2024
2e75d26
rename dataclass
panh99 Sep 23, 2024
85579df
update utils
panh99 Sep 23, 2024
68b3154
update states
panh99 Sep 23, 2024
e219662
fix sqlite
panh99 Sep 23, 2024
3f5a116
Merge branch 'main' into run-mgmt-state
danieljanes Sep 23, 2024
3d6c3bb
Merge branch 'main' into run-mgmt-state
panh99 Sep 23, 2024
aeac5b1
init sqlite_state_utils
panh99 Sep 23, 2024
0eb8ad1
format
panh99 Sep 23, 2024
8765100
add timestamps for in memory state
panh99 Sep 23, 2024
4e16416
keep utils in the same file
panh99 Sep 24, 2024
b7e931c
Merge branch 'main' into run-mgmt-state
panh99 Sep 27, 2024
30e4649
rename variables and improve logging
panh99 Sep 27, 2024
521b38f
Merge remote-tracking branch 'origin/main' into run-mgmt-state
panh99 Oct 21, 2024
9c79cb5
format
panh99 Oct 21, 2024
9474f61
Merge branch 'main' into run-mgmt-state
panh99 Oct 22, 2024
3f74076
correct wording
panh99 Oct 22, 2024
71bc27c
restore changes
panh99 Oct 22, 2024
cbab64e
Update src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
panh99 Oct 23, 2024
a55811b
add pending
panh99 Oct 23, 2024
c3bac06
apply set
panh99 Oct 23, 2024
55f740f
Merge branch 'main' into run-mgmt-state
panh99 Oct 23, 2024
9b49a20
Merge branch 'main' into run-mgmt-state
panh99 Oct 23, 2024
a6aed95
apply suggestion
panh99 Oct 23, 2024
48828ec
init
jafermarq Oct 23, 2024
8012662
fix
jafermarq Oct 23, 2024
639ec38
fix
jafermarq Oct 23, 2024
ae9fd4d
+ Pan`s suggestion
jafermarq Oct 23, 2024
88ae666
w/ tests
jafermarq Oct 23, 2024
9d4c908
tweak
jafermarq Oct 23, 2024
0a26fca
Merge branch 'main' into get-pending-run-id-linkstate
jafermarq Oct 23, 2024
a9c40b2
fix for py3.9
jafermarq Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,28 @@ class ErrorCode:
def __new__(cls) -> ErrorCode:
"""Prevent instantiation."""
raise TypeError(f"{cls.__name__} cannot be instantiated.")


class Status:
"""Run status."""

PENDING = "pending"
STARTING = "starting"
RUNNING = "running"
FINISHED = "finished"

def __new__(cls) -> Status:
"""Prevent instantiation."""
raise TypeError(f"{cls.__name__} cannot be instantiated.")


class SubStatus:
"""Run sub-status."""

COMPLETED = "completed"
FAILED = "failed"
STOPPED = "stopped"

def __new__(cls) -> SubStatus:
"""Prevent instantiation."""
raise TypeError(f"{cls.__name__} cannot be instantiated.")
9 changes: 9 additions & 0 deletions src/py/flwr/common/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,15 @@ class Run:
override_config: UserConfig


@dataclass
class RunStatus:
"""Run status information."""

status: str
sub_status: str
details: str


@dataclass
class Fab:
"""Fab file representation."""
Expand Down
23 changes: 16 additions & 7 deletions src/py/flwr/server/superlink/fleet/vce/vce_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,17 @@
RecordSet,
Scalar,
)
from flwr.common.constant import Status
from flwr.common.recordset_compat import getpropertiesins_to_recordset
from flwr.common.serde import message_from_taskres, message_to_taskins
from flwr.common.typing import Run
from flwr.common.typing import Run, RunStatus
from flwr.server.superlink.fleet.vce.vce_api import (
NodeToPartitionMapping,
_register_nodes,
start_vce,
)
from flwr.server.superlink.linkstate import InMemoryLinkState, LinkStateFactory
from flwr.server.superlink.linkstate.in_memory_linkstate import RunRecord


class DummyClient(NumPyClient):
Expand Down Expand Up @@ -113,12 +115,19 @@ def register_messages_into_state(
) -> dict[UUID, float]:
"""Register `num_messages` into the state factory."""
state: InMemoryLinkState = state_factory.state() # type: ignore
state.run_ids[run_id] = Run(
run_id=run_id,
fab_id="Mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
state.run_ids[run_id] = RunRecord(
Run(
run_id=run_id,
fab_id="Mock/mock",
fab_version="v1.0.0",
fab_hash="hash",
override_config={},
),
RunStatus(
status=Status.PENDING,
sub_status="",
details="",
),
)
# Artificially add TaskIns to state so they can be processed
# by the Simulation Engine logic
Expand Down
112 changes: 101 additions & 11 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import threading
import time
from dataclasses import dataclass
from logging import ERROR, WARNING
from typing import Optional
from uuid import UUID, uuid4
Expand All @@ -26,13 +27,31 @@
MESSAGE_TTL_TOLERANCE,
NODE_ID_NUM_BYTES,
RUN_ID_NUM_BYTES,
Status,
)
from flwr.common.typing import Run, UserConfig
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.linkstate.linkstate import LinkState
from flwr.server.utils import validate_task_ins_or_res

from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
from .utils import (
generate_rand_int_from_bytes,
has_valid_sub_status,
is_valid_transition,
make_node_unavailable_taskres,
)


@dataclass
class RunRecord:
"""The record of a specific run, including its status and timestamps."""

run: Run
status: RunStatus
pending_at: str = ""
starting_at: str = ""
running_at: str = ""
finished_at: str = ""


class InMemoryLinkState(LinkState): # pylint: disable=R0902,R0904
Expand All @@ -44,8 +63,8 @@ def __init__(self) -> None:
self.node_ids: dict[int, tuple[float, float]] = {}
self.public_key_to_node_id: dict[bytes, int] = {}

# Map run_id to (fab_id, fab_version)
self.run_ids: dict[int, Run] = {}
# Map run_id to RunRecord
self.run_ids: dict[int, RunRecord] = {}
self.task_ins_store: dict[UUID, TaskIns] = {}
self.task_res_store: dict[UUID, TaskRes] = {}

Expand Down Expand Up @@ -351,13 +370,22 @@ def create_run(
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

if run_id not in self.run_ids:
self.run_ids[run_id] = Run(
run_id=run_id,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
run_record = RunRecord(
run=Run(
run_id=run_id,
fab_id=fab_id if fab_id else "",
fab_version=fab_version if fab_version else "",
fab_hash=fab_hash if fab_hash else "",
override_config=override_config,
),
status=RunStatus(
status=Status.PENDING,
sub_status="",
details="",
),
pending_at=now().isoformat(),
)
self.run_ids[run_id] = run_record
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -401,7 +429,69 @@ def get_run(self, run_id: int) -> Optional[Run]:
if run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
return None
return self.run_ids[run_id]
return self.run_ids[run_id].run

def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Retrieve the statuses for the specified runs."""
with self.lock:
return {
run_id: self.run_ids[run_id].status
for run_id in set(run_ids)
if run_id in self.run_ids
}

def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`."""
with self.lock:
# Check if the run_id exists
if run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
return False

# Check if the status transition is valid
current_status = self.run_ids[run_id].status
if not is_valid_transition(current_status, new_status):
log(
ERROR,
'Invalid status transition: from "%s" to "%s"',
current_status.status,
new_status.status,
)
return False

# Check if the sub-status is valid
if not has_valid_sub_status(current_status):
log(
ERROR,
'Invalid sub-status "%s" for status "%s"',
current_status.sub_status,
current_status.status,
)
return False

# Update the status
run_record = self.run_ids[run_id]
if new_status.status == Status.STARTING:
run_record.starting_at = now().isoformat()
elif new_status.status == Status.RUNNING:
run_record.running_at = now().isoformat()
elif new_status.status == Status.FINISHED:
run_record.finished_at = now().isoformat()
run_record.status = new_status
return True

def get_pending_run_id(self) -> int | None:
"""Get the `run_id` of a run with `Status.PENDING` status, if any."""
pending_run_id = None

# Loop through all registered runs
for run_id, run_rec in self.run_ids.items():
# Break once a pending run is found
if run_rec.status.status == Status.PENDING:
pending_run_id = run_id
break

return pending_run_id

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat."""
Expand Down
50 changes: 49 additions & 1 deletion src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Optional
from uuid import UUID

from flwr.common.typing import Run, UserConfig
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611


Expand Down Expand Up @@ -178,6 +178,54 @@ def get_run(self, run_id: int) -> Optional[Run]:
- `fab_version`: The version of the FAB used in the specified run.
"""

@abc.abstractmethod
def get_run_status(self, run_ids: set[int]) -> dict[int, RunStatus]:
"""Retrieve the statuses for the specified runs.

Parameters
----------
run_ids : set[int]
A set of run identifiers for which to retrieve statuses.

Returns
-------
dict[int, RunStatus]
A dictionary mapping each valid run ID to its corresponding status.

Notes
-----
Only valid run IDs that exist in the State will be included in the returned
dictionary. If a run ID is not found, it will be omitted from the result.
"""

@abc.abstractmethod
def update_run_status(self, run_id: int, new_status: RunStatus) -> bool:
"""Update the status of the run with the specified `run_id`.

Parameters
----------
run_id : int
The identifier of the run.
new_status : RunStatus
The new status to be assigned to the run.

Returns
-------
bool
True if the status update is successful; False otherwise.
"""

@abc.abstractmethod
def get_pending_run_id(self) -> Optional[int]:
"""Get the `run_id` of a run with `Status.PENDING` status.

Returns
-------
Optional[int]
The `run_id` of a `Run` that is pending to be started; None if
there is no Run pending.
"""

@abc.abstractmethod
def store_server_private_public_key(
self, private_key: bytes, public_key: bytes
Expand Down
Loading
Loading