Skip to content

Commit

Permalink
Add DP-FTRL to flsim
Browse files Browse the repository at this point in the history
Summary:
## What?
* Adding DP-FTRL along with FTRL optimizer to FLSim
* The privacy engine is copied from https://github.com/google-research/DP-FTRL. Most notably, `CummuNoiseTorch` and `CummuNoiseEffTorch`
* The `SyncFTRLServer ` wraps logic for tree completion and restart similar to that of in https://github.com/google-research/DP-FTRL/blob/main/main.py#L208

### What is tree completion?
See Appendix D.3.1
TLDR: Run virtual steps in order to have a complete binary tree so we can have lower noise at the expense of a little more privacy loss.
A natural trick to consider is thus to complete the tree with “virtual steps” such that the noise is the smallest.
{F737392429}

### What is tree restart?
See Appendix D.1
https://arxiv.org/pdf/2103.00039.pdf

## Why?
We're using this code to evaluate utility of FTRL compared to DP-FedAvg

## Notebook to compute epsilon
https://fburl.com/anp/6bms12ap

Reviewed By: pierrestock

Differential Revision: D36358239

fbshipit-source-id: e05f984982478c48111ce01ffe4e76b72c88e0aa
  • Loading branch information
John Nguyen authored and facebook-github-bot committed Jun 30, 2022
1 parent 6033eac commit 1ae9520
Show file tree
Hide file tree
Showing 7 changed files with 678 additions and 3 deletions.
8 changes: 8 additions & 0 deletions flsim/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
FedAvgWithLROptimizerConfig,
FedLAMBOptimizerConfig,
FedLARSOptimizerConfig,
ServerFTRLOptimizerConfig,
)

from .sync_aggregators import (
FedAdamSyncAggregatorConfig,
FedAvgSyncAggregatorConfig,
Expand Down Expand Up @@ -178,3 +180,9 @@
node=FedLAMBOptimizerConfig,
group="server_optimizer",
)

ConfigStore.instance().store(
name="base_ftrl_optimizer",
node=ServerFTRLOptimizerConfig,
group="server_optimizer",
)
110 changes: 108 additions & 2 deletions flsim/optimizers/server_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import torch
import torch.nn as nn
from flsim.optimizers.layerwise_optimizers import LAMB, LARS
from flsim.utils.config_utils import fullclassname, init_self_cfg, is_target
from flsim.utils.config_utils import fullclassname, init_self_cfg
from omegaconf import MISSING


Expand All @@ -48,7 +48,7 @@ def _set_defaults_in_cfg(cls, cfg):

@abc.abstractmethod
@torch.no_grad()
def step(self, closure):
def step(self, closure, noise=None):
r"""Performs a single optimization step (parameter update).
Args:
Expand Down Expand Up @@ -215,6 +215,105 @@ def zero_grad(self, set_to_none: bool = False):
return LAMB.zero_grad(self, set_to_none)


class ServerFTRLOptimizer(IServerOptimizer, torch.optim.Optimizer):
"""
:param params: parameter groups
:param momentum: if non-zero, use DP-FTRLM
:param record_last_noise: whether to record the last noise. for the tree completion trick.
"""

def __init__(self, *, model: nn.Module, record_last_noise: bool, **kwargs) -> None:
init_self_cfg(
self,
component_class=__class__,
config_class=ServerFTRLOptimizerConfig,
**kwargs,
)

IServerOptimizer.__init__(self, model=model, **kwargs)
# pyre-ignore[28]
torch.optim.Optimizer.__init__(self, params=model.parameters(), defaults={})
# pyre-ignore[16]
self.momentum = self.cfg.momentum
self.lr = self.cfg.lr
self.record_last_noise = record_last_noise

def __setstate__(self, state):
super(ServerFTRLOptimizer, self).__setstate__(state)

def zero_grad(self, set_to_none: bool = False):
return torch.optim.Optimizer.zero_grad(self, set_to_none)

@torch.no_grad()
def step(self, noise, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p, nz in zip(group["params"], noise):
if p.grad is None:
continue
d_p = p.grad

param_state = self.state[p]

if len(param_state) == 0:
param_state["grad_sum"] = torch.zeros_like(
d_p, memory_format=torch.preserve_format
)
param_state["model_sum"] = p.detach().clone(
memory_format=torch.preserve_format
) # just record the initial model
param_state["momentum"] = torch.zeros_like(
p, memory_format=torch.preserve_format
)
if self.record_last_noise:
param_state["last_noise"] = torch.zeros_like(
p, memory_format=torch.preserve_format
) # record the last noise needed, in order for restarting

gs, ms = param_state["grad_sum"], param_state["model_sum"]
if self.momentum == 0:
gs.add_(d_p)
p.copy_(ms + (-gs - nz) / self.lr)
else:
gs.add_(d_p)
param_state["momentum"].mul_(self.momentum).add_(gs + nz)
p.copy_(ms - param_state["momentum"] / self.lr)
if self.record_last_noise:
param_state["last_noise"].copy_(nz)
return loss

@torch.no_grad()
def restart(self, last_noise=None):
"""
Restart the tree.
:param last_noise: the last noise to be added. If none, use the last noise recorded.
"""
assert last_noise is not None or self.record_last_noise
for group in self.param_groups:
if last_noise is None:
for p in group["params"]:
if p.grad is None:
continue
param_state = self.state[p]
if len(param_state) == 0:
continue
param_state["grad_sum"].add_(
param_state["last_noise"]
) # add the last piece of noise to the current gradient sum
else:
for p, nz in zip(group["params"], last_noise):
if p.grad is None:
continue
param_state = self.state[p]
if len(param_state) == 0:
continue
param_state["grad_sum"].add_(nz)


@dataclass
class ServerOptimizerConfig:
_target_: str = MISSING
Expand Down Expand Up @@ -259,3 +358,10 @@ class FedLAMBOptimizerConfig(ServerOptimizerConfig):
beta1: float = 0.9
beta2: float = 0.999
eps: float = 1e-8


@dataclass
class ServerFTRLOptimizerConfig(ServerOptimizerConfig):
_target_: str = fullclassname(ServerFTRLOptimizer)
lr: float = 0.001
momentum: float = 0.0
176 changes: 175 additions & 1 deletion flsim/privacy/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import logging
import os
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
from typing import Any, Optional

import torch
from flsim.common.logger import Logger
Expand All @@ -21,6 +22,12 @@
from torch import nn


@dataclass
class TreeNode:
height: int
value: Any


class PrivacyEngineNotAttachedException(Exception):
"""
Exception class to be thrown from User Privacy Engine in case
Expand Down Expand Up @@ -165,3 +172,170 @@ def get_privacy_spent(self, target_delta: Optional[float] = None):
"the set of alpha orders."
)
return PrivacyBudget(eps, opt_alpha, target_delta)


class CummuNoiseTorch:
@torch.no_grad()
def __init__(self, std, shapes, device, test_mode=False, seed=None):
"""
:param std: standard deviation of the noise
:param shapes: shapes of the noise, which is basically shape of the gradients
:param device: device for pytorch tensor
:param test_mode: if in test mode, noise will be 1 in each node of the tree
"""
seed = (
seed
if seed is not None
else int.from_bytes(os.urandom(8), byteorder="big", signed=True)
)
self.std = std
self.shapes = shapes
self.device = device
self.step = 0
self.binary = [0]
self.noise_sum = [torch.zeros(shape).to(self.device) for shape in shapes]
self.recorded = [[torch.zeros(shape).to(self.device) for shape in shapes]]
torch.cuda.manual_seed_all(seed)
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(seed)
self.test_mode = test_mode

@torch.no_grad()
def __call__(self):
"""
:return: the noise to be added by DP-FTRL
"""
self.step += 1
if self.std <= 0 and not self.test_mode:
return self.noise_sum

idx = 0
while idx < len(self.binary) and self.binary[idx] == 1:
self.binary[idx] = 0
for ns, re in zip(self.noise_sum, self.recorded[idx]):
ns -= re
idx += 1
if idx >= len(self.binary):
self.binary.append(0)
self.recorded.append(
[torch.zeros(shape).to(self.device) for shape in self.shapes]
)

for shape, ns, re in zip(self.shapes, self.noise_sum, self.recorded[idx]):
if not self.test_mode:
n = torch.normal(
0, self.std, shape, generator=self.generator, device=self.device
)
else:
n = torch.ones(shape).to(self.device)
ns += n
re.copy_(n)

self.binary[idx] = 1
return self.noise_sum

@torch.no_grad()
def proceed_until(self, step_target):
"""
Proceed until the step_target-th step. This is for the binary tree completion trick.
:return: the noise to be added by DP-FTRL
"""
if self.step >= step_target:
raise ValueError(f"Already reached {step_target}.")
while self.step < step_target:
noise_sum = self.__call__()
return noise_sum


class CummuNoiseEffTorch:
"""
The tree aggregation protocol with the trick in Honaker, "Efficient Use of Differentially Private Binary Trees", 2015
"""

@torch.no_grad()
def __init__(self, std, shapes, device, seed, test_mode=False):
"""
:param std: standard deviation of the noise
:param shapes: shapes of the noise, which is basically shape of the gradients
:param device: device for pytorch tensor
"""
seed = (
seed
if seed is not None
else int.from_bytes(os.urandom(8), byteorder="big", signed=True)
)
self.test_mode = test_mode

self.std = std
self.shapes = shapes
self.device = device
torch.cuda.manual_seed_all(seed)
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(seed)
self.step = 0
self.noise_sum = [torch.zeros(shape).to(self.device) for shape in shapes]
self.stack = []

@torch.no_grad()
def get_noise(self):
return [
torch.normal(
0, self.std, shape, generator=self.generator, device=self.device
)
for shape in self.shapes
]

@torch.no_grad()
def push(self, elem):
for i in range(len(self.shapes)):
self.noise_sum[i] += elem.value[i] / (2.0 - 1 / 2**elem.height)
self.stack.append(elem)

@torch.no_grad()
def pop(self):
elem = self.stack.pop()
for i in range(len(self.shapes)):
self.noise_sum[i] -= elem.value[i] / (2.0 - 1 / 2**elem.height)

@torch.no_grad()
def __call__(self):
"""
:return: the noise to be added by DP-FTRL
"""
self.step += 1

# add new element to the stack
self.push(TreeNode(0, self.get_noise()))

# pop the stack
while len(self.stack) >= 2 and self.stack[-1].height == self.stack[-2].height:
# create new element
left_value, right_value = self.stack[-2].value, self.stack[-1].value
new_noise = self.get_noise()
new_elem = TreeNode(
self.stack[-1].height + 1,
[
x + (y + z) / 2
for x, y, z in zip(new_noise, left_value, right_value)
],
)

# pop the stack, update sum
self.pop()
self.pop()

# append to the stack, update sum
self.push(new_elem)
return self.noise_sum

@torch.no_grad()
def proceed_until(self, step_target):
"""
Proceed until the step_target-th step. This is for the binary tree completion trick.
:return: the noise to be added by DP-FTRL
"""
if self.step >= step_target:
raise ValueError(f"Already reached {step_target}.")
while self.step < step_target:
noise_sum = self.__call__()
return noise_sum
7 changes: 7 additions & 0 deletions flsim/servers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from hydra.core.config_store import ConfigStore # @manual

from .sync_dp_servers import SyncDPSGDServerConfig
from .sync_ftrl_servers import SyncFTRLServerConfig
from .sync_mime_servers import SyncMimeServerConfig
from .sync_mimelite_servers import SyncMimeLiteServerConfig
from .sync_secagg_servers import SyncSecAggServerConfig, SyncSecAggSQServerConfig
Expand Down Expand Up @@ -54,3 +55,9 @@
node=SyncSecAggSQServerConfig,
group="server",
)

ConfigStore.instance().store(
name="base_sync_ftrl_server",
node=SyncFTRLServerConfig,
group="server",
)
Loading

0 comments on commit 1ae9520

Please sign in to comment.