diff --git a/mmedit/apis/train.py b/mmedit/apis/train.py index 1a388636e1..8b34307712 100644 --- a/mmedit/apis/train.py +++ b/mmedit/apis/train.py @@ -11,7 +11,7 @@ from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper -from mmedit.core.runners.apex_amp_utils import apex_amp_initialize +from mmedit.core.runner.apex_amp_utils import apex_amp_initialize from mmedit.datasets.builder import build_dataloader, build_dataset from mmedit.utils import get_root_logger diff --git a/mmedit/core/__init__.py b/mmedit/core/__init__.py index 674c65ab4f..e065657ffa 100644 --- a/mmedit/core/__init__.py +++ b/mmedit/core/__init__.py @@ -3,10 +3,11 @@ from .hooks import VisualizationHook from .misc import tensor2img from .optimizer import build_optimizers +from .runner import IterBasedFP16Runner from .scheduler import LinearLrUpdaterHook __all__ = [ 'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook', 'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook', - 'VisualizationHook', 'L1Evaluation' + 'VisualizationHook', 'L1Evaluation', 'IterBasedFP16Runner' ] diff --git a/mmedit/core/runner/__init__.py b/mmedit/core/runner/__init__.py new file mode 100644 index 0000000000..3f91d3c446 --- /dev/null +++ b/mmedit/core/runner/__init__.py @@ -0,0 +1,3 @@ +from .iter_based_fp16_runner import IterBasedFP16Runner + +__all__ = ['IterBasedFP16Runner'] diff --git a/mmedit/core/runners/apex_amp_utils.py b/mmedit/core/runner/apex_amp_utils.py similarity index 100% rename from mmedit/core/runners/apex_amp_utils.py rename to mmedit/core/runner/apex_amp_utils.py diff --git a/mmedit/core/runners/checkpoints.py b/mmedit/core/runner/checkpoint.py similarity index 100% rename from mmedit/core/runners/checkpoints.py rename to mmedit/core/runner/checkpoint.py diff --git a/mmedit/core/runners/iter_based_fp16_runner.py b/mmedit/core/runner/iter_based_fp16_runner.py similarity index 99% rename from mmedit/core/runners/iter_based_fp16_runner.py rename to mmedit/core/runner/iter_based_fp16_runner.py index 9856b88a84..0da1ae0920 100644 --- a/mmedit/core/runners/iter_based_fp16_runner.py +++ b/mmedit/core/runner/iter_based_fp16_runner.py @@ -126,7 +126,7 @@ def __init__(self, *args, fp16_loss_scaler, use_apex_amp=False, **kwargs): # add fp16 grad scaler, using pytorch official GradScaler self.loss_scaler = GradScaler(**fp16_loss_scaler) - mmcv.print_log('Use FP16 grad scaler in Training', 'edit') + mmcv.print_log('Use FP16 grad scaler in Training', 'mmedit') # flag to use amp in apex (NVIDIA) self.use_apex_amp = use_apex_amp diff --git a/mmedit/models/backbones/sr_backbones/basicvsr_net.py b/mmedit/models/backbones/sr_backbones/basicvsr_net.py index 21dbb18143..bea65a1c04 100644 --- a/mmedit/models/backbones/sr_backbones/basicvsr_net.py +++ b/mmedit/models/backbones/sr_backbones/basicvsr_net.py @@ -250,10 +250,10 @@ def __init__(self, pretrained): self.register_buffer( 'mean', - torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).half()) self.register_buffer( 'std', - torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).half()) def compute_flow(self, ref, supp): """Compute flow from ref to supp. @@ -348,8 +348,8 @@ def forward(self, ref, supp): align_corners=False) # adjust the flow values - flow[:, 0, :, :] *= float(w) / float(w_up) - flow[:, 1, :, :] *= float(h) / float(h_up) + flow[:, 0, :, :] *= w / w_up + flow[:, 1, :, :] *= h / h_up return flow diff --git a/mmedit/models/restorers/basicvsr.py b/mmedit/models/restorers/basicvsr.py index 60fede7c04..0169991aa1 100644 --- a/mmedit/models/restorers/basicvsr.py +++ b/mmedit/models/restorers/basicvsr.py @@ -4,7 +4,7 @@ import mmcv import numpy as np import torch -import torch.cuda.amp.autocast +import torch.cuda.amp from mmedit.core import tensor2img from ..registry import MODELS @@ -34,17 +34,21 @@ def __init__(self, pixel_loss, train_cfg=None, test_cfg=None, - pretrained=None): + pretrained=None, + fp16_enabled=False): super().__init__(generator, pixel_loss, train_cfg, test_cfg, pretrained) # fix pre-trained networks self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0 - self.generator.find_unused_parameters = False + self.is_weight_fixed = False # count training steps self.register_buffer('step_counter', torch.zeros(1)) + # fp16 settings + self.fp16_enabled = fp16_enabled + def check_if_mirror_extended(self, lrs): """Check whether the input is a mirror-extended sequence. @@ -80,18 +84,23 @@ def train_step(self, # fix SPyNet and EDVR at the beginning if self.step_counter < self.fix_iter: - if not self.generator.find_unused_parameters: - self.generator.find_unused_parameters = True + if not self.is_weight_fixed: + self.is_weight_fixed = True for k, v in self.generator.named_parameters(): if 'spynet' in k or 'edvr' in k: v.requires_grad_(False) elif self.step_counter == self.fix_iter: # train all the parameters - self.generator.find_unused_parameters = False self.generator.requires_grad_(True) - fp16_enabled = True if loss_scaler is not None else False - with torch.cuda.amp.autocast(enabled=fp16_enabled): + if loss_scaler is None and self.fp16_enabled: + raise AssertionError('When loss_scaler is None, fp16_enabled ' + 'must be False. But got True.') + elif loss_scaler is not None and not self.fp16_enabled: + raise AssertionError('When loss_scaler is not None, fp16_enabled ' + 'must be True. But got False.') + + with torch.cuda.amp.autocast(enabled=self.fp16_enabled): outputs = self(**data_batch, test_mode=False) loss, log_vars = self.parse_losses(outputs.pop('losses'))