-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
135 changed files
with
5,275 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -127,3 +127,6 @@ dmypy.json | |
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
data/ | ||
.idea/ |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
import random | ||
import json | ||
from collections import defaultdict | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torchvision.transforms as transforms | ||
import torch.utils.data | ||
from torchvision.datasets import MNIST, CIFAR10, CIFAR100 | ||
|
||
|
||
def get_datasets(data_name, dataroot, normalize=True, val_size=10000): | ||
|
||
if data_name == 'mnist': | ||
normalization = transforms.Normalize( | ||
(0.1307,), (0.3081,) | ||
) | ||
data_obj = MNIST | ||
elif data_name =='cifar10': | ||
normalization = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | ||
data_obj = CIFAR10 | ||
elif data_name == 'cifar100': | ||
normalization = transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)) | ||
data_obj = CIFAR100 | ||
else: | ||
raise ValueError("choose data_name from ['mnist', 'cifar10', 'cifar100']") | ||
|
||
trans = [transforms.ToTensor()] | ||
|
||
if normalize: | ||
trans.append(normalization) | ||
|
||
transform = transforms.Compose(trans) | ||
|
||
dataset = data_obj( | ||
dataroot, | ||
train=True, | ||
download=True, | ||
transform=transform | ||
) | ||
|
||
test_set = data_obj( | ||
dataroot, | ||
train=False, | ||
download=True, | ||
transform=transform | ||
) | ||
|
||
train_size = len(dataset) - val_size | ||
train_set, val_set = torch.utils.data.random_split(dataset, [train_size, val_size]) | ||
|
||
return train_set, val_set, test_set | ||
|
||
|
||
def get_num_classes_samples(dataset): | ||
# ---------------# | ||
# Extract labels # | ||
# ---------------# | ||
if isinstance(dataset, torch.utils.data.Subset): | ||
if isinstance(dataset.dataset.targets, list): | ||
data_labels_list = np.array(dataset.dataset.targets)[dataset.indices] | ||
else: | ||
data_labels_list = dataset.dataset.targets[dataset.indices] | ||
else: | ||
if isinstance(dataset.targets, list): | ||
data_labels_list = np.array(dataset.targets) | ||
else: | ||
data_labels_list = dataset.targets | ||
classes, num_samples = np.unique(data_labels_list, return_counts=True) | ||
num_classes = len(classes) | ||
return num_classes, num_samples, data_labels_list | ||
|
||
|
||
def classes_per_node_dirichlet(dataset, num_users, num_gen_users, alpha): | ||
num_classes, _, _ = get_num_classes_samples(dataset) | ||
|
||
# create distribution for each client | ||
prob_array = [] | ||
alpha_list = [[alpha if i >= num_gen_users else 0.1 * alpha for _ in range(num_classes)] for i in range(num_users)] | ||
for i in range(num_users): | ||
prob_array.append(np.random.dirichlet(alpha_list[i], 1).reshape(-1)) | ||
|
||
# normalizing | ||
prob_array = np.array(prob_array) | ||
prob_array /= prob_array.sum(axis=0) | ||
|
||
class_partitions = defaultdict(list) | ||
cls_list = [i for i in range(num_classes)] | ||
for i in range(num_users): | ||
class_partitions['class'].append(cls_list) | ||
class_partitions['prob'].append(prob_array[i, :]) | ||
|
||
return class_partitions | ||
|
||
|
||
def gen_classes_per_node(dataset, num_users, classes_per_user=2, high_prob=0.6, low_prob=0.4): | ||
|
||
num_classes, num_samples, _ = get_num_classes_samples(dataset) | ||
|
||
# -------------------------------------------# | ||
# Divide classes + num samples for each user # | ||
# -------------------------------------------# | ||
# assert (classes_per_user * num_users) % num_classes == 0, "equal classes appearance is needed" | ||
count_per_class = (classes_per_user * num_users) // num_classes + 1 | ||
class_dict = {} | ||
for i in range(num_classes): | ||
probs = np.random.uniform(low_prob, high_prob, size=count_per_class) | ||
probs_norm = (probs / probs.sum()).tolist() | ||
class_dict[i] = {'count': count_per_class, 'prob': probs_norm} | ||
class_partitions = defaultdict(list) | ||
for i in range(num_users): | ||
c = [] | ||
for _ in range(classes_per_user): | ||
class_counts = [class_dict[i]['count'] for i in range(num_classes)] | ||
max_class_counts = np.where(np.array(class_counts) == max(class_counts))[0] | ||
c.append(np.random.choice(max_class_counts)) | ||
class_dict[c[-1]]['count'] -= 1 | ||
class_partitions['class'].append(c) | ||
class_partitions['prob'].append([class_dict[i]['prob'].pop() for i in c]) | ||
return class_partitions | ||
|
||
|
||
def gen_data_split(dataset, num_users, class_partitions): | ||
|
||
num_classes, num_samples, data_labels_list = get_num_classes_samples(dataset) | ||
|
||
# -------------------------- # | ||
# Create class index mapping # | ||
# -------------------------- # | ||
data_class_idx = {i: np.where(data_labels_list == i)[0] for i in range(num_classes)} | ||
|
||
# --------- # | ||
# Shuffling # | ||
# --------- # | ||
for data_idx in data_class_idx.values(): | ||
random.shuffle(data_idx) | ||
|
||
# ------------------------------ # | ||
# Assigning samples to each user # | ||
# ------------------------------ # | ||
user_data_idx = [[] for i in range(num_users)] | ||
for usr_i in range(num_users): | ||
for c, p in zip(class_partitions['class'][usr_i], class_partitions['prob'][usr_i]): | ||
end_idx = int(num_samples[c] * p) | ||
user_data_idx[usr_i].extend(data_class_idx[c][:end_idx]) | ||
data_class_idx[c] = data_class_idx[c][end_idx:] | ||
|
||
return user_data_idx | ||
|
||
|
||
def gen_random_loaders(data_name, data_path, num_users, bz, classes_per_user): | ||
loader_params = {"batch_size": bz, "shuffle": False, "pin_memory": True, "num_workers": 0} | ||
dataloaders = [] | ||
datasets = get_datasets(data_name, data_path, normalize=True) | ||
for i, d in enumerate(datasets): | ||
if i == 0: | ||
cls_partitions = gen_classes_per_node(d, num_users, classes_per_user) | ||
loader_params['shuffle'] = True | ||
usr_subset_idx = gen_data_split(d, num_users, cls_partitions) | ||
subsets = list(map(lambda x: torch.utils.data.Subset(d, x), usr_subset_idx)) | ||
dataloaders.append(list(map(lambda x: torch.utils.data.DataLoader(x, **loader_params), subsets))) | ||
|
||
return dataloaders |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import random | ||
from typing import List | ||
from abc import abstractmethod | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
from torch.nn.utils import spectral_norm | ||
|
||
|
||
class DNNHyper(nn.Module): | ||
def __init__(self, n_nodes, embedding_dim, hidden_dim=100, target_hidden=100, out_dim=10, spec_norm=False): | ||
super().__init__() | ||
|
||
self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim) | ||
self.hidden_dim = hidden_dim | ||
self.target_hidden = target_hidden | ||
self.out_dim = out_dim | ||
|
||
self.mlp = nn.Sequential( | ||
spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim), | ||
nn.ReLU(inplace=True), | ||
spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim), | ||
) | ||
|
||
self.l1_weights = nn.Linear(self.hidden_dim, self.target_hidden * 784) | ||
self.l1_bias = nn.Linear(self.hidden_dim, self.target_hidden) | ||
self.l2_weights = nn.Linear(self.hidden_dim, self.target_hidden * self.out_dim) | ||
self.l2_bias = nn.Linear(self.hidden_dim, self.out_dim) | ||
if spec_norm: | ||
self.l1_weights = spectral_norm(self.l1_weights) | ||
self.l1_bias = spectral_norm(self.l1_bias) | ||
self.l2_weights = spectral_norm(self.l2_weights) | ||
self.l2_bias = spectral_norm(self.l2_bias) | ||
|
||
def forward(self, idx): | ||
emd = self.embeddings(idx) | ||
features = self.mlp(emd) | ||
|
||
weights = { | ||
"fc1.weight": self.l1_weights(features).view(self.target_hidden, 784), | ||
"fc1.bias": self.l1_bias(features).view(-1), | ||
"fc2.weight": self.l2_weights(features).view(self.out_dim, self.target_hidden), | ||
"fc2.bias": self.l2_bias(features).view(-1) | ||
} | ||
return weights | ||
|
||
|
||
class DNNTarget(nn.Module): | ||
def __init__(self, hidden_dim=100, output_dim=10): | ||
super(DNNTarget, self).__init__() | ||
self.hidden_dim = hidden_dim | ||
self.output_dim = output_dim | ||
|
||
def forward(self, x, weights): | ||
x = x.view(x.shape[0], -1) | ||
x = F.linear(x, weights['fc1.weight'], weights['fc1.bias']) | ||
output = F.linear(x, weights['fc2.weight'], weights['fc2.bias']) | ||
return output | ||
|
||
|
||
class DNNTargetLook(nn.Module): | ||
def __init__(self, input_dim=784, hidden_dim=100, output_dim=10): | ||
super(DNNTargetLook, self).__init__() | ||
# define network layers | ||
self.fc1 = nn.Linear(input_dim, hidden_dim) | ||
self.fc2 = nn.Linear(hidden_dim, output_dim) | ||
|
||
def forward(self, x): | ||
x = torch.flatten(x, 1) | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
|
||
class CNNHyper(nn.Module): | ||
def __init__( | ||
self, n_nodes, embedding_dim, in_channels=3, out_dim=10, n_kernels=16, hidden_dim=100, | ||
spec_norm=False, n_hidden=1): | ||
super().__init__() | ||
|
||
self.in_channels = in_channels | ||
self.out_dim = out_dim | ||
self.n_kernels = n_kernels | ||
self.embeddings = nn.Embedding(num_embeddings=n_nodes, embedding_dim=embedding_dim) | ||
|
||
layers = [ | ||
spectral_norm(nn.Linear(embedding_dim, hidden_dim)) if spec_norm else nn.Linear(embedding_dim, hidden_dim), | ||
] | ||
for _ in range(n_hidden): | ||
layers.append(nn.ReLU(inplace=True)) | ||
layers.append( | ||
spectral_norm(nn.Linear(hidden_dim, hidden_dim)) if spec_norm else nn.Linear(hidden_dim, hidden_dim), | ||
) | ||
|
||
self.mlp = nn.Sequential(*layers) | ||
|
||
self.c1_weights = nn.Linear(hidden_dim, self.n_kernels * self.in_channels * 5 * 5) | ||
self.c1_bias = nn.Linear(hidden_dim, self.n_kernels) | ||
self.c2_weights = nn.Linear(hidden_dim, 2 * self.n_kernels * self.n_kernels * 5 * 5) | ||
self.c2_bias = nn.Linear(hidden_dim, 2 * self.n_kernels) | ||
self.l1_weights = nn.Linear(hidden_dim, 120 * 2 * self.n_kernels * 5 * 5) | ||
self.l1_bias = nn.Linear(hidden_dim, 120) | ||
self.l2_weights = nn.Linear(hidden_dim, 84 * 120) | ||
self.l2_bias = nn.Linear(hidden_dim, 84) | ||
self.l3_weights = nn.Linear(hidden_dim, self.out_dim * 84) | ||
self.l3_bias = nn.Linear(hidden_dim, self.out_dim) | ||
|
||
if spec_norm: | ||
self.c1_weights = spectral_norm(self.c1_weights) | ||
self.c1_bias = spectral_norm(self.c1_bias) | ||
self.c2_weights = spectral_norm(self.c2_weights) | ||
self.c2_bias = spectral_norm(self.c2_bias) | ||
self.l1_weights = spectral_norm(self.l1_weights) | ||
self.l1_bias = spectral_norm(self.l1_bias) | ||
self.l2_weights = spectral_norm(self.l2_weights) | ||
self.l2_bias = spectral_norm(self.l2_bias) | ||
self.l3_weights = spectral_norm(self.l3_weights) | ||
self.l3_bias = spectral_norm(self.l3_bias) | ||
|
||
def forward(self, idx): | ||
emd = self.embeddings(idx) | ||
features = self.mlp(emd) | ||
|
||
weights = { | ||
"conv1.weight": self.c1_weights(features).view(self.n_kernels, self.in_channels, 5, 5), | ||
"conv1.bias": self.c1_bias(features).view(-1), | ||
"conv2.weight": self.c2_weights(features).view(2 * self.n_kernels, self.n_kernels, 5, 5), | ||
"conv2.bias": self.c2_bias(features).view(-1), | ||
"fc1.weight": self.l1_weights(features).view(120, 2 * self.n_kernels * 5 * 5), | ||
"fc1.bias": self.l1_bias(features).view(-1), | ||
"fc2.weight": self.l2_weights(features).view(84, 120), | ||
"fc2.bias": self.l2_bias(features).view(-1), | ||
"fc3.weight": self.l3_weights(features).view(self.out_dim, 84), | ||
"fc3.bias": self.l3_bias(features).view(-1), | ||
} | ||
return weights | ||
|
||
|
||
class CNNTargetLook(nn.Module): | ||
def __init__(self, in_channels=3, n_kernels=16, out_dim=10): | ||
super(CNNTargetLook, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(in_channels, n_kernels, 5) | ||
self.pool = nn.MaxPool2d(2, 2) | ||
self.conv2 = nn.Conv2d(n_kernels, 2 * n_kernels, 5) | ||
self.fc1 = nn.Linear(2 * n_kernels * 5 * 5, 120) | ||
self.fc2 = nn.Linear(120, 84) | ||
self.fc3 = nn.Linear(84, out_dim) | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = x.view(x.shape[0], -1) | ||
x = F.relu(self.fc1(x)) | ||
x = F.relu(self.fc2(x)) | ||
x = self.fc3(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
project: fhn | ||
entity: ax2 | ||
method: grid | ||
parameters: | ||
data-name: | ||
values: | ||
- cifar10 | ||
data-path: | ||
values: | ||
- /cortex/data/images | ||
embed-dim: | ||
values: | ||
- -1 | ||
eval-every: | ||
values: | ||
- 25 | ||
gpu: | ||
values: | ||
- 0 | ||
hyper-hid: | ||
values: | ||
- 100 | ||
inner-lr: | ||
values: | ||
- 0.005 | ||
la-steps: | ||
values: | ||
- 50 | ||
lr: | ||
values: | ||
- 0.01 | ||
n-hidden: | ||
values: | ||
- 3 | ||
num-nodes: | ||
values: | ||
- 50 | ||
num-steps: | ||
values: | ||
- 5000 | ||
seed: | ||
values: | ||
- 42 | ||
- 27 | ||
- 13 | ||
wd: | ||
values: | ||
- 0.001 | ||
program: trainer.py |
Oops, something went wrong.