-
Notifications
You must be signed in to change notification settings - Fork 4
/
engine.py
316 lines (304 loc) · 11.9 KB
/
engine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Train and eval functions used in main.py
"""
import copy
import json
import math
import os
import random
import sys
from typing import Iterable
from util import box_ops
from util.utils import to_device
import torch
import torchvision
from copy import deepcopy
import util.misc as utils
from datasets.coco_eval import CocoEvaluator, convert_to_xywh
from lvis import LVISEval, LVISResults
# for eval visualize only
def generate_deterministic_rand(num):
prev_state = random.getstate()
random.seed(num)
rand = random.random()
random.setstate(prev_state)
return rand
def train_one_epoch(
model: torch.nn.Module,
criterion: torch.nn.Module,
data_loader: Iterable,
optimizer: torch.optim.Optimizer,
device: torch.device,
epoch: int,
max_norm: float = 0,
lr_scheduler=None,
args=None,
ema_m=None,
):
scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
header = "Epoch: [{}]".format(epoch)
if utils.get_world_size() == 1:
print_freq = 20
else:
print_freq = 200
_cnt = 0
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
targets = [
{
k: v if isinstance(v, (list, dict)) else v.to(device)
for k, v in t.items()
}
for t in targets
]
categories = copy.deepcopy(data_loader.dataset.category_list)
if args.num_label_sampled > 0 and args.dataset_file!="ovcoco":
if args.pseudo_box:
gt = torch.cat([target['labels'][~target['pseudo_mask'].to(torch.bool)] for target in targets]).unique()
else:
gt = torch.cat([target['labels'] for target in targets]).unique()
if gt.numel() >= args.num_label_sampled:
sampled = gt[torch.randperm(gt.numel(), device=gt.device)][:args.num_label_sampled]
else:
all_class = torch.arange(len(categories), device=gt.device)
neg_class = all_class[~(all_class.unsqueeze(1) == gt.unsqueeze(0)).any(-1)]
num_sample = args.num_label_sampled - gt.numel()
sampled = neg_class[torch.randperm(neg_class.numel(), device=gt.device)][:num_sample]
sampled = torch.cat([gt, sampled])
used_categories = sampled.tolist()
# reorder
for target in targets:
label = target['labels']
sampled_mask = (label.unsqueeze(-1) == sampled.unsqueeze(0)).any(-1)
if args.pseudo_box:
sampled_mask[target['pseudo_mask'].to(torch.bool)]=True
target['boxes'] = target['boxes'][sampled_mask]
label = label[sampled_mask]
new_label = (label.unsqueeze(-1) == sampled.unsqueeze(0)).int().argmax(-1)
# reassign pseudo box label
if args.pseudo_box:
new_label[target['pseudo_mask'].to(torch.bool)]=-1
target['labels'] = new_label
else:
used_categories = categories
with torch.cuda.amp.autocast(enabled=args.amp):
outputs = model(samples,categories=used_categories,targets=targets)
for target in targets:
target["ori_labels"] = target["labels"]
target["labels"] = target["labels"] - target["labels"]
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
losses = sum(
loss_dict[k] * weight_dict[k]
for k in loss_dict.keys()
if k in weight_dict
)
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
loss_value = losses_reduced_scaled.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
# amp backward function
if args.amp:
optimizer.zero_grad()
scaler.scale(losses).backward()
if max_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
else:
# original backward function
optimizer.zero_grad()
losses.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
if args.use_ema:
if epoch >= args.ema_epoch:
ema_m.update(model)
metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
del samples
del targets
del outputs
del loss_dict
del loss_dict_reduced
del weight_dict
del losses
del losses_reduced_scaled
_cnt += 1
if args.debug:
if _cnt % 15 == 0:
print("BREAK!" * 5)
break
if getattr(criterion, "loss_weight_decay", False):
criterion.loss_weight_decay(epoch=epoch)
if getattr(criterion, "tuning_matching", False):
criterion.tuning_matching(epoch)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
resstat = {
k: meter.global_avg
for k, meter in metric_logger.meters.items()
if meter.count > 0
}
if getattr(criterion, "loss_weight_decay", False):
resstat.update({f"weight_{k}": v for k, v in criterion.weight_dict.items()})
return resstat
@torch.no_grad()
def evaluate(
model,
criterion,
postprocessors,
data_loader,
base_ds,
device,
output_dir,
args=None,
epoch=None,
):
model.eval()
criterion.eval()
if args.dataset_file == "ovlvis" and epoch and utils.get_rank()==0 and not os.path.exists(os.path.join(output_dir,f"epoch_{epoch}")):
os.mkdir(os.path.join(output_dir,f"epoch_{epoch}"))
metric_logger = utils.MetricLogger(delimiter=" ")
header = "Test:"
if args.dataset_file == "ovlvis":
cat2label = data_loader.dataset.cat2label
label2cat = {v: k for k, v in cat2label.items()}
lvis_results = []
label_map = args.label_map
iou_types = ["bbox"]
elif args.dataset_file == "ovcoco":
iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(
base_ds, iou_types, label2cat=data_loader.dataset.label2catid
)
else:
raise ValueError
if args.debug or utils.get_world_size() == 1:
print_freq = 10
else:
print_freq = 100
_cnt = 0
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
targets = [
{
k: v if isinstance(v, (list, dict)) else v.to(device)
for k, v in t.items()
}
for t in targets
]
outputs = model(
samples,
categories=data_loader.dataset.category_list,
targets=targets,
)
# for loss only
training_target = []
for target in targets:
new_target = target.copy()
new_target["ori_labels"] = target["labels"]
new_target["labels"] = target["labels"] - target["labels"]
training_target.append(new_target)
loss_dict = criterion(outputs, training_target)
weight_dict = criterion.weight_dict
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = utils.reduce_dict(loss_dict)
loss_dict_reduced_scaled = {
k: v * weight_dict[k]
for k, v in loss_dict_reduced.items()
if k in weight_dict
}
metric_logger.update(
loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
)
if "class_error" in loss_dict_reduced.keys():
metric_logger.update(class_error=loss_dict_reduced["class_error"])
else:
metric_logger.update(class_error=0.0)
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors["bbox"](outputs, orig_target_sizes)
if args.dataset_file == "ovlvis":
for target, output in zip(targets, results):
image_id = target["image_id"].item()
boxes = convert_to_xywh(output["boxes"])
for ind in range(len(output["scores"])):
temp = {
"image_id": image_id,
"score": output["scores"][ind].item(),
"category_id": output["labels"][ind].item(),
"bbox": boxes[ind].tolist(),
}
if label_map:
temp["category_id"] = label2cat[temp["category_id"]]
lvis_results.append(temp)
else:
res = {
target["image_id"].item(): output
for target, output in zip(targets, results)
}
if coco_evaluator is not None:
coco_evaluator.update(res)
_cnt += 1
if args.debug:
if _cnt % (15 * 5) == 0:
print("BREAK!" * 5)
break
metric_logger.synchronize_between_processes()
if args.dataset_file == "ovlvis":
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
rank = utils.get_rank()
if epoch is not None: # 训练期间只保存,不验证
torch.save(lvis_results, os.path.join(output_dir,f"epoch_{epoch}",f"pred_{rank}.pth"))
if torch.distributed.is_initialized():
torch.distributed.barrier()
else:
torch.save(lvis_results, os.path.join(output_dir,f"pred_{rank}.pth"))
if torch.distributed.is_initialized():
torch.distributed.barrier()
if rank == 0:
world_size = utils.get_world_size()
for i in range(1, world_size):
temp = torch.load(output_dir + f"/pred_{i}.pth")
lvis_results += temp
lvis_results = LVISResults(base_ds, lvis_results, max_dets=300)
for iou_type in iou_types:
lvis_eval = LVISEval(base_ds, lvis_results, iou_type)
lvis_eval.run()
lvis_eval.print_results()
if rank == 0:
stats.update(lvis_eval.get_results())
return stats, None
else:
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if "bbox" in postprocessors.keys():
stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist()
if "segm" in postprocessors.keys():
stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist()
return stats, coco_evaluator