Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support NAFNet model #1369

Merged
merged 39 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6401dbf
NAFNet feature
VongolaWu Oct 10, 2022
af6ccd7
NAFNet and Real-ESRGAN
VongolaWu Oct 10, 2022
b0d4278
NAFNet for denoising cfg
VongolaWu Oct 10, 2022
cd03f2f
Modify NAFNet Gopro Config
VongolaWu Oct 14, 2022
d571266
NAFNet Deblurring Config
VongolaWu Oct 27, 2022
3b80ce5
Debug removal
VongolaWu Oct 27, 2022
e1fdb67
NAFNet Feature
VongolaWu Oct 27, 2022
e9a7706
NAFNet Feature Support
VongolaWu Oct 27, 2022
a8908f9
NAFNet Feature Support
VongolaWu Oct 27, 2022
3ec128b
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
VongolaWu Oct 31, 2022
a93d51b
NAFNet code check finished
VongolaWu Nov 1, 2022
e671456
NAFNet unit tests
VongolaWu Nov 1, 2022
c77355f
NAFNet unit test
VongolaWu Nov 1, 2022
50bbe2a
Docstring for NAFNet
VongolaWu Nov 2, 2022
7bb9978
NAFNet unit tests
VongolaWu Nov 2, 2022
76ba28b
Merge branch 'dev-1.x' into dev-1.x
VongolaWu Nov 2, 2022
5fadab5
NAFNet
VongolaWu Nov 2, 2022
e96b072
Merge branch 'dev-1.x' of https://github.com/VongolaWu/mmediting into…
VongolaWu Nov 2, 2022
84a186b
NAFNet unit tests
VongolaWu Nov 2, 2022
28d2704
NAFNet unit tests
VongolaWu Nov 2, 2022
71e96c0
PSNRLoss
VongolaWu Nov 2, 2022
775cbf5
NAFNet unit tests
VongolaWu Nov 2, 2022
c2e0c73
NAFNet Readme
VongolaWu Nov 4, 2022
2ca02b4
add NAFNet into __init__
VongolaWu Nov 5, 2022
d1745ef
Merge branch 'dev-1.x' into dev-1.x
zengyh1900 Nov 7, 2022
268fb40
NAFNet baseline
VongolaWu Nov 8, 2022
20fed2b
Merge branch 'dev-1.x' into dev-1.x
VongolaWu Nov 8, 2022
9fc2681
NAFNet cfgs for test
VongolaWu Nov 16, 2022
e22def5
NAFNet readme
VongolaWu Nov 16, 2022
c727b75
NAFNet Readme and Configs
VongolaWu Nov 16, 2022
31e65b5
Rename files
VongolaWu Nov 16, 2022
2c4bc1a
BaseModule
VongolaWu Nov 16, 2022
3057de8
clear undesired comments
VongolaWu Nov 16, 2022
7cc5d43
NAFNet Readme
VongolaWu Nov 16, 2022
7d704fe
NAFNet Readme
VongolaWu Nov 17, 2022
8caa034
NAFNet components
VongolaWu Nov 17, 2022
fe8d7fc
Merge branch 'dev-1.x' into dev-1.x
VongolaWu Nov 17, 2022
93256d9
NAFNet UTs
VongolaWu Nov 17, 2022
46d2d11
Merge branch 'dev-1.x' of https://github.com/VongolaWu/mmediting into…
VongolaWu Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions configs/nafnet/nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
_base_ = '../_base_/default_runtime.py'
VongolaWu marked this conversation as resolved.
Show resolved Hide resolved

VongolaWu marked this conversation as resolved.
Show resolved Hide resolved
experiment_name = 'nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro'
work_dir = f'./work_dirs/{experiment_name}'
save_dir = './work_dirs/'

# DistributedDataParallel
model_wrapper_cfg = dict(type='MMSeparateDistributedDataParallel')

# model settings
model = dict(
type='BaseEditModel',
generator=dict(
type='NAFNetLocal',
img_channel=3,
mid_channels=64,
enc_blk_nums=[1, 1, 1, 28],
middle_blk_num=1,
dec_blk_nums=[1, 1, 1, 1],
),
pixel_loss=dict(type='PSNRLoss'),
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
))

train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='SetValues', dictionary=dict(scale=1)),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(type='PackEditInputs')
]

val_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='PackEditInputs')
]

# dataset settings
dataset_type = 'BasicImageDataset'

train_dataloader = dict(
num_workers=8,
batch_size=8, # gpus 4
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='../datasets/gopro/train',
data_prefix=dict(gt='sharp', img='blur'),
ann_file='meta_info_gopro_train.txt',
pipeline=train_pipeline))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='../datasets/gopro/test',
ann_file='meta_info_gopro_test.txt',
data_prefix=dict(gt='sharp', img='blur'),
pipeline=val_pipeline))

