Skip to content

Commit

Permalink
[Enhancement] Set real_feat to cpu in inception_utils (#1415)
Browse files Browse the repository at this point in the history
fix s2 configs
  • Loading branch information
plyfager authored Nov 7, 2022
1 parent b70695d commit 3349986
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions configs/styleganv2/stylegan2_c2_8xb4_lsun-car-384x512.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,9 @@
metrics = [
dict(
type='FrechetInceptionDistance',
prefix='FID-Full-50k',
prefix='FID-50k',
fake_nums=50000,
real_nums=50000,
inception_style='StyleGAN',
sample_model='ema'),
dict(type='PrecisionAndRecall', fake_nums=50000, prefix='PR-50K'),
Expand All @@ -110,7 +111,7 @@
# checkpoint=dict(
# save_best=['FID-Full-50k/fid', 'IS-50k/is'],
# rule=['less', 'greater']))
default_hooks = dict(checkpoint=dict(save_best='FID-Full-50k/fid'))
default_hooks = dict(checkpoint=dict(save_best='FID-50k/fid'))

val_evaluator = dict(metrics=metrics)
test_evaluator = dict(metrics=metrics)
4 changes: 2 additions & 2 deletions mmedit/evaluation/functional/inception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def prepare_inception_feat(dataloader: DataLoader,
f'same time. But receive \'{mean}\' and \'{std}\' '
'respectively.')

real_feat_ = metric.forward_inception(img)
real_feat_ = metric.forward_inception(img).cpu()
real_feat.append(real_feat_)

if is_main_process():
Expand All @@ -443,7 +443,7 @@ def prepare_inception_feat(dataloader: DataLoader,
if is_main_process():
inception_state = dict(**args)
if capture_mean_cov:
real_feat = torch.cat(real_feat, dim=0)[:num_items].cpu().numpy()
real_feat = torch.cat(real_feat, dim=0)[:num_items].numpy()
real_mean = np.mean(real_feat, 0)
real_cov = np.cov(real_feat, rowvar=False)
inception_state['real_mean'] = real_mean
Expand Down

0 comments on commit 3349986

Please sign in to comment.