Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve testing in state #1675

Merged
merged 13 commits into from
Feb 20, 2023
26 changes: 23 additions & 3 deletions src/py/flwr/server/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@


from datetime import datetime, timedelta, timezone
from logging import ERROR
from typing import Dict, List, Optional, Set
from uuid import UUID, uuid4

from flwr.common.logger import log
from flwr.proto.task_pb2 import TaskIns, TaskRes

from .state import State
from flwr.server.state.state import State
from flwr.server.utils.validator import validate_task_ins_or_res


class InMemoryState(State):
Expand All @@ -35,6 +37,12 @@ def __init__(self) -> None:
def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
"""Store one TaskIns."""

# Validate task
errors = validate_task_ins_or_res(task_ins)
if any(errors):
log(ERROR, errors)
return None

# Create and set task_id
task_id = uuid4()
task_ins.task_id = str(task_id)
Expand Down Expand Up @@ -89,11 +97,17 @@ def get_task_ins(
def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
"""Store one TaskRes."""

# Validate task
errors = validate_task_ins_or_res(task_res)
if any(errors):
log(ERROR, errors)
return None

# Create and set task_id
task_id = uuid4()
task_res.task_id = str(task_id)

# Set created_at
# Set created_at and ttl
created_at: datetime = _now()
ttl: datetime = created_at + timedelta(hours=24)

Expand Down Expand Up @@ -152,6 +166,12 @@ def delete_tasks(self, task_ids: Set[UUID]) -> None:
for task_id in task_res_to_be_deleted:
del self.task_res_store[task_id]

def num_task_ins(self) -> int:
return len(self.task_ins_store)

def num_task_res(self) -> int:
return len(self.task_res_store)

def register_node(self, node_id: int) -> None:
"""Register a client node."""
if node_id in self.node_ids:
Expand Down
227 changes: 0 additions & 227 deletions src/py/flwr/server/state/in_memory_state_test.py

This file was deleted.

14 changes: 14 additions & 0 deletions src/py/flwr/server/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe
available. If `limit` is set, it has to be greater zero.
"""

@abc.abstractmethod
def num_task_ins(self) -> int:
"""Number of task_ins in store.

This includes delivered but not yet deleted task_ins.
"""

@abc.abstractmethod
def num_task_res(self) -> int:
"""Number of task_res in store.

This includes delivered but not yet deleted task_res.
"""

@abc.abstractmethod
def delete_tasks(self, task_ids: Set[UUID]) -> None:
"""Delete all delivered TaskIns/TaskRes pairs."""
Expand Down
Loading