test_dataloader = val_dataloader

val_evaluator = [
dict(type='MAE'),
dict(type='PSNR'),
dict(type='SSIM'),
]
test_evaluator = val_evaluator

train_cfg = dict(
type='IterBasedTrainLoop', max_iters=400_000, val_interval=20000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# optimizer
# optim_wrapper = dict(
VongolaWu marked this conversation as resolved.
Show resolved Hide resolved
# constructor='MultiOptimWrapperConstructor',
# generator=dict(
# type='OptimWrapper',
# optimizer=dict(type='Adam', lr=1e-3, betas=(0.9, 0.9))))
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=1e-3, betas=(0.9, 0.9)))

# learning policy
param_scheduler = dict(
type='CosineAnnealingLR', by_epoch=False, T_max=400_000, eta_min=1e-7)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5000,
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
)

visualizer = dict(bgr2rgb=False)

randomness = dict(seed=10, diff_rank_seed=True)
120 changes: 120 additions & 0 deletions configs/nafnet/nafnet_c64eb2248mb12db2222_lr1e-3_400k_sidd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
_base_ = '../_base_/default_runtime.py'

experiment_name = 'nafnet_c64eb2248mb12db2222_lr1e-3_400k_sidd'
work_dir = f'./work_dirs/{experiment_name}'
save_dir = './work_dirs/'

# DistributedDataParallel
model_wrapper_cfg = dict(type='MMSeparateDistributedDataParallel')

# model settings
model = dict(
type='BaseEditModel',
generator=dict(
type='NAFNet',
img_channel=3,
mid_channels=64,
enc_blk_nums=[2, 2, 4, 8],
middle_blk_num=12,
dec_blk_nums=[2, 2, 2, 2],
),
pixel_loss=dict(type='PSNRLoss'),
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0.0, 0.0, 0.0],
std=[255.0, 255.0, 255.0],
))

train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='SetValues', dictionary=dict(scale=1)),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip', keys=['img', 'gt'], flip_ratio=0.5, direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(type='PackEditInputs')
]

val_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='PackEditInputs')
]

# dataset settings
dataset_type = 'BasicImageDataset'

train_dataloader = dict(
num_workers=8,
batch_size=8, # gpus 4
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='sidd', task_name='denoising'),
data_root='../datasets/SIDD/train',
data_prefix=dict(gt='gt', img='noisy'),
filename_tmpl=dict(img='{}_NOISY', gt='{}_GT'),
pipeline=train_pipeline))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='sidd', task_name='denoising'),
data_root='../datasets/SIDD/val/',
data_prefix=dict(gt='gt', img='noisy'),
filename_tmpl=dict(gt='{}_GT', img='{}_NOISY'),
pipeline=val_pipeline))

test_dataloader = val_dataloader

val_evaluator = [
dict(type='MAE'),
dict(type='PSNR'),
dict(type='SSIM'),
]
test_evaluator = val_evaluator

train_cfg = dict(
type='IterBasedTrainLoop', max_iters=400_000, val_interval=20000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# optimizer
optim_wrapper = dict(
constructor='MultiOptimWrapperConstructor',
generator=dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, betas=(0.9, 0.9))))

# learning policy
param_scheduler = dict(
type='CosineAnnealingLR', by_epoch=False, T_max=400_000, eta_min=1e-7)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5000,
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
)

visualizer = dict(bgr2rgb=True)
4 changes: 3 additions & 1 deletion mmedit/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .liif import LIIF, MLPRefiner
from .lsgan import LSGAN
from .mspie import MSPIEStyleGAN2, PESinGAN
from .nafnet import Baseline, BaselineLocal, NAFNet, NAFNetLocal
from .pconv import (MaskConvModule, PartialConv2d, PConvDecoder, PConvEncoder,
PConvEncoderDecoder, PConvInpaintor)
from .pggan import ProgressiveGrowingGAN
Expand Down Expand Up @@ -73,5 +74,6 @@
'FBADecoder', 'WGANGP', 'CycleGAN', 'SAGAN', 'LSGAN', 'GGAN', 'Pix2Pix',
'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN',
'ProgressiveGrowingGAN', 'SinGAN', 'IDLossModel', 'PESinGAN',
'MSPIEStyleGAN2', 'StyleGAN3Generator'
'MSPIEStyleGAN2', 'StyleGAN3Generator', 'Baseline', 'BaselineLocal',
'NAFNet', 'NAFNetLocal'
]
10 changes: 10 additions & 0 deletions mmedit/models/editors/nafnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .baseline_net import Baseline, BaselineLocal
from .nafnet_net import NAFNet, NAFNetLocal

__all__ = [
'NAFNet',
'NAFNetLocal',
'Baseline',
'BaselineLocal',
]
Loading