Skip to content

Commit

Permalink
Updating FedOpt and subclasses to latest version of the paper (#895)
Browse files Browse the repository at this point in the history
* Modified the m_t equation of FEDADAM algorithm.

The original delta_t calculation was wrong and should have been changed to m_t, otherwise the result of v_t would have been affected. Before the correction, the accuracy and loss value would have collapsed after the second round; after the correction, the gradient could be updated correctly.

* Modified the m_t equation of FEDYOGI algorithm.

The original delta_t calculation was wrong and should have been changed to m_t, otherwise the result of v_t would have been affected. Before the correction, the accuracy and loss value would have collapsed after the second round; after the correction, the gradient could be updated correctly.

* Updating FedOpt and subclasses to latest version of the paper

Co-authored-by: Kuihao <56499195+kuihao@users.noreply.github.com>
  • Loading branch information
pedropgusmao and kuihao authored Dec 22, 2021
1 parent 8b90548 commit e573e6e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 48 deletions.
32 changes: 17 additions & 15 deletions src/py/flwr/server/strategy/fedadagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
) -> None:
"""Federated learning strategy using Adagrad on server-side.
Implementation based on https://arxiv.org/abs/2003.00295
Implementation based on https://arxiv.org/abs/2003.00295v5
Args:
fraction_fit (float, optional): Fraction of clients used during
Expand Down Expand Up @@ -109,7 +109,6 @@ def __init__(
beta_2=0.0,
tau=tau,
)
self.v_t: Optional[Weights] = None

def __repr__(self) -> str:
rep = f"FedAdagrad(accept_failures={self.accept_failures})"
Expand All @@ -129,26 +128,29 @@ def aggregate_fit(
return None, {}

fedavg_weights_aggregate = parameters_to_weights(fedavg_parameters_aggregated)
aggregated_updates = [
subset_weights - self.current_weights[idx]
for idx, subset_weights in enumerate(fedavg_weights_aggregate)
]

# Adagrad
delta_t = aggregated_updates
if not self.v_t:
self.v_t = [np.zeros_like(subset_weights) for subset_weights in delta_t]
delta_t = [
x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
]

self.v_t = [
self.v_t[idx] + np.multiply(subset_weights, subset_weights)
for idx, subset_weights in enumerate(delta_t)
# m_t
if not self.m_t:
self.m_t = [np.zeros_like(x) for x in delta_t]
self.m_t = [
self.beta_1 * x + (1 - self.beta_1) * y for x, y in zip(self.m_t, delta_t)
]

# v_t
if not self.v_t:
self.v_t = [np.zeros_like(x) for x in delta_t]
self.v_t = [x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t)]

new_weights = [
self.current_weights[idx]
+ self.eta * delta_t[idx] / (np.sqrt(self.v_t[idx]) + self.tau)
for idx in range(len(delta_t))
x + self.eta * y / (np.sqrt(z) + self.tau)
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
]

self.current_weights = new_weights

return weights_to_parameters(self.current_weights), metrics_aggregated
30 changes: 14 additions & 16 deletions src/py/flwr/server/strategy/fedadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
) -> None:
"""Federated learning strategy using Adagrad on server-side.
Implementation based on https://arxiv.org/abs/2003.00295
Implementation based on https://arxiv.org/abs/2003.00295v5
Args:
fraction_fit (float, optional): Fraction of clients used during
Expand Down Expand Up @@ -113,8 +113,6 @@ def __init__(
beta_2=beta_2,
tau=tau,
)
self.delta_t: Optional[Weights] = None
self.v_t: Optional[Weights] = None

def __repr__(self) -> str:
rep = f"FedAdam(accept_failures={self.accept_failures})"
Expand All @@ -134,30 +132,30 @@ def aggregate_fit(
return None, {}

fedavg_weights_aggregate = parameters_to_weights(fedavg_parameters_aggregated)
aggregated_updates = [
x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
]

# Adam
if not self.delta_t:
self.delta_t = [np.zeros_like(x) for x in self.current_weights]
delta_t = [
x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
]

self.delta_t = [
self.beta_1 * x + (1.0 - self.beta_1) * y
for x, y in zip(self.delta_t, aggregated_updates)
# m_t
if not self.m_t:
self.m_t = [np.zeros_like(x) for x in delta_t]
self.m_t = [
self.beta_1 * x + (1 - self.beta_1) * y for x, y in zip(self.m_t, delta_t)
]

# v_t
if not self.v_t:
self.v_t = [np.zeros_like(x) for x in self.delta_t]

self.v_t = [np.zeros_like(x) for x in delta_t]
self.v_t = [
self.beta_2 * x + (1.0 - self.beta_2) * np.multiply(y, y)
for x, y in zip(self.v_t, self.delta_t)
self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y)
for x, y in zip(self.v_t, delta_t)
]

new_weights = [
x + self.eta * y / (np.sqrt(z) + self.tau)
for x, y, z in zip(self.current_weights, self.delta_t, self.v_t)
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
]

self.current_weights = new_weights
Expand Down
4 changes: 3 additions & 1 deletion src/py/flwr/server/strategy/fedopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
) -> None:
"""Federated Optim strategy interface.
Implementation based on https://arxiv.org/abs/2003.00295
Implementation based on https://arxiv.org/abs/2003.00295v5
Args:
fraction_fit (float, optional): Fraction of clients used during
Expand Down Expand Up @@ -100,6 +100,8 @@ def __init__(
self.tau = tau
self.beta_1 = beta_1
self.beta_2 = beta_2
self.m_t: Optional[Weights] = None
self.v_t: Optional[Weights] = None

def __repr__(self) -> str:
rep = f"FedOpt(accept_failures={self.accept_failures})"
Expand Down
29 changes: 13 additions & 16 deletions src/py/flwr/server/strategy/fedyogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
) -> None:
"""Federated learning strategy using Yogi on server-side.
Implementation based on https://arxiv.org/abs/2003.00295
Implementation based on https://arxiv.org/abs/2003.00295v5
Args:
fraction_fit (float, optional): Fraction of clients used during
Expand Down Expand Up @@ -113,8 +113,6 @@ def __init__(
beta_2=beta_2,
tau=tau,
)
self.delta_t: Optional[Weights] = None
self.v_t: Optional[Weights] = None

def __repr__(self) -> str:
rep = f"FedYogi(accept_failures={self.accept_failures})"
Expand All @@ -134,31 +132,30 @@ def aggregate_fit(
return None, {}

fedavg_weights_aggregate = parameters_to_weights(fedavg_parameters_aggregated)
aggregated_updates = [
x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
]

# Yogi
delta_t = [
x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights)
]

if not self.delta_t:
self.delta_t = [np.zeros_like(x) for x in self.current_weights]

self.delta_t = [
self.beta_1 * x + (1.0 - self.beta_1) * y
for x, y in zip(self.delta_t, aggregated_updates)
# m_t
if not self.m_t:
self.m_t = [np.zeros_like(x) for x in delta_t]
self.m_t = [
self.beta_1 * x + (1 - self.beta_1) * y for x, y in zip(self.m_t, delta_t)
]

# v_t
if not self.v_t:
self.v_t = [np.zeros_like(x) for x in self.delta_t]

self.v_t = [np.zeros_like(x) for x in delta_t]
self.v_t = [
x - (1.0 - self.beta_2) * np.multiply(y, y) * np.sign(x - np.multiply(y, y))
for x, y in zip(self.v_t, self.delta_t)
for x, y in zip(self.v_t, delta_t)
]

new_weights = [
x + self.eta * y / (np.sqrt(z) + self.tau)
for x, y, z in zip(self.current_weights, self.delta_t, self.v_t)
for x, y, z in zip(self.current_weights, self.m_t, self.v_t)
]

self.current_weights = new_weights
Expand Down

0 comments on commit e573e6e

Please sign in to comment.