Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support NAFNet model #1369

Merged
merged 39 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6401dbf
NAFNet feature
VongolaWu Oct 10, 2022
af6ccd7
NAFNet and Real-ESRGAN
VongolaWu Oct 10, 2022
b0d4278
NAFNet for denoising cfg
VongolaWu Oct 10, 2022
cd03f2f
Modify NAFNet Gopro Config
VongolaWu Oct 14, 2022
d571266
NAFNet Deblurring Config
VongolaWu Oct 27, 2022
3b80ce5
Debug removal
VongolaWu Oct 27, 2022
e1fdb67
NAFNet Feature
VongolaWu Oct 27, 2022
e9a7706
NAFNet Feature Support
VongolaWu Oct 27, 2022
a8908f9
NAFNet Feature Support
VongolaWu Oct 27, 2022
3ec128b
Merge branch 'open-mmlab:dev-1.x' into dev-1.x
VongolaWu Oct 31, 2022
a93d51b
NAFNet code check finished
VongolaWu Nov 1, 2022
e671456
NAFNet unit tests
VongolaWu Nov 1, 2022
c77355f
NAFNet unit test
VongolaWu Nov 1, 2022
50bbe2a
Docstring for NAFNet
VongolaWu Nov 2, 2022
7bb9978
NAFNet unit tests
VongolaWu Nov 2, 2022
76ba28b
Merge branch 'dev-1.x' into dev-1.x
VongolaWu Nov 2, 2022
5fadab5
NAFNet
VongolaWu Nov 2, 2022
e96b072
Merge branch 'dev-1.x' of https://github.com/VongolaWu/mmediting into…
VongolaWu Nov 2, 2022
84a186b
NAFNet unit tests
VongolaWu Nov 2, 2022
28d2704
NAFNet unit tests
VongolaWu Nov 2, 2022
71e96c0
PSNRLoss
VongolaWu Nov 2, 2022
775cbf5
NAFNet unit tests
VongolaWu Nov 2, 2022
c2e0c73
NAFNet Readme
VongolaWu Nov 4, 2022
2ca02b4
add NAFNet into __init__
VongolaWu Nov 5, 2022
d1745ef
Merge branch 'dev-1.x' into dev-1.x
zengyh1900 Nov 7, 2022
268fb40
NAFNet baseline
VongolaWu Nov 8, 2022
20fed2b
Merge branch 'dev-1.x' into dev-1.x
VongolaWu Nov 8, 2022
9fc2681
NAFNet cfgs for test
VongolaWu Nov 16, 2022
e22def5
NAFNet readme
VongolaWu Nov 16, 2022
c727b75
NAFNet Readme and Configs
VongolaWu Nov 16, 2022
31e65b5
Rename files
VongolaWu Nov 16, 2022
2c4bc1a
BaseModule
VongolaWu Nov 16, 2022
3057de8
clear undesired comments
VongolaWu Nov 16, 2022
7cc5d43
NAFNet Readme
VongolaWu Nov 16, 2022
7d704fe
NAFNet Readme
VongolaWu Nov 17, 2022
8caa034
NAFNet components
VongolaWu Nov 17, 2022
fe8d7fc
Merge branch 'dev-1.x' into dev-1.x
VongolaWu Nov 17, 2022
93256d9
NAFNet UTs
VongolaWu Nov 17, 2022
46d2d11
Merge branch 'dev-1.x' of https://github.com/VongolaWu/mmediting into…
VongolaWu Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions configs/nafnet/nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
_base_ = '../_base_/default_runtime.py'
VongolaWu marked this conversation as resolved.
Show resolved Hide resolved

VongolaWu marked this conversation as resolved.
Show resolved Hide resolved
experiment_name = 'nafnet_c64eb11128mb1db1111_lr1e-3_400k_gopro'
work_dir = f'./work_dirs/{experiment_name}'
save_dir = './work_dirs/'

# DistributedDataParallel
model_wrapper_cfg = dict(type='MMSeparateDistributedDataParallel')

