Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Koukyosyumei committed Oct 29, 2022
1 parent 316fcc0 commit b75cce4
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/aijack/collaborative/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ def step(self, grads):
grads (List[torch.Tensor]): list of gradients
"""
for param, grad in zip(self.parameters, grads):
param.data -= self.lr * (grad + self.weight_decay * param.data)
if self.weight_decay == 0.0:
param.data -= self.lr * grad
else:
param.data -= self.lr * (grad + self.weight_decay * param.data)
self.t += 1
1 change: 1 addition & 0 deletions src/aijack/defense/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
from .ckks import CKKSEncoder, CKKSEncrypter, CKKSPlaintext # noqa: F401
from .dp import GeneralMomentAccountant, PrivacyManager # noqa: F401
from .mid import VIB, KL_between_normals, mib_loss # noqa:F401
from .paillier import PaillierGradientClientManager, PaillierKeyGenerator # noqa: F401
from .purifier import Purifier_Cifar10, PurifierLoss # noqa: F401
from .soteria import SoteriaManager, attach_soteria_to_client # noqa: F401
1 change: 1 addition & 0 deletions src/aijack/defense/paillier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
PaillierSecretKey,
)

from .fed_wrapper import PaillierGradientClientManager # noqa: F401
from .torch_wrapper import PaillierTensor # noqa: F401
42 changes: 42 additions & 0 deletions src/aijack/defense/paillier/fed_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np

from ...manager import BaseManager
from .torch_wrapper import PaillierTensor


def attach_paillier_to_client_for_encrypted_grad(cls, pk, sk):
class PaillierClientWrapper(cls):
def __init__(self, *args, **kwargs):
super(PaillierClientWrapper, self).__init__(*args, **kwargs)

def upload_gradients(self):
pt_grads = super().upload_gradients()
print("do enc")
return [
PaillierTensor(
np.vectorize(lambda x: pk.encrypt(x.detach().numpy()))(grad)
)
for grad in pt_grads
]

def download(self, model_parameters):
decrypted_params = {}
for key, param in model_parameters.items():
if type(param) == PaillierTensor:
decrypted_params[key] = param.decrypt2float(sk)
else:
decrypted_params[key] = param
return super().download(decrypted_params)

return PaillierClientWrapper


class PaillierGradientClientManager(BaseManager):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

def attach(self, cls):
return attach_paillier_to_client_for_encrypted_grad(
cls, *self.args, **self.kwargs
)
4 changes: 2 additions & 2 deletions src/aijack/defense/paillier/torch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def add(input, other):
@implements(torch.sub)
def sub(input, other):
if type(other) in [int, float]:
return PaillierTensor(input._paillier_np_array + (-1 * other))
return PaillierTensor(input._paillier_np_array + (-1) * other)
elif type(other) in [torch.Tensor, PaillierTensor]:
return PaillierTensor(input._paillier_np_array + (-1 * other.numpy()))
return PaillierTensor(input._paillier_np_array + (-1) * other.numpy())
else:
raise NotImplementedError(f"{type(other)} is not supported.")

Expand Down
83 changes: 83 additions & 0 deletions test/defense/paillier/test_paillier.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,86 @@ def test_paillier_torch():
torch.testing.assert_allclose(
pt_5.decypt(sk), torch.Tensor([26, 1, 27]), atol=1e-5, rtol=1
)


def test_pailier_fedavg():
import torch
import torch.nn as nn
import torch.optim as optim

from aijack.collaborative import FedAvgClient, FedAvgServer
from aijack.defense import PaillierGradientClientManager, PaillierKeyGenerator

torch.manual_seed(0)

lr = 0.01
epochs = 2
client_num = 2

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 5),
nn.Sigmoid(),
nn.MaxPool2d(3, 3, 1),
nn.Conv2d(32, 64, 5),
nn.Sigmoid(),
nn.MaxPool2d(3, 3, 1),
)

self.lin = nn.Sequential(nn.Linear(256, 10))

def forward(self, x):
x = self.conv(x)
x = x.reshape((-1, 256))
x = self.lin(x)
return x

keygenerator = PaillierKeyGenerator(64)
pk, sk = keygenerator.generate_keypair()

x = torch.load("test/demodata/demo_mnist_x.pt")
x.requires_grad = True
y = torch.load("test/demodata/demo_mnist_y.pt")

manager = PaillierGradientClientManager(pk, sk)
PaillierGradFedAvgClient = manager.attach(FedAvgClient)

clients = [
PaillierGradFedAvgClient(
Net(),
user_id=i,
lr=lr,
)
for i in range(client_num)
]
optimizers = [optim.SGD(client.parameters(), lr=lr) for client in clients]

global_model = Net()
server = FedAvgServer(clients, global_model, lr=lr)

criterion = nn.CrossEntropyLoss()

loss_log = []
for _ in range(epochs):
temp_loss = 0
for client_idx in range(client_num):
client = clients[client_idx]
optimizer = optimizers[client_idx]

optimizer.zero_grad()
client.zero_grad()

outputs = client(x)
loss = criterion(outputs, y.to(torch.int64))
client.backward(loss)
temp_loss = loss.item() / client_num

optimizer.step()

loss_log.append(temp_loss)

server.action(use_gradients=True)

assert loss_log[0] > loss_log[1]

0 comments on commit b75cce4

Please sign in to comment.