-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
256 lines (230 loc) · 13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import configparser
import copy
import os
from datetime import datetime
from pathlib import Path
import fire
import numpy as np
import pandas as pd
# import ray
from tqdm import tqdm
import wandb
from aggregators.base import FederatedAveraging
from loaders.utils import get_confusion_matrix_plot
from models.ut_har import *
from models.utils import load_model
from strategies.base_fl import basic_fedavg
from trainers.distributed_base import BaseTrainer
from trainers.ultralytics_distributed import UltralyticsYoloTrainer
from utils import WarmupScheduler, read_system_variable, get_default_yolo_hyperparameters, set_seed, load_dataset, \
get_partition, plot_data_distributions, add_label_noise, plot_noise_distribution
os.environ['WANDB_START_METHOD'] = 'thread'
system_config = configparser.ConfigParser()
run_config = configparser.ConfigParser()
run_config.read('config.yml')
system_config.read('system.yml')
num_gpus, num_trainers_per_gpu, seed = read_system_variable(system_config)
set_seed(seed)
print(f'Seed is {seed}')
YOLO_HYPERPARAMETERS = get_default_yolo_hyperparameters()
# ray.init(ignore_reinit_error=True, num_cpus=num_gpus * num_trainers_per_gpu + 5, num_gpus=num_gpus)
print("success")
class Experiment:
# def __init__(self, cfg):
# print(f'reading config from {cfg}')
# config.read(cfg)
# print(config['DEFAULT'].get('partition_type'))
def main(self,
model: str =
[run_config['DEFAULT'].get('model', 'models/resnet_group_norm.pt'), print(run_config['DEFAULT'])][0],
dataset_name: str = run_config['DEFAULT'].get('dataset', 'cifar10'),
data_dir: str = run_config['DEFAULT'].get('data_dir', '../data/'),
client_num_in_total: int = run_config['DEFAULT'].getint('client_num_in_total', 2118),
client_num_per_round: int = run_config['DEFAULT'].getint('client_num_per_round', 10),
batch_size: int = run_config['DEFAULT'].getint('batch_size', 16),
client_optimizer: str = run_config['DEFAULT'].get('client_optimizer', 'sgd'),
lr: float = run_config['DEFAULT'].getfloat('lr', 0.1e-2),
wd: float = run_config['DEFAULT'].getfloat('wd', 0.001),
epochs: int = run_config['DEFAULT'].getint('epochs', 1),
fl_algorithm: str = run_config['DEFAULT'].get('fl_algorithm', 'FedAvgSeq'),
comm_round: int = run_config['DEFAULT'].getint('comm_round', 30),
test_frequency: int = run_config['DEFAULT'].getint('test_frequency', 2),
server_optimizer: str = run_config['DEFAULT'].get('server_optimizer', 'adam'),
server_lr: float = run_config['DEFAULT'].getfloat('server_lr', 1e-1),
alpha: float = run_config['DEFAULT'].getfloat('alpha', 0.1),
partition_type: str = run_config['DEFAULT'].get('partition_type', 'dirichlet'),
amp: bool = run_config['DEFAULT'].getboolean('amp', False),
analysis: str = run_config['DEFAULT'].get('analysis', 'baseline'),
trainer: str = run_config['DEFAULT'].get('trainer', 'BaseTrainer'),
class_mixup: float = run_config['DEFAULT'].getfloat('class_mixup', 1),
precision: str = run_config['DEFAULT'].get('precision', 'float32'),
watch_metric: str = run_config['DEFAULT'].get('watch_metric', 'f1_score'),
milestones: list[int] = None,
resume: str = ""
):
"""
:param model: neural network used in training
:param dataset_name: dataset used for training
:param data_dir: data directory
:param client_num_in_total: number of workers in a distributed cluster
:param client_num_per_round: number of workers
:param batch_size: input batch size for training
:param client_optimizer: SGD with momentum; adam
:param lr: learning rate
:param wd: weight decay parameter
:param epochs: how many epochs will be trained locally
:param fl_algorithm: Algorithm list: FedAvg; FedOPT; FedProx; FedAvgSeq
:param comm_round: how many round of communications we should use
:param test_frequency: the frequency of the strategies
:param server_optimizer: server_optimizer
:param server_lr: server_lr
:param alpha: alpha in Dirichlet distribution
:param partition_type: partition type: user, dirichlet, central
:param trainer: trainer to be used
:param amp: flag for using mixed precision
:param watch_metric:
:param class_mixup:
:param precision:
:param analysis:
:param seed
Args:
milestones:
"""
if milestones is None:
milestones = []
print('Starting...')
args = copy.deepcopy(locals())
args.pop('self')
device = run_config['DEFAULT']['device']
dataset, num_classes = load_dataset(dataset_name)
partition, client_num_in_total, client_num_per_round = get_partition(partition_type,
dataset_name,
num_classes,
client_num_in_total,
client_num_per_round,
alpha,
dataset)
run = wandb.init(
# mode='disabled',
project=run_config['DEFAULT']['project'],
entity=run_config['DEFAULT']['entity'],
name=f'{fl_algorithm}_{dataset_name}_{partition_type}_{client_num_per_round}_{client_num_in_total}_{client_optimizer}_{lr}'
f'_{server_optimizer}_{model}_{analysis}'
f'{server_lr}_{alpha}_{datetime.now().strftime("%Y%m%d_%H%M%S")}',
config=args,
)
wandb.config['num_samples'] = len(dataset['train'])
client_datasets = partition(dataset['train'])
wandb.config['seed'] = seed
partition_name = partition_type if partition_type != 'dirichlet' else f'{partition_type}_{alpha}'
plot_data_distributions(dataset, dataset_name, client_datasets, num_classes)
if 'label_noise' in analysis and dataset_name in ['wisdm_phone', 'wisdm_watch', 'widar', 'ut_har', 'casas',
'epic_sounds', 'emognition']:
client_datasets, noise_percentages = add_label_noise(analysis, dataset_name, client_datasets, num_classes)
plot_noise_distribution(noise_percentages)
print('Saving dataset in object store')
data_ref = dataset['train']
print('Saving client indices in object store')
client_dataset_refs = [client_dataset for client_dataset in
tqdm(client_datasets)]
global_model = load_model(model_name=model, trainer=trainer, dataset_name=dataset_name)
if resume != "" and Path(f'weights/{resume}/best_model.pt').exists():
global_model.load_state_dict(torch.load(f'weights/{resume}/best_model.pt'))
global_model = global_model.cpu()
if trainer == 'BaseTrainer':
from scorers.classification_evaluator import evaluate
if dataset_name in {'energy'}:
from scorers.regression_evaluator import evaluate
criterion = nn.MSELoss()
wandb.config['loss'] = 'MSE'
elif dataset_name in {'ego4d'}:
from scorers.localization_evaluator import evaluate
criterion = nn.CrossEntropyLoss()
wandb.config['loss'] = 'CrossEntropyLoss'
else:
criterion = nn.CrossEntropyLoss()
wandb.config['loss'] = 'CrossEntropyLoss'
scheduler = torch.optim.lr_scheduler.MultiStepLR(torch.optim.SGD(global_model.parameters(), lr=lr),
milestones=milestones,
gamma=0.1)
client_trainers = [BaseTrainer(model_name=model,
dataset_name=dataset_name,
state_dict=global_model.state_dict(),
criterion=criterion,
batch_size=batch_size,
optimizer_name=client_optimizer,
epochs=epochs, scheduler='multisteplr',
class_mixup=class_mixup,
amp=amp,
**{'lr': lr, 'milestones': milestones, 'gamma': 0.1}) for _ \
in range(min(client_num_per_round, num_gpus * num_trainers_per_gpu))]
elif trainer == 'ultralytics':
# pt = torch.load('yolov8n.pt.1')
# global_model.load(pt)
base_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
torch.optim.SGD(global_model.parameters(), lr=lr), T_0=10, T_mult=2,
eta_min=1e-6)
optimizer = torch.optim.SGD(global_model.parameters(),
lr=lr) # dummy optimizer meant for scheduler. do not confuse for actual optimizer
scheduler = WarmupScheduler(optimizer, warmup_epochs=3, scheduler=base_scheduler)
global_model.args = YOLO_HYPERPARAMETERS
from scorers.ultralytics_yolo_evaluator import evaluate
client_trainers = [UltralyticsYoloTrainer(
model_path=model,
state_dict=global_model.state_dict(),
optimizer_name=client_optimizer,
epochs=epochs,
args=YOLO_HYPERPARAMETERS,
batch_size=batch_size,
amp=amp,
device=device) for _ in range(client_num_per_round)]
else:
raise ValueError(f'Client trainer of type {trainer} not found')
aggregator = FederatedAveraging(global_model=global_model,
server_optimizer=server_optimizer,
server_lr=server_lr,
server_momentum=0.9,
eps=1e-3)
best_metric = -np.inf
best_model = None
for round_idx in tqdm(range(0, comm_round)):
if round_idx % test_frequency == 0 and round_idx > 0:
metrics = evaluate(global_model, dataset['test'], device=device, num_classes=num_classes,
batch_size=batch_size)
v = metrics.get(watch_metric)
if isinstance(v, torch.Tensor):
v = v.numpy()
confusion_metric = None
if 'confusion' in metrics:
confusion_metric = metrics['confusion'].numpy()
del metrics['confusion']
if v is not None and v > best_metric:
best_metric = v
best_model = copy.deepcopy(global_model.cpu())
path = f'weights/{wandb.run.name}'
Path(path).mkdir(parents=True, exist_ok=True)
torch.save(best_model.state_dict(), f'{path}/best_model.pt')
if confusion_metric is not None:
chart = get_confusion_matrix_plot(confusion_metric)
wandb.log({'confusion_matrix': wandb.Html(chart)}, step=round_idx)
np.save(f'{path}/confusion_matrix.npy', confusion_metric)
wandb.log(
{'confusion_matrix_chart':
wandb.Table(dataframe=pd.DataFrame(confusion_metric,
columns=list(range(confusion_metric.shape[0]))))},
step=round_idx)
wandb.log(metrics, step=round_idx)
print(f'metric round_idx = {watch_metric}: {v}')
local_metrics_avg, global_model, scheduler = basic_fedavg(aggregator,
client_trainers,
client_dataset_refs,
client_num_per_round,
global_model,
round_idx,
scheduler,
device,
precision)
print(local_metrics_avg)
wandb.log(local_metrics_avg, step=round_idx)
if __name__ == '__main__':
fire.Fire(Experiment)