# model settings
model = dict(
type='BaseEditModel',
generator=dict(
type='NAFNetLocal',
img_channel=3,
mid_channels=64,
enc_blk_nums=[1, 1, 1, 28],
middle_blk_num=1,
dec_blk_nums=[1, 1, 1, 1],
),
pixel_loss=dict(type='PSNRLoss'),
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0., 0., 0.],
std=[255., 255., 255.],
)
)

train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='SetValues', dictionary=dict(scale=1)),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(type='PackEditInputs')
]

val_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='PackEditInputs')
]

# dataset settings
dataset_type = 'BasicImageDataset'

train_dataloader = dict(
num_workers=8,
batch_size=8, # gpus 4
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='../datasets/gopro/train',
data_prefix=dict(gt='sharp', img='blur'),
ann_file='meta_info_gopro_train.txt',
pipeline=train_pipeline))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='gopro', task_name='deblur'),
data_root='../datasets/gopro/test',
ann_file='meta_info_gopro_test.txt',
data_prefix=dict(gt='sharp', img='blur'),
pipeline=val_pipeline))

test_dataloader = val_dataloader

val_evaluator = [
dict(type='MAE'),
dict(type='PSNR'),
dict(type='SSIM'),
]
test_evaluator = val_evaluator

train_cfg = dict(
type='IterBasedTrainLoop', max_iters=400_000, val_interval=20000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# optimizer
# optim_wrapper = dict(
VongolaWu marked this conversation as resolved.
Show resolved Hide resolved
# constructor='MultiOptimWrapperConstructor',
# generator=dict(
# type='OptimWrapper',
# optimizer=dict(type='Adam', lr=1e-3, betas=(0.9, 0.9))))
optim_wrapper = dict(
constructor='DefaultOptimWrapperConstructor',
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, weight_decay=1e-3, betas=(0.9, 0.9)))

# learning policy
param_scheduler = dict(
type='CosineAnnealingLR',
by_epoch=False,
T_max=400_000,
eta_min=1e-7)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5000,
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
)

visualizer = dict(bgr2rgb=False)

randomness = dict(seed=10, diff_rank_seed=True)
127 changes: 127 additions & 0 deletions configs/nafnet/nafnet_c64eb2248mb12db2222_lr1e-3_400k_sidd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
_base_ = '../_base_/default_runtime.py'

experiment_name = 'nafnet_c64eb2248mb12db2222_lr1e-3_400k_sidd'
work_dir = f'./work_dirs/{experiment_name}'
save_dir = './work_dirs/'

# DistributedDataParallel
model_wrapper_cfg = dict(type='MMSeparateDistributedDataParallel')

# model settings
model = dict(
type='BaseEditModel',
generator=dict(
type='NAFNet',
img_channel=3,
mid_channels=64,
enc_blk_nums=[2, 2, 4, 8],
middle_blk_num=12,
dec_blk_nums=[2, 2, 2, 2],
),
pixel_loss=dict(type='PSNRLoss'),
train_cfg=dict(),
test_cfg=dict(),
data_preprocessor=dict(
type='EditDataPreprocessor',
mean=[0.0, 0.0, 0.0],
std=[255.0, 255.0, 255.0],
)
)

train_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='SetValues', dictionary=dict(scale=1)),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='horizontal'),
dict(
type='Flip',
keys=['img', 'gt'],
flip_ratio=0.5,
direction='vertical'),
dict(type='RandomTransposeHW', keys=['img', 'gt'], transpose_ratio=0.5),
dict(type='PairedRandomCrop', gt_patch_size=256),
dict(type='PackEditInputs')
]

val_pipeline = [
dict(type='LoadImageFromFile', key='img'),
dict(type='LoadImageFromFile', key='gt'),
dict(type='PackEditInputs')
]

# dataset settings
dataset_type = 'BasicImageDataset'

train_dataloader = dict(
num_workers=8,
batch_size=8, # gpus 4
persistent_workers=False,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='sidd', task_name='denoising'),
data_root='../datasets/SIDD/train',
data_prefix=dict(gt='gt', img='noisy'),
filename_tmpl=dict(img='{}_NOISY', gt='{}_GT'),
pipeline=train_pipeline))

val_dataloader = dict(
num_workers=4,
persistent_workers=False,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
metainfo=dict(dataset_type='sidd', task_name='denoising'),
data_root='../datasets/SIDD/val/',
data_prefix=dict(gt='gt', img='noisy'),
filename_tmpl=dict(gt='{}_GT', img='{}_NOISY'),
pipeline=val_pipeline))

test_dataloader = val_dataloader

val_evaluator = [
dict(type='MAE'),
dict(type='PSNR'),
dict(type='SSIM'),
]
test_evaluator = val_evaluator

train_cfg = dict(
type='IterBasedTrainLoop', max_iters=400_000, val_interval=20000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# optimizer
optim_wrapper = dict(
constructor='MultiOptimWrapperConstructor',
generator=dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=1e-3, betas=(0.9, 0.9))))

# learning policy
param_scheduler = dict(
type='CosineAnnealingLR',
by_epoch=False,
T_max=400_000,
eta_min=1e-7)

default_hooks = dict(
checkpoint=dict(
type='CheckpointHook',
interval=5000,
save_optimizer=True,
by_epoch=False,
out_dir=save_dir,
),
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100),
param_scheduler=dict(type='ParamSchedulerHook'),
sampler_seed=dict(type='DistSamplerSeedHook'),
)

visualizer = dict(bgr2rgb=True)
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,4 @@
interval=1,
interp_cfg=dict(momentum=0.999),
)
]
]
57 changes: 57 additions & 0 deletions mmedit/evaluation/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,60 @@ def ssim(img1,
ssims.append(_ssim(img1[..., i], img2[..., i]))

return np.array(ssims).mean()

# Components for SSIM with 3D kernels
# Not used for now
import torch

def _3d_gaussian_calculator(img, conv3d):
out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
return out

def _generate_3d_gaussian_kernel():
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
kernel_3 = cv2.getGaussianKernel(11, 1.5)
kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
conv3d.weight.requires_grad = False
conv3d.weight[0, 0, :, :, :] = kernel
return conv3d

def _ssim_3d(img1, img2, max_value):
assert len(img1.shape) == 3 and len(img2.shape) == 3
"""Calculate SSIM (structural similarity) for one channel images.

It is called by func:`calculate_ssim`.

Args:
img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.

Returns:
float: ssim result.
"""
C1 = (0.01 * max_value) ** 2
C2 = (0.03 * max_value) ** 2
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)

kernel = _generate_3d_gaussian_kernel().cuda()

img1 = torch.tensor(img1).float().cuda()
img2 = torch.tensor(img2).float().cuda()


mu1 = _3d_gaussian_calculator(img1, kernel)
mu2 = _3d_gaussian_calculator(img2, kernel)

mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2
mu1_mu2 = mu1 * mu2
sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2

ssim_map = ((2 * mu1_mu2 + C1) *
(2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
(sigma1_sq + sigma2_sq + C2))
return float(ssim_map.mean())
1 change: 1 addition & 0 deletions mmedit/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from .tof import TOFlowVFINet, TOFlowVSRNet, ToFResBlock
from .ttsr import LTE, TTSR, SearchTransformer, TTSRDiscriminator, TTSRNet
from .wgan_gp import WGANGP
from .nafnet import NAFNet, NAFNetLocal, Baseline, BaselineLocal

__all__ = [
'AOTEncoderDecoder', 'AOTBlockNeck', 'AOTInpaintor',
Expand Down
9 changes: 9 additions & 0 deletions mmedit/models/editors/nafnet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .nafnet_net import NAFNet, NAFNetLocal
from .baseline_net import Baseline, BaselineLocal

__all__ = [
"NAFNet",
"NAFNetLocal",
"Baseline",
"BaselineLocal",
]
Loading