diff --git a/flsim/optimizers/__init__.py b/flsim/optimizers/__init__.py index 6a4de65c..a5ed91cd 100644 --- a/flsim/optimizers/__init__.py +++ b/flsim/optimizers/__init__.py @@ -30,7 +30,9 @@ FedAvgWithLROptimizerConfig, FedLAMBOptimizerConfig, FedLARSOptimizerConfig, + ServerFTRLOptimizerConfig, ) + from .sync_aggregators import ( FedAdamSyncAggregatorConfig, FedAvgSyncAggregatorConfig, @@ -178,3 +180,9 @@ node=FedLAMBOptimizerConfig, group="server_optimizer", ) + +ConfigStore.instance().store( + name="base_ftrl_optimizer", + node=ServerFTRLOptimizerConfig, + group="server_optimizer", +) diff --git a/flsim/optimizers/server_optimizers.py b/flsim/optimizers/server_optimizers.py index 5b61fa21..5c1b43e3 100644 --- a/flsim/optimizers/server_optimizers.py +++ b/flsim/optimizers/server_optimizers.py @@ -28,7 +28,7 @@ import torch import torch.nn as nn from flsim.optimizers.layerwise_optimizers import LAMB, LARS -from flsim.utils.config_utils import fullclassname, init_self_cfg, is_target +from flsim.utils.config_utils import fullclassname, init_self_cfg from omegaconf import MISSING @@ -48,7 +48,7 @@ def _set_defaults_in_cfg(cls, cfg): @abc.abstractmethod @torch.no_grad() - def step(self, closure): + def step(self, closure, noise=None): r"""Performs a single optimization step (parameter update). Args: @@ -215,6 +215,105 @@ def zero_grad(self, set_to_none: bool = False): return LAMB.zero_grad(self, set_to_none) +class ServerFTRLOptimizer(IServerOptimizer, torch.optim.Optimizer): + """ + :param params: parameter groups + :param momentum: if non-zero, use DP-FTRLM + :param record_last_noise: whether to record the last noise. for the tree completion trick. + """ + + def __init__(self, *, model: nn.Module, record_last_noise: bool, **kwargs) -> None: + init_self_cfg( + self, + component_class=__class__, + config_class=ServerFTRLOptimizerConfig, + **kwargs, + ) + + IServerOptimizer.__init__(self, model=model, **kwargs) + # pyre-ignore[28] + torch.optim.Optimizer.__init__(self, params=model.parameters(), defaults={}) + # pyre-ignore[16] + self.momentum = self.cfg.momentum + self.lr = self.cfg.lr + self.record_last_noise = record_last_noise + + def __setstate__(self, state): + super(ServerFTRLOptimizer, self).__setstate__(state) + + def zero_grad(self, set_to_none: bool = False): + return torch.optim.Optimizer.zero_grad(self, set_to_none) + + @torch.no_grad() + def step(self, noise, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p, nz in zip(group["params"], noise): + if p.grad is None: + continue + d_p = p.grad + + param_state = self.state[p] + + if len(param_state) == 0: + param_state["grad_sum"] = torch.zeros_like( + d_p, memory_format=torch.preserve_format + ) + param_state["model_sum"] = p.detach().clone( + memory_format=torch.preserve_format + ) # just record the initial model + param_state["momentum"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) + if self.record_last_noise: + param_state["last_noise"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) # record the last noise needed, in order for restarting + + gs, ms = param_state["grad_sum"], param_state["model_sum"] + if self.momentum == 0: + gs.add_(d_p) + p.copy_(ms + (-gs - nz) / self.lr) + else: + gs.add_(d_p) + param_state["momentum"].mul_(self.momentum).add_(gs + nz) + p.copy_(ms - param_state["momentum"] / self.lr) + if self.record_last_noise: + param_state["last_noise"].copy_(nz) + return loss + + @torch.no_grad() + def restart(self, last_noise=None): + """ + Restart the tree. + :param last_noise: the last noise to be added. If none, use the last noise recorded. + """ + assert last_noise is not None or self.record_last_noise + for group in self.param_groups: + if last_noise is None: + for p in group["params"]: + if p.grad is None: + continue + param_state = self.state[p] + if len(param_state) == 0: + continue + param_state["grad_sum"].add_( + param_state["last_noise"] + ) # add the last piece of noise to the current gradient sum + else: + for p, nz in zip(group["params"], last_noise): + if p.grad is None: + continue + param_state = self.state[p] + if len(param_state) == 0: + continue + param_state["grad_sum"].add_(nz) + + @dataclass class ServerOptimizerConfig: _target_: str = MISSING @@ -259,3 +358,10 @@ class FedLAMBOptimizerConfig(ServerOptimizerConfig): beta1: float = 0.9 beta2: float = 0.999 eps: float = 1e-8 + + +@dataclass +class ServerFTRLOptimizerConfig(ServerOptimizerConfig): + _target_: str = fullclassname(ServerFTRLOptimizer) + lr: float = 0.001 + momentum: float = 0.0 diff --git a/flsim/privacy/privacy_engine.py b/flsim/privacy/privacy_engine.py index c9ca52cf..d94aa6a2 100644 --- a/flsim/privacy/privacy_engine.py +++ b/flsim/privacy/privacy_engine.py @@ -12,7 +12,8 @@ import logging import os from abc import ABC, abstractmethod -from typing import Optional +from dataclasses import dataclass +from typing import Any, Optional import torch from flsim.common.logger import Logger @@ -21,6 +22,12 @@ from torch import nn +@dataclass +class TreeNode: + height: int + value: Any + + class PrivacyEngineNotAttachedException(Exception): """ Exception class to be thrown from User Privacy Engine in case @@ -165,3 +172,170 @@ def get_privacy_spent(self, target_delta: Optional[float] = None): "the set of alpha orders." ) return PrivacyBudget(eps, opt_alpha, target_delta) + + +class CummuNoiseTorch: + @torch.no_grad() + def __init__(self, std, shapes, device, test_mode=False, seed=None): + """ + :param std: standard deviation of the noise + :param shapes: shapes of the noise, which is basically shape of the gradients + :param device: device for pytorch tensor + :param test_mode: if in test mode, noise will be 1 in each node of the tree + """ + seed = ( + seed + if seed is not None + else int.from_bytes(os.urandom(8), byteorder="big", signed=True) + ) + self.std = std + self.shapes = shapes + self.device = device + self.step = 0 + self.binary = [0] + self.noise_sum = [torch.zeros(shape).to(self.device) for shape in shapes] + self.recorded = [[torch.zeros(shape).to(self.device) for shape in shapes]] + torch.cuda.manual_seed_all(seed) + self.generator = torch.Generator(device=self.device) + self.generator.manual_seed(seed) + self.test_mode = test_mode + + @torch.no_grad() + def __call__(self): + """ + :return: the noise to be added by DP-FTRL + """ + self.step += 1 + if self.std <= 0 and not self.test_mode: + return self.noise_sum + + idx = 0 + while idx < len(self.binary) and self.binary[idx] == 1: + self.binary[idx] = 0 + for ns, re in zip(self.noise_sum, self.recorded[idx]): + ns -= re + idx += 1 + if idx >= len(self.binary): + self.binary.append(0) + self.recorded.append( + [torch.zeros(shape).to(self.device) for shape in self.shapes] + ) + + for shape, ns, re in zip(self.shapes, self.noise_sum, self.recorded[idx]): + if not self.test_mode: + n = torch.normal( + 0, self.std, shape, generator=self.generator, device=self.device + ) + else: + n = torch.ones(shape).to(self.device) + ns += n + re.copy_(n) + + self.binary[idx] = 1 + return self.noise_sum + + @torch.no_grad() + def proceed_until(self, step_target): + """ + Proceed until the step_target-th step. This is for the binary tree completion trick. + :return: the noise to be added by DP-FTRL + """ + if self.step >= step_target: + raise ValueError(f"Already reached {step_target}.") + while self.step < step_target: + noise_sum = self.__call__() + return noise_sum + + +class CummuNoiseEffTorch: + """ + The tree aggregation protocol with the trick in Honaker, "Efficient Use of Differentially Private Binary Trees", 2015 + """ + + @torch.no_grad() + def __init__(self, std, shapes, device, seed, test_mode=False): + """ + :param std: standard deviation of the noise + :param shapes: shapes of the noise, which is basically shape of the gradients + :param device: device for pytorch tensor + """ + seed = ( + seed + if seed is not None + else int.from_bytes(os.urandom(8), byteorder="big", signed=True) + ) + self.test_mode = test_mode + + self.std = std + self.shapes = shapes + self.device = device + torch.cuda.manual_seed_all(seed) + self.generator = torch.Generator(device=self.device) + self.generator.manual_seed(seed) + self.step = 0 + self.noise_sum = [torch.zeros(shape).to(self.device) for shape in shapes] + self.stack = [] + + @torch.no_grad() + def get_noise(self): + return [ + torch.normal( + 0, self.std, shape, generator=self.generator, device=self.device + ) + for shape in self.shapes + ] + + @torch.no_grad() + def push(self, elem): + for i in range(len(self.shapes)): + self.noise_sum[i] += elem.value[i] / (2.0 - 1 / 2**elem.height) + self.stack.append(elem) + + @torch.no_grad() + def pop(self): + elem = self.stack.pop() + for i in range(len(self.shapes)): + self.noise_sum[i] -= elem.value[i] / (2.0 - 1 / 2**elem.height) + + @torch.no_grad() + def __call__(self): + """ + :return: the noise to be added by DP-FTRL + """ + self.step += 1 + + # add new element to the stack + self.push(TreeNode(0, self.get_noise())) + + # pop the stack + while len(self.stack) >= 2 and self.stack[-1].height == self.stack[-2].height: + # create new element + left_value, right_value = self.stack[-2].value, self.stack[-1].value + new_noise = self.get_noise() + new_elem = TreeNode( + self.stack[-1].height + 1, + [ + x + (y + z) / 2 + for x, y, z in zip(new_noise, left_value, right_value) + ], + ) + + # pop the stack, update sum + self.pop() + self.pop() + + # append to the stack, update sum + self.push(new_elem) + return self.noise_sum + + @torch.no_grad() + def proceed_until(self, step_target): + """ + Proceed until the step_target-th step. This is for the binary tree completion trick. + :return: the noise to be added by DP-FTRL + """ + if self.step >= step_target: + raise ValueError(f"Already reached {step_target}.") + while self.step < step_target: + noise_sum = self.__call__() + return noise_sum diff --git a/flsim/servers/__init__.py b/flsim/servers/__init__.py index 84d6e5c9..23db8986 100644 --- a/flsim/servers/__init__.py +++ b/flsim/servers/__init__.py @@ -8,6 +8,7 @@ from hydra.core.config_store import ConfigStore # @manual from .sync_dp_servers import SyncDPSGDServerConfig +from .sync_ftrl_servers import SyncFTRLServerConfig from .sync_mime_servers import SyncMimeServerConfig from .sync_mimelite_servers import SyncMimeLiteServerConfig from .sync_secagg_servers import SyncSecAggServerConfig, SyncSecAggSQServerConfig @@ -54,3 +55,9 @@ node=SyncSecAggSQServerConfig, group="server", ) + +ConfigStore.instance().store( + name="base_sync_ftrl_server", + node=SyncFTRLServerConfig, + group="server", +) diff --git a/flsim/servers/sync_ftrl_servers.py b/flsim/servers/sync_ftrl_servers.py new file mode 100644 index 00000000..28a99fa1 --- /dev/null +++ b/flsim/servers/sync_ftrl_servers.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from flsim.active_user_selectors.simple_user_selector import ( + UniformlyRandomActiveUserSelectorConfig, +) +from flsim.channels.base_channel import IdentityChannel, IFLChannel +from flsim.channels.message import Message +from flsim.data.data_provider import IFLDataProvider +from flsim.interfaces.model import IFLModel +from flsim.optimizers.server_optimizers import ( + ServerFTRLOptimizerConfig, + ServerOptimizerConfig, +) +from flsim.privacy.common import PrivacySetting +from flsim.privacy.privacy_engine import CummuNoiseEffTorch, CummuNoiseTorch +from flsim.privacy.user_update_clip import UserUpdateClipper +from flsim.servers.aggregator import AggregationType, Aggregator +from flsim.servers.sync_servers import SyncServer, SyncServerConfig +from flsim.utils.config_utils import fullclassname, init_self_cfg +from flsim.utils.distributed.fl_distributed import FLDistributedUtils +from flsim.utils.fl.common import FLModelParamUtils +from hydra.utils import instantiate +from omegaconf import OmegaConf + + +class SyncFTRLServer(SyncServer): + def __init__( + self, + *, + global_model: IFLModel, + channel: Optional[IFLChannel] = None, + **kwargs, + ): + init_self_cfg( + self, + component_class=__class__, # pyre-fixme[10]: Name `__class__` is used but not defined. + config_class=SyncFTRLServerConfig, + **kwargs, + ) + assert ( + self.cfg.aggregation_type == AggregationType.AVERAGE # pyre-ignore[16] + ), "As in https://arxiv.org/pdf/1710.06963.pdf, DP training must be done with simple averaging and uniform weights." + + assert ( + FLDistributedUtils.is_master_worker() + ), "Distributed training is not supported for FTRL" + + self._optimizer = instantiate( + config=self.cfg.server_optimizer, + model=global_model.fl_get_module(), + record_last_noise=True, + ) + self._global_model: IFLModel = global_model + self._aggregator: Aggregator = Aggregator( + module=global_model.fl_get_module(), + aggregation_type=self.cfg.aggregation_type, + only_federated_params=self.cfg.only_federated_params, + ) + self._active_user_selector = instantiate(self.cfg.active_user_selector) + self._channel: IFLChannel = channel or IdentityChannel() + self._restart_rounds = self.cfg.restart_rounds + self._clipping_value = self.cfg.privacy_setting.clipping_value + self._noise_std = self.cfg.privacy_setting.noise_multiplier + self._user_update_clipper = UserUpdateClipper() + self._shapes = [p.shape for p in global_model.fl_get_module().parameters()] + self._device = next(global_model.fl_get_module().parameters()).device + self._privacy_engine = None + + def _create_tree(self, users_per_round): + std = (self._noise_std * self._clipping_value) / users_per_round + if self.cfg.efficient: + self._privacy_engine = CummuNoiseEffTorch( + std=std, + shapes=self._shapes, + device=self._device, + test_mode=False, + seed=self.cfg.privacy_setting.noise_seed, + ) + else: + self._privacy_engine = CummuNoiseTorch( + std=std, + shapes=self._shapes, + device=self._device, + test_mode=False, + seed=self.cfg.privacy_setting.noise_seed, + ) + + @classmethod + def _set_defaults_in_cfg(cls, cfg): + if OmegaConf.is_missing(cfg.active_user_selector, "_target_"): + cfg.active_user_selector = UniformlyRandomActiveUserSelectorConfig() + if OmegaConf.is_missing(cfg.server_optimizer, "_target_"): + cfg.server_optimizer = ServerFTRLOptimizerConfig() + + @property + def global_model(self): + return self._global_model + + def select_clients_for_training( + self, + num_total_users, + users_per_round, + data_provider: Optional[IFLDataProvider] = None, + epoch: Optional[int] = None, + ): + if self._privacy_engine is None: + self._create_tree(users_per_round) + + return self._active_user_selector.get_user_indices( + num_total_users=num_total_users, + users_per_round=users_per_round, + data_provider=data_provider, + global_model=self.global_model, + epoch=epoch, + ) + + def init_round(self): + self._aggregator.zero_weights() + self._optimizer.zero_grad() + + if self.should_restart(): + last_noise = None + if self.cfg.tree_completion: + actual_steps = self._privacy_engine.step * self._restart_rounds + next_pow_2 = 2 ** (actual_steps - 1).bit_length() + if next_pow_2 > actual_steps: + last_noise = self._privacy_engine.proceed_until(next_pow_2) + self._optimizer.restart(last_noise) + + def receive_update_from_client(self, message: Message): + message = self._channel.client_to_server(message) + + self._aggregator.apply_weight_to_update( + delta=message.model.fl_get_module(), weight=message.weight + ) + + self._user_update_clipper.clip( + message.model.fl_get_module(), max_norm=self._clipping_value + ) + self._aggregator.add_update( + delta=message.model.fl_get_module(), weight=message.weight + ) + + def step(self): + aggregated_model = self._aggregator.aggregate() + noise = self._privacy_engine() + FLModelParamUtils.set_gradient( + model=self._global_model.fl_get_module(), + reference_gradient=aggregated_model, + ) + self._optimizer.step(noise) + return noise + + def should_restart(self): + return ( + (self._privacy_engine is not None) + and self._privacy_engine.step != 0 + and ((self._privacy_engine.step + 1) % self._restart_rounds) == 0 + ) + + +@dataclass +class SyncFTRLServerConfig(SyncServerConfig): + _target_: str = fullclassname(SyncFTRLServer) + aggregation_type: AggregationType = AggregationType.AVERAGE + server_optimizer: ServerOptimizerConfig = ServerFTRLOptimizerConfig() + restart_rounds: int = 10000 + privacy_setting: PrivacySetting = PrivacySetting() + efficient: bool = False + tree_completion: bool = False diff --git a/flsim/servers/tests/test_sync_ftrl_servers.py b/flsim/servers/tests/test_sync_ftrl_servers.py new file mode 100644 index 00000000..bf6659e8 --- /dev/null +++ b/flsim/servers/tests/test_sync_ftrl_servers.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import List + +import numpy as np +import pytest +from flsim.channels.message import Message +from flsim.common.pytest_helper import assertEmpty, assertEqual +from flsim.optimizers.server_optimizers import ServerFTRLOptimizerConfig +from flsim.privacy.common import PrivacySetting +from flsim.privacy.privacy_engine import CummuNoiseEffTorch, CummuNoiseTorch +from flsim.servers.sync_ftrl_servers import SyncFTRLServerConfig +from flsim.utils.fl.common import FLModelParamUtils +from flsim.utils.test_utils import ( + linear_model, + SampleNet, + verify_models_equivalent_after_training, +) +from hydra.utils import instantiate + + +@dataclass +class MockClientUpdate: + deltas: List[float] + weights: List[float] + average: float + + +class TestSyncServer: + def _create_client_updates(self, num_clients) -> MockClientUpdate: + deltas = [float(i + 1) for i in range(num_clients)] + weights = [float(i + 1) for i in range(num_clients)] + average = float(np.average(deltas)) + return MockClientUpdate(deltas, weights, average) + + def _setup( + self, + efficient, + noise_multiplier, + clip_value, + total_users, + tree_completion, + restart_rounds, + ): + fl_model = SampleNet(linear_model(0)) + nonfl_model = SampleNet(linear_model(0)) + + optimizer = instantiate( + config=ServerFTRLOptimizerConfig(lr=1.0), + model=nonfl_model.fl_get_module(), + record_last_noise=True, + ) + + if efficient: + noise_gen = CummuNoiseEffTorch( + std=noise_multiplier * clip_value, + shapes=[p.shape for p in nonfl_model.fl_get_module().parameters()], + device="cpu", + seed=0, + ) + else: + noise_gen = CummuNoiseTorch( + std=noise_multiplier * clip_value, + shapes=[p.shape for p in nonfl_model.fl_get_module().parameters()], + device="cpu", + seed=0, + ) + + server = instantiate( + SyncFTRLServerConfig( + server_optimizer=ServerFTRLOptimizerConfig(lr=1.0), + privacy_setting=PrivacySetting( + noise_multiplier=noise_multiplier, + clipping_value=clip_value, + noise_seed=0, + ), + tree_completion=tree_completion, + efficient=efficient, + restart_rounds=restart_rounds, + ), + global_model=fl_model, + ) + + client_updates = self._create_client_updates(total_users) + return server, nonfl_model, fl_model, client_updates, noise_gen, optimizer + + @pytest.mark.parametrize( + "efficient", + [ + True, + False, + ], + ) + def test_ftrl_same_nonfl_server(self, efficient) -> None: + noise_multiplier = 0 + clip_value = 1000 + total_users = 1 + users_per_round = 1 + + ( + server, + nonfl_model, + fl_model, + client_updates, + noise_gen, + optimizer, + ) = self._setup( + efficient, + noise_multiplier, + clip_value, + total_users, + tree_completion=False, + restart_rounds=1000, + ) + + for _ in range(10): + server.init_round() + optimizer.zero_grad() + server.select_clients_for_training(total_users, users_per_round) + for delta, weight in zip(client_updates.deltas, client_updates.weights): + server.receive_update_from_client( + Message(model=SampleNet(linear_model(delta)), weight=weight) + ) + + FLModelParamUtils.set_gradient( + model=nonfl_model.fl_get_module(), + reference_gradient=linear_model(client_updates.average), + ) + noise = noise_gen() + # nonfl_model = init - lr * (grad + noise) + # = 0 - (1 + noise) + # = -(1 + noise) + optimizer.step(noise) + fl_noise = server.step() + assertEqual(sum([p.sum() for p in fl_noise]), sum([p.sum() for p in noise])) + error_msg = verify_models_equivalent_after_training(fl_model, nonfl_model) + assertEmpty(error_msg, msg=error_msg) + + @pytest.mark.parametrize( + "efficient", + [ + True, + False, + ], + ) + def test_ftrl_same_nonfl_server_with_restart(self, efficient): + noise_multiplier = 0 + clip_value = 1000 + total_users = 1 + users_per_round = 1 + + ( + server, + nonfl_model, + fl_model, + client_updates, + noise_gen, + optimizer, + ) = self._setup( + efficient, + noise_multiplier, + clip_value, + total_users, + tree_completion=True, + restart_rounds=1, + ) + + for _ in range(10): + server.init_round() + optimizer.zero_grad() + server.select_clients_for_training(total_users, users_per_round) + for delta, weight in zip(client_updates.deltas, client_updates.weights): + server.receive_update_from_client( + Message(model=SampleNet(linear_model(delta)), weight=weight) + ) + + FLModelParamUtils.set_gradient( + model=nonfl_model.fl_get_module(), + reference_gradient=linear_model(client_updates.average), + ) + noise = noise_gen() + # nonfl_model = init - lr * (grad + noise) + # = 0 - (1 + noise) + # = -(1 + noise) + optimizer.step(noise) + fl_noise = server.step() + assertEqual(sum([p.sum() for p in fl_noise]), sum([p.sum() for p in noise])) + error_msg = verify_models_equivalent_after_training(fl_model, nonfl_model) + assertEmpty(error_msg, msg=error_msg) diff --git a/flsim/utils/test_utils.py b/flsim/utils/test_utils.py index c0e5e245..6d6d11bd 100644 --- a/flsim/utils/test_utils.py +++ b/flsim/utils/test_utils.py @@ -565,3 +565,9 @@ def create_model_with_value(value) -> nn.Module: model = TwoFC() model.fill_all(value) return model + + +def linear_model(value) -> nn.Module: + model = Linear() + model.fill_all(value) + return model