Skip to content

Commit

Permalink
BasicVSR fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
ckkelvinchan committed May 16, 2021
1 parent 10f1936 commit 8489c93
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mmedit/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers
from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper
from mmedit.core.runners.apex_amp_utils import apex_amp_initialize
from mmedit.core.runner.apex_amp_utils import apex_amp_initialize
from mmedit.datasets.builder import build_dataloader, build_dataset
from mmedit.utils import get_root_logger

Expand Down
3 changes: 2 additions & 1 deletion mmedit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from .hooks import VisualizationHook
from .misc import tensor2img
from .optimizer import build_optimizers
from .runner import IterBasedFP16Runner
from .scheduler import LinearLrUpdaterHook

__all__ = [
'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook',
'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook',
'VisualizationHook', 'L1Evaluation'
'VisualizationHook', 'L1Evaluation', 'IterBasedFP16Runner'
]
3 changes: 3 additions & 0 deletions mmedit/core/runner/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .iter_based_fp16_runner import IterBasedFP16Runner

__all__ = ['IterBasedFP16Runner']
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, *args, fp16_loss_scaler, use_apex_amp=False, **kwargs):

# add fp16 grad scaler, using pytorch official GradScaler
self.loss_scaler = GradScaler(**fp16_loss_scaler)
mmcv.print_log('Use FP16 grad scaler in Training', 'edit')
mmcv.print_log('Use FP16 grad scaler in Training', 'mmedit')

# flag to use amp in apex (NVIDIA)
self.use_apex_amp = use_apex_amp
Expand Down
8 changes: 4 additions & 4 deletions mmedit/models/backbones/sr_backbones/basicvsr_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,10 @@ def __init__(self, pretrained):

self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).half())
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).half())

def compute_flow(self, ref, supp):
"""Compute flow from ref to supp.
Expand Down Expand Up @@ -348,8 +348,8 @@ def forward(self, ref, supp):
align_corners=False)

# adjust the flow values
flow[:, 0, :, :] *= float(w) / float(w_up)
flow[:, 1, :, :] *= float(h) / float(h_up)
flow[:, 0, :, :] *= w / w_up
flow[:, 1, :, :] *= h / h_up

return flow

Expand Down
25 changes: 17 additions & 8 deletions mmedit/models/restorers/basicvsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mmcv
import numpy as np
import torch
import torch.cuda.amp.autocast
import torch.cuda.amp

from mmedit.core import tensor2img
from ..registry import MODELS
Expand Down Expand Up @@ -34,17 +34,21 @@ def __init__(self,
pixel_loss,
train_cfg=None,
test_cfg=None,
pretrained=None):
pretrained=None,
fp16_enabled=False):
super().__init__(generator, pixel_loss, train_cfg, test_cfg,
pretrained)

# fix pre-trained networks
self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0
self.generator.find_unused_parameters = False
self.is_weight_fixed = False

# count training steps
self.register_buffer('step_counter', torch.zeros(1))

# fp16 settings
self.fp16_enabled = fp16_enabled

def check_if_mirror_extended(self, lrs):
"""Check whether the input is a mirror-extended sequence.
Expand Down Expand Up @@ -80,18 +84,23 @@ def train_step(self,

# fix SPyNet and EDVR at the beginning
if self.step_counter < self.fix_iter:
if not self.generator.find_unused_parameters:
self.generator.find_unused_parameters = True
if not self.is_weight_fixed:
self.is_weight_fixed = True
for k, v in self.generator.named_parameters():
if 'spynet' in k or 'edvr' in k:
v.requires_grad_(False)
elif self.step_counter == self.fix_iter:
# train all the parameters
self.generator.find_unused_parameters = False
self.generator.requires_grad_(True)

fp16_enabled = True if loss_scaler is not None else False
with torch.cuda.amp.autocast(enabled=fp16_enabled):
if loss_scaler is None and self.fp16_enabled:
raise AssertionError('When loss_scaler is None, fp16_enabled '
'must be False. But got True.')
elif loss_scaler is not None and not self.fp16_enabled:
raise AssertionError('When loss_scaler is not None, fp16_enabled '
'must be True. But got False.')

with torch.cuda.amp.autocast(enabled=self.fp16_enabled):
outputs = self(**data_batch, test_mode=False)
loss, log_vars = self.parse_losses(outputs.pop('losses'))

Expand Down

0 comments on commit 8489c93

Please sign in to comment.