Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Koukyosyumei committed Oct 29, 2022
1 parent b75cce4 commit ffbeebd
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 52 deletions.
46 changes: 43 additions & 3 deletions src/aijack/collaborative/fedavg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,52 @@

from ..core import BaseClient
from ..core.utils import GRADIENTS_TAG, PARAMETERS_TAG
from ..optimizer import AdamFLOptimizer, SGDFLOptimizer


class FedAvgClient(BaseClient):
def __init__(self, model, user_id=0, lr=0.1, send_gradient=True):
def __init__(
self,
model,
user_id=0,
lr=0.1,
send_gradient=True,
optimizer_type_for_global_grad="sgd",
server_side_update=True,
optimizer_kwargs_for_global_grad={},
):
super(FedAvgClient, self).__init__(model, user_id=user_id)
self.lr = lr
self.send_gradient = send_gradient
self.server_side_update = server_side_update

if not self.server_side_update:
self._setup_optimizer_for_global_grad(
optimizer_type_for_global_grad, **optimizer_kwargs_for_global_grad
)

self.prev_parameters = []
for param in self.model.parameters():
self.prev_parameters.append(copy.deepcopy(param))

self.initialized = False

def _setup_optimizer_for_global_grad(self, optimizer_type, **kwargs):
if optimizer_type == "sgd":
self.optimizer_for_gloal_grad = SGDFLOptimizer(
self.model.parameters(), lr=self.lr, **kwargs
)
elif optimizer_type == "adam":
self.optimizer_for_gloal_grad = AdamFLOptimizer(
self.model.parameters(), lr=self.lr, **kwargs
)
elif optimizer_type == "none":
self.optimizer_for_gloal_grad = None
else:
raise NotImplementedError(
f"{optimizer_type} is not supported. You can specify `sgd`, `adam`, or `none`."
)

def upload(self):
if self.send_gradient:
return self.upload_gradients()
Expand All @@ -31,8 +65,14 @@ def upload_gradients(self):
gradients.append((prev_param - param) / self.lr)
return gradients

def download(self, model_parameters):
self.model.load_state_dict(model_parameters)
def download(self, new_global_model):
if self.server_side_update or (not self.initialized):
self.model.load_state_dict(new_global_model)
else:
self.optimizer_for_gloal_grad.step(new_global_model)

if not self.initialized:
self.initialized = True

self.prev_parameters = []
for param in self.model.parameters():
Expand Down
24 changes: 17 additions & 7 deletions src/aijack/collaborative/fedavg/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def __init__(
server_id=0,
lr=0.1,
optimizer_type="sgd",
server_side_update=True,
optimizer_kwargs={},
):
super(FedAvgServer, self).__init__(clients, global_model, server_id=server_id)
self.lr = lr
self._setup_optimizer(optimizer_type, **optimizer_kwargs)
self.distribtue()
self.server_side_update = server_side_update
self.distribtue(force_send_model_state_dict=True)

def _setup_optimizer(self, optimizer_type, **kwargs):
if optimizer_type == "sgd":
Expand Down Expand Up @@ -66,17 +68,22 @@ def receive_local_parameters(self):
def updata_from_gradients(self, weight=None):
if weight is None:
weight = np.ones(self.num_clients) / self.num_clients
weight = weight.tolist()

aggregated_gradients = [
self.aggregated_gradients = [
torch.zeros_like(params) for params in self.server_model.parameters()
]
len_gradients = len(aggregated_gradients)
len_gradients = len(self.aggregated_gradients)

for i, gradients in enumerate(self.uploaded_gradients):
for gradient_id in range(len_gradients):
aggregated_gradients[gradient_id] += weight[i] * gradients[gradient_id]
self.aggregated_gradients[gradient_id] = (
gradients[gradient_id] * weight[i]
+ self.aggregated_gradients[gradient_id]
)

self.optimizer.step(aggregated_gradients)
if self.server_side_update:
self.optimizer.step(self.aggregated_gradients)

def update_from_parameters(self, weight=None):
if weight is None:
Expand All @@ -95,9 +102,12 @@ def update_from_parameters(self, weight=None):

self.server_model.load_state_dict(averaged_params)

def distribtue(self):
def distribtue(self, force_send_model_state_dict=False):
for client in self.clients:
client.download(self.server_model.state_dict())
if self.server_side_update or force_send_model_state_dict:
client.download(self.server_model.state_dict())
else:
client.download(self.aggregated_gradients)


class MPIFedAVGServer(BaseServer):
Expand Down
22 changes: 12 additions & 10 deletions src/aijack/defense/paillier/fed_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@ def __init__(self, *args, **kwargs):

def upload_gradients(self):
pt_grads = super().upload_gradients()
print("do enc")
return [
PaillierTensor(
np.vectorize(lambda x: pk.encrypt(x.detach().numpy()))(grad)
np.vectorize(lambda x: pk.encrypt(x))(grad.detach().numpy())
)
for grad in pt_grads
]

def download(self, model_parameters):
decrypted_params = {}
for key, param in model_parameters.items():
if type(param) == PaillierTensor:
decrypted_params[key] = param.decrypt2float(sk)
else:
decrypted_params[key] = param
return super().download(decrypted_params)
def download(self, global_grad):
if not self.initialized:
return super().download(global_grad)
else:
decrypted_global_grad = []
for grad in global_grad:
if type(grad) == PaillierTensor:
decrypted_global_grad.append(grad.decrypt(sk))
else:
decrypted_global_grad.append(grad)
return super().download(decrypted_global_grad)

return PaillierClientWrapper

Expand Down
28 changes: 24 additions & 4 deletions src/aijack/defense/paillier/torch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, paillier_array):
def __repr__(self):
return "PaillierTensor"

def decypt(self, sk):
def decrypt(self, sk):
return torch.Tensor(
np.vectorize(lambda x: sk.decrypt2float(x))(self._paillier_np_array)
)
Expand All @@ -43,6 +43,9 @@ def tensor(self, sk=None):
def numpy(self):
return self._paillier_np_array

def detach(self):
return self

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
Expand All @@ -58,7 +61,7 @@ def add(input, other):
if type(other) in [int, float]:
return PaillierTensor(input._paillier_np_array + other)
elif type(other) in [torch.Tensor, PaillierTensor]:
return PaillierTensor(input._paillier_np_array + other.numpy())
return PaillierTensor(input._paillier_np_array + other.detach().numpy())
else:
raise NotImplementedError(f"{type(other)} is not supported.")

Expand All @@ -67,7 +70,9 @@ def sub(input, other):
if type(other) in [int, float]:
return PaillierTensor(input._paillier_np_array + (-1) * other)
elif type(other) in [torch.Tensor, PaillierTensor]:
return PaillierTensor(input._paillier_np_array + (-1) * other.numpy())
return PaillierTensor(
input._paillier_np_array + (-1) * other.detach().numpy()
)
else:
raise NotImplementedError(f"{type(other)} is not supported.")

Expand All @@ -76,15 +81,30 @@ def mul(input, other):
if type(other) in [int, float]:
return PaillierTensor(input._paillier_np_array * other)
elif type(other) in [torch.Tensor, PaillierTensor]:
return PaillierTensor(input._paillier_np_array * other.numpy())
return PaillierTensor(input._paillier_np_array * other.detach().numpy())
else:
raise NotImplementedError(f"{type(other)} is not supported.")

def __add__(self, other):
return torch.add(self, other)

def __iadd__(self, other):
self = torch.add(self, other)

def __radd__(self, other):
return self.__add__(other)

def __sub__(self, other):
return torch.sub(self, other)

def __isub__(self, other):
self = torch.sub(self, other)

def __rsub__(self, other):
return self.__sub__(other)

def __mul__(self, other):
return torch.mul(self, other)

def __rmul__(self, other):
return self.__mul__(other)
42 changes: 14 additions & 28 deletions test/defense/paillier/test_paillier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,28 @@ def test_paillier_torch():
ct_3 = ct_1 + ct_2

pt_1 = PaillierTensor([ct_1, ct_2, ct_3])
torch.testing.assert_allclose(
pt_1.decypt(sk), torch.Tensor([13, 0.5, 13.5]), atol=1e-5, rtol=1
torch.testing.assert_close(
pt_1.decrypt(sk), torch.Tensor([13, 0.5, 13.5]), atol=1e-5, rtol=1
)

pt_2 = pt_1 + torch.Tensor([0.4, 0.1, 0.2])
torch.testing.assert_allclose(
pt_2.decypt(sk), torch.Tensor([13.4, 0.6, 13.7]), atol=1e-5, rtol=1
torch.testing.assert_close(
pt_2.decrypt(sk), torch.Tensor([13.4, 0.6, 13.7]), atol=1e-5, rtol=1
)

pt_3 = pt_1 * torch.Tensor([1, 2.5, 0.5])
torch.testing.assert_allclose(
pt_3.decypt(sk), torch.Tensor([13, 1.25, 6.75]), atol=1e-5, rtol=1
torch.testing.assert_close(
pt_3.decrypt(sk), torch.Tensor([13, 1.25, 6.75]), atol=1e-5, rtol=1
)

pt_4 = pt_1 - torch.Tensor([0.7, 0.3, 0.6])
torch.testing.assert_allclose(
pt_4.decypt(sk), torch.Tensor([14.3, 0.2, 12.9]), atol=1e-5, rtol=1
torch.testing.assert_close(
pt_4.decrypt(sk), torch.Tensor([14.3, 0.2, 12.9]), atol=1e-5, rtol=1
)

pt_5 = pt_1 * 2
torch.testing.assert_allclose(
pt_5.decypt(sk), torch.Tensor([26, 1, 27]), atol=1e-5, rtol=1
torch.testing.assert_close(
pt_5.decrypt(sk), torch.Tensor([26, 1, 27]), atol=1e-5, rtol=1
)


Expand All @@ -78,20 +78,10 @@ def test_pailier_fedavg():
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5),
nn.Sigmoid(),
nn.MaxPool2d(3, 3, 1),
nn.Conv2d(32, 64, 5),
nn.Sigmoid(),
nn.MaxPool2d(3, 3, 1),
)

self.lin = nn.Sequential(nn.Linear(256, 10))
self.lin = nn.Sequential(nn.Linear(28 * 28, 10))

def forward(self, x):
x = self.conv(x)
x = x.reshape((-1, 256))
x = x.reshape((-1, 28 * 28))
x = self.lin(x)
return x

Expand All @@ -106,17 +96,13 @@ def forward(self, x):
PaillierGradFedAvgClient = manager.attach(FedAvgClient)

clients = [
PaillierGradFedAvgClient(
Net(),
user_id=i,
lr=lr,
)
PaillierGradFedAvgClient(Net(), user_id=i, lr=lr, server_side_update=False)
for i in range(client_num)
]
optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

global_model = Net()
server = FedAvgServer(clients, global_model, lr=lr)
server = FedAvgServer(clients, global_model, lr=lr, server_side_update=False)

criterion = nn.CrossEntropyLoss()

Expand Down

0 comments on commit ffbeebd

Please sign in to comment.