From e573e6e22b416d656ea0022df3aa1b130fd90ca4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pedro=20Porto=20Buarque=20de=20Gusm=C3=A3o?= Date: Wed, 22 Dec 2021 11:11:30 +0000 Subject: [PATCH] Updating FedOpt and subclasses to latest version of the paper (#895) * Modified the m_t equation of FEDADAM algorithm. The original delta_t calculation was wrong and should have been changed to m_t, otherwise the result of v_t would have been affected. Before the correction, the accuracy and loss value would have collapsed after the second round; after the correction, the gradient could be updated correctly. * Modified the m_t equation of FEDYOGI algorithm. The original delta_t calculation was wrong and should have been changed to m_t, otherwise the result of v_t would have been affected. Before the correction, the accuracy and loss value would have collapsed after the second round; after the correction, the gradient could be updated correctly. * Updating FedOpt and subclasses to latest version of the paper Co-authored-by: Kuihao <56499195+kuihao@users.noreply.github.com> --- src/py/flwr/server/strategy/fedadagrad.py | 32 ++++++++++++----------- src/py/flwr/server/strategy/fedadam.py | 30 ++++++++++----------- src/py/flwr/server/strategy/fedopt.py | 4 ++- src/py/flwr/server/strategy/fedyogi.py | 29 +++++++++----------- 4 files changed, 47 insertions(+), 48 deletions(-) diff --git a/src/py/flwr/server/strategy/fedadagrad.py b/src/py/flwr/server/strategy/fedadagrad.py index 0d59faff7761..e42df3ae4e3b 100644 --- a/src/py/flwr/server/strategy/fedadagrad.py +++ b/src/py/flwr/server/strategy/fedadagrad.py @@ -65,7 +65,7 @@ def __init__( ) -> None: """Federated learning strategy using Adagrad on server-side. - Implementation based on https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 Args: fraction_fit (float, optional): Fraction of clients used during @@ -109,7 +109,6 @@ def __init__( beta_2=0.0, tau=tau, ) - self.v_t: Optional[Weights] = None def __repr__(self) -> str: rep = f"FedAdagrad(accept_failures={self.accept_failures})" @@ -129,26 +128,29 @@ def aggregate_fit( return None, {} fedavg_weights_aggregate = parameters_to_weights(fedavg_parameters_aggregated) - aggregated_updates = [ - subset_weights - self.current_weights[idx] - for idx, subset_weights in enumerate(fedavg_weights_aggregate) - ] # Adagrad - delta_t = aggregated_updates - if not self.v_t: - self.v_t = [np.zeros_like(subset_weights) for subset_weights in delta_t] + delta_t = [ + x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights) + ] - self.v_t = [ - self.v_t[idx] + np.multiply(subset_weights, subset_weights) - for idx, subset_weights in enumerate(delta_t) + # m_t + if not self.m_t: + self.m_t = [np.zeros_like(x) for x in delta_t] + self.m_t = [ + self.beta_1 * x + (1 - self.beta_1) * y for x, y in zip(self.m_t, delta_t) ] + # v_t + if not self.v_t: + self.v_t = [np.zeros_like(x) for x in delta_t] + self.v_t = [x + np.multiply(y, y) for x, y in zip(self.v_t, delta_t)] + new_weights = [ - self.current_weights[idx] - + self.eta * delta_t[idx] / (np.sqrt(self.v_t[idx]) + self.tau) - for idx in range(len(delta_t)) + x + self.eta * y / (np.sqrt(z) + self.tau) + for x, y, z in zip(self.current_weights, self.m_t, self.v_t) ] + self.current_weights = new_weights return weights_to_parameters(self.current_weights), metrics_aggregated diff --git a/src/py/flwr/server/strategy/fedadam.py b/src/py/flwr/server/strategy/fedadam.py index 755b4d8220eb..f5db9a9c50c2 100644 --- a/src/py/flwr/server/strategy/fedadam.py +++ b/src/py/flwr/server/strategy/fedadam.py @@ -67,7 +67,7 @@ def __init__( ) -> None: """Federated learning strategy using Adagrad on server-side. - Implementation based on https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 Args: fraction_fit (float, optional): Fraction of clients used during @@ -113,8 +113,6 @@ def __init__( beta_2=beta_2, tau=tau, ) - self.delta_t: Optional[Weights] = None - self.v_t: Optional[Weights] = None def __repr__(self) -> str: rep = f"FedAdam(accept_failures={self.accept_failures})" @@ -134,30 +132,30 @@ def aggregate_fit( return None, {} fedavg_weights_aggregate = parameters_to_weights(fedavg_parameters_aggregated) - aggregated_updates = [ - x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights) - ] # Adam - if not self.delta_t: - self.delta_t = [np.zeros_like(x) for x in self.current_weights] + delta_t = [ + x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights) + ] - self.delta_t = [ - self.beta_1 * x + (1.0 - self.beta_1) * y - for x, y in zip(self.delta_t, aggregated_updates) + # m_t + if not self.m_t: + self.m_t = [np.zeros_like(x) for x in delta_t] + self.m_t = [ + self.beta_1 * x + (1 - self.beta_1) * y for x, y in zip(self.m_t, delta_t) ] + # v_t if not self.v_t: - self.v_t = [np.zeros_like(x) for x in self.delta_t] - + self.v_t = [np.zeros_like(x) for x in delta_t] self.v_t = [ - self.beta_2 * x + (1.0 - self.beta_2) * np.multiply(y, y) - for x, y in zip(self.v_t, self.delta_t) + self.beta_2 * x + (1 - self.beta_2) * np.multiply(y, y) + for x, y in zip(self.v_t, delta_t) ] new_weights = [ x + self.eta * y / (np.sqrt(z) + self.tau) - for x, y, z in zip(self.current_weights, self.delta_t, self.v_t) + for x, y, z in zip(self.current_weights, self.m_t, self.v_t) ] self.current_weights = new_weights diff --git a/src/py/flwr/server/strategy/fedopt.py b/src/py/flwr/server/strategy/fedopt.py index 8c6f7bae86e1..f3293b51f2ce 100644 --- a/src/py/flwr/server/strategy/fedopt.py +++ b/src/py/flwr/server/strategy/fedopt.py @@ -53,7 +53,7 @@ def __init__( ) -> None: """Federated Optim strategy interface. - Implementation based on https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 Args: fraction_fit (float, optional): Fraction of clients used during @@ -100,6 +100,8 @@ def __init__( self.tau = tau self.beta_1 = beta_1 self.beta_2 = beta_2 + self.m_t: Optional[Weights] = None + self.v_t: Optional[Weights] = None def __repr__(self) -> str: rep = f"FedOpt(accept_failures={self.accept_failures})" diff --git a/src/py/flwr/server/strategy/fedyogi.py b/src/py/flwr/server/strategy/fedyogi.py index 0d9b8cd4bd49..565280920081 100644 --- a/src/py/flwr/server/strategy/fedyogi.py +++ b/src/py/flwr/server/strategy/fedyogi.py @@ -67,7 +67,7 @@ def __init__( ) -> None: """Federated learning strategy using Yogi on server-side. - Implementation based on https://arxiv.org/abs/2003.00295 + Implementation based on https://arxiv.org/abs/2003.00295v5 Args: fraction_fit (float, optional): Fraction of clients used during @@ -113,8 +113,6 @@ def __init__( beta_2=beta_2, tau=tau, ) - self.delta_t: Optional[Weights] = None - self.v_t: Optional[Weights] = None def __repr__(self) -> str: rep = f"FedYogi(accept_failures={self.accept_failures})" @@ -134,31 +132,30 @@ def aggregate_fit( return None, {} fedavg_weights_aggregate = parameters_to_weights(fedavg_parameters_aggregated) - aggregated_updates = [ - x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights) - ] # Yogi + delta_t = [ + x - y for x, y in zip(fedavg_weights_aggregate, self.current_weights) + ] - if not self.delta_t: - self.delta_t = [np.zeros_like(x) for x in self.current_weights] - - self.delta_t = [ - self.beta_1 * x + (1.0 - self.beta_1) * y - for x, y in zip(self.delta_t, aggregated_updates) + # m_t + if not self.m_t: + self.m_t = [np.zeros_like(x) for x in delta_t] + self.m_t = [ + self.beta_1 * x + (1 - self.beta_1) * y for x, y in zip(self.m_t, delta_t) ] + # v_t if not self.v_t: - self.v_t = [np.zeros_like(x) for x in self.delta_t] - + self.v_t = [np.zeros_like(x) for x in delta_t] self.v_t = [ x - (1.0 - self.beta_2) * np.multiply(y, y) * np.sign(x - np.multiply(y, y)) - for x, y in zip(self.v_t, self.delta_t) + for x, y in zip(self.v_t, delta_t) ] new_weights = [ x + self.eta * y / (np.sqrt(z) + self.tau) - for x, y, z in zip(self.current_weights, self.delta_t, self.v_t) + for x, y, z in zip(self.current_weights, self.m_t, self.v_t) ] self.current_weights = new_weights