Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add E2E test for WorkloadState and NodeState #2696

Merged
merged 4 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions e2e/bare/client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from datetime import datetime

import flwr as fl
import numpy as np

SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'


model_params = np.array([1])
Expand All @@ -12,16 +15,29 @@ class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return model_params

def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

def fit(self, parameters, config):
model_params = parameters
model_params = [param * (objective/np.mean(param)) for param in model_params]
return model_params, 1, {}
self._record_timestamp_to_state()
return model_params, 1, {STATE_VAR: self._retrieve_timestamp_from_state()}

def evaluate(self, parameters, config):
model_params = parameters
loss = min(np.abs(1 - np.mean(model_params)/objective), 1)
accuracy = 1 - loss
return loss, 1, {"accuracy": accuracy}
self._record_timestamp_to_state()
return loss, 1, {"accuracy": accuracy, STATE_VAR: self._retrieve_timestamp_from_state()}

def client_fn(cid):
return FlowerClient().to_client()
Expand Down
30 changes: 30 additions & 0 deletions e2e/bare/simulation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,41 @@
from typing import List, Tuple
import numpy as np

import flwr as fl
from flwr.common import Metrics

from client import client_fn
STATE_VAR = 'timestamp'


# Define metric aggregation function
def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Ensure that timestamps are monotonically increasing."""
states = []
for _, m in metrics:
# split string and covert timestamps to float
states.append([float(tt) for tt in m[STATE_VAR].split(',')])

for client_state in states:
if len(client_state) == 1:
continue
deltas = np.diff(client_state)
assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}"

return {STATE_VAR: states}


strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics)

hist = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98

# The checks in record_state_metrics don't do anythinng if client's state has a single entry
state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1]
assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds"
19 changes: 17 additions & 2 deletions e2e/pytorch/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from collections import OrderedDict
from datetime import datetime

import torch
import torch.nn as nn
Expand All @@ -18,6 +19,7 @@
warnings.filterwarnings("ignore", category=UserWarning)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
SUBSET_SIZE = 1000
STATE_VAR = 'timestamp'


class Net(nn.Module):
Expand Down Expand Up @@ -89,16 +91,29 @@ def load_data():
class FlowerClient(fl.client.NumPyClient):
def get_parameters(self, config):
return [val.cpu().numpy() for _, val in net.state_dict().items()]

def _record_timestamp_to_state(self):
"""Record timestamp to client's state."""
t_stamp = datetime.now().timestamp()
if STATE_VAR in self.state.state:
self.state.state[STATE_VAR] += f",{t_stamp}"
else:
self.state.state[STATE_VAR] = str(t_stamp)

def _retrieve_timestamp_from_state(self):
return self.state.state[STATE_VAR]

def fit(self, parameters, config):
set_parameters(net, parameters)
train(net, trainloader, epochs=1)
return self.get_parameters(config={}), len(trainloader.dataset), {}
self._record_timestamp_to_state()
return self.get_parameters(config={}), len(trainloader.dataset), {STATE_VAR: self._retrieve_timestamp_from_state()}

def evaluate(self, parameters, config):
set_parameters(net, parameters)
loss, accuracy = test(net, testloader)
return loss, len(testloader.dataset), {"accuracy": accuracy}
self._record_timestamp_to_state()
return loss, len(testloader.dataset), {"accuracy": accuracy, STATE_VAR: self._retrieve_timestamp_from_state()}

def set_parameters(model, parameters):
params_dict = zip(model.state_dict().keys(), parameters)
Expand Down
31 changes: 31 additions & 0 deletions e2e/pytorch/simulation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,42 @@
from typing import List, Tuple
import numpy as np

import flwr as fl
from flwr.common import Metrics


from client import client_fn
STATE_VAR = 'timestamp'


# Define metric aggregation function
def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Ensure that timestamps are monotonically increasing."""
states = []
for _, m in metrics:
# split string and covert timestamps to float
states.append([float(tt) for tt in m[STATE_VAR].split(',')])

for client_state in states:
if len(client_state) == 1:
continue
deltas = np.diff(client_state)
assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}"

return {STATE_VAR: states}


strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics)

hist = fl.simulation.start_simulation(
client_fn=client_fn,
num_clients=2,
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98

# The checks in record_state_metrics don't do anythinng if client's state has a single entry
state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1]
assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds"
39 changes: 39 additions & 0 deletions e2e/server.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,47 @@
from typing import List, Tuple
import numpy as np


import flwr as fl
from flwr.common import Metrics
STATE_VAR = 'timestamp'


# Define metric aggregation function
def record_state_metrics(metrics: List[Tuple[int, Metrics]]) -> Metrics:
"""Ensure that timestamps are monotonically increasing."""
if not metrics:
return {}

if STATE_VAR not in metrics[0][1]:
# Do nothing if keyword is not present
return {}

states = []
for _, m in metrics:
# split string and covert timestamps to float
states.append([float(tt) for tt in m[STATE_VAR].split(',')])

for client_state in states:
if len(client_state) == 1:
continue
deltas = np.diff(client_state)
assert np.all(deltas > 0), f"Timestamps are not monotonically increasing: {client_state}"

return {STATE_VAR: states}


strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=record_state_metrics)

hist = fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3),
strategy=strategy,
)

assert hist.losses_distributed[-1][1] == 0 or (hist.losses_distributed[0][1] / hist.losses_distributed[-1][1]) >= 0.98

if STATE_VAR in hist.metrics_distributed:
# The checks in record_state_metrics don't do anythinng if client's state has a single entry
state_metrics_last_round = hist.metrics_distributed[STATE_VAR][-1]
assert len(state_metrics_last_round[1][0]) == 2*state_metrics_last_round[0], f"There should be twice as many entries in the client state as rounds"