diff --git a/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py b/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py index 9e3a41c200..3888b4516b 100644 --- a/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py +++ b/configs/wgan-gp/wgangp_GN_1xb64-160kiters_celeba-cropped-128x128.py @@ -26,7 +26,7 @@ loss_config=loss_config) # `batch_size` and `data_root` need to be set. -batch_size = 4 +batch_size = 64 data_root = './data/celeba-cropped/cropped_images_aligned_png/' train_dataloader = dict( batch_size=batch_size, dataset=dict(data_root=data_root)) @@ -47,7 +47,7 @@ custom_hooks = [ dict( type='GenVisualizationHook', - interval=1000, + interval=5000, fixed_input=True, vis_kwargs_list=dict(type='GAN', name='fake_img')) ] @@ -65,5 +65,8 @@ image_shape=(3, 128, 128)) ] +# save multi best checkpoints +default_hooks = dict(checkpoint=dict(save_best='swd/avg')) + val_evaluator = dict(metrics=metrics) test_evaluator = dict(metrics=metrics)