From 10f193653db088305e3ea584398c77580120a8bd Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Wed, 12 May 2021 20:42:12 +0800 Subject: [PATCH 1/2] Support of fp16 --- mmedit/apis/train.py | 37 ++- mmedit/core/runners/apex_amp_utils.py | 33 +++ mmedit/core/runners/checkpoints.py | 91 ++++++ mmedit/core/runners/iter_based_fp16_runner.py | 270 ++++++++++++++++++ mmedit/models/restorers/basicvsr.py | 33 ++- 5 files changed, 452 insertions(+), 12 deletions(-) create mode 100644 mmedit/core/runners/apex_amp_utils.py create mode 100644 mmedit/core/runners/checkpoints.py create mode 100644 mmedit/core/runners/iter_based_fp16_runner.py diff --git a/mmedit/apis/train.py b/mmedit/apis/train.py index d577e6357d..1a388636e1 100644 --- a/mmedit/apis/train.py +++ b/mmedit/apis/train.py @@ -7,10 +7,11 @@ import numpy as np import torch from mmcv.parallel import MMDataParallel -from mmcv.runner import HOOKS, IterBasedRunner +from mmcv.runner import HOOKS, IterBasedRunner, build_runner from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper +from mmedit.core.runners.apex_amp_utils import apex_amp_initialize from mmedit.datasets.builder import build_dataloader, build_dataset from mmedit.utils import get_root_logger @@ -133,12 +134,33 @@ def _dist_train(model, # build runner optimizer = build_optimizers(model, cfg.optimizers) - runner = IterBasedRunner( - model, - optimizer=optimizer, - work_dir=cfg.work_dir, - logger=logger, - meta=meta) + + # use apex amp + _use_apex_amp = False + if cfg.get('apex_amp', None): + model, optimizer = apex_amp_initialize(model, optimizer, + **cfg.apex_amp) + _use_apex_amp = True + + # allow users to define the runner + if cfg.get('runner', None): + runner = build_runner( + cfg.runner, + dict( + model=model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + use_apex_amp=_use_apex_amp, + meta=meta)) + else: + runner = IterBasedRunner( + model, + optimizer=optimizer, + work_dir=cfg.work_dir, + logger=logger, + meta=meta) + # an ugly walkaround to make the .log and .log.json filenames the same runner.timestamp = timestamp @@ -189,6 +211,7 @@ def _dist_train(model, runner.run(data_loaders, cfg.workflow, cfg.total_iters) +# TODO: Support fp16 for non-distributed training def _non_dist_train(model, dataset, cfg, diff --git a/mmedit/core/runners/apex_amp_utils.py b/mmedit/core/runners/apex_amp_utils.py new file mode 100644 index 0000000000..d166f6dc17 --- /dev/null +++ b/mmedit/core/runners/apex_amp_utils.py @@ -0,0 +1,33 @@ +try: + from apex import amp +except ImportError: + amp = None + + +def apex_amp_initialize(models, optimizers, init_args=None, mode='gan'): + """Initialize apex.amp for mixed-precision training. + Args: + models (nn.Module | list[Module]): Modules to be wrapped with apex.amp. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + init_args (dict | None, optional): Config for amp initialization. + Defaults to None. + mode (str, optional): The moded used to initialize the apex.map. + Different modes lead to different wrapping mode for models and + optimizers. Defaults to 'gan'. + Returns: + Module, :obj:`Optimizer`: Wrapped module and optimizer. + """ + init_args = init_args or dict() + + if mode == 'gan': + _optmizers = [optimizers['generator'], optimizers['discriminator']] + + models, _optmizers = amp.initialize(models, _optmizers, **init_args) + optimizers['generator'] = _optmizers[0] + optimizers['discriminator'] = _optmizers[1] + + return models, optimizers + + else: + raise NotImplementedError( + f'Cannot initialize apex.amp with mode {mode}') diff --git a/mmedit/core/runners/checkpoints.py b/mmedit/core/runners/checkpoints.py new file mode 100644 index 0000000000..967cb2953f --- /dev/null +++ b/mmedit/core/runners/checkpoints.py @@ -0,0 +1,91 @@ +import os.path as osp +import time +from tempfile import TemporaryDirectory + +import mmcv +import torch +from mmcv.parallel import is_module_wrapper +from mmcv.runner.checkpoint import get_state_dict, weights_to_cpu +from torch.optim import Optimizer + + +def save_checkpoint(model, + filename, + optimizer=None, + loss_scaler=None, + save_apex_amp=False, + meta=None): + """Save checkpoint to file. + The checkpoint will have 3 or more fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + In mixed-precision training, ``loss_scaler`` or ``amp.state_dict`` will be + saved in checkpoint. + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + loss_scaler (Object, optional): Loss scaler used for FP16 training. + save_apex_amp (bool, optional): Whether to save apex.amp + ``state_dict``. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + # save loss scaler for mixed-precision (FP16) training + if loss_scaler is not None: + checkpoint['loss_scaler'] = loss_scaler.state_dict() + + # save state_dict from apex.amp + if save_apex_amp: + from apex import amp + checkpoint['amp'] = amp.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() diff --git a/mmedit/core/runners/iter_based_fp16_runner.py b/mmedit/core/runners/iter_based_fp16_runner.py new file mode 100644 index 0000000000..9856b88a84 --- /dev/null +++ b/mmedit/core/runners/iter_based_fp16_runner.py @@ -0,0 +1,270 @@ +import os.path as osp +import platform +import shutil +from functools import partial + +import mmcv +import torch +import torch.distributed as dist +from mmcv.parallel import collate +from mmcv.runner import RUNNERS, IterBasedRunner +from torch.optim import Optimizer +from torch.utils.data import DataLoader + +from .checkpoint import save_checkpoint + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + from torch.cuda.amp import GradScaler +except ImportError: + pass + + +class IterLoader: + """Iteration based dataloader. + This wrapper for dataloader is to matching the iter-based training + proceduer. + Args: + dataloader (object): Dataloader in PyTorch. + runner (object): ``mmcv.Runner`` + """ + + def __init__(self, dataloader, runner): + self._dataloader = dataloader + self.runner = runner + self.iter_loader = iter(self._dataloader) + self._epoch = 0 + + @property + def epoch(self): + """The number of current epoch. + Returns: + int: Epoch number. + """ + return self._epoch + + def update_dataloader(self, curr_scale): + """Update dataloader. + Update the dataloader according to the `curr_scale`. This functionality + is very helpful in training progressive growing GANs in which the + dataloader should be updated according to the scale of the models in + training. + Args: + curr_scale (int): The scale in current stage. + """ + # update dataset, sampler, and samples per gpu in dataloader + if hasattr(self._dataloader.dataset, 'update_annotations'): + update_flag = self._dataloader.dataset.update_annotations( + curr_scale) + else: + update_flag = False + if update_flag: + # the sampler should be updated with the modified dataset + assert hasattr(self._dataloader.sampler, 'update_sampler') + samples_per_gpu = None if not hasattr( + self._dataloader.dataset, 'samples_per_gpu' + ) else self._dataloader.dataset.samples_per_gpu + self._dataloader.sampler.update_sampler(self._dataloader.dataset, + samples_per_gpu) + # update samples per gpu + if samples_per_gpu is not None: + if dist.is_initialized(): + # samples = samples_per_gpu + # self._dataloader.collate_fn = partial( + # collate, samples_per_gpu=samples) + self._dataloader = DataLoader( + self._dataloader.dataset, + batch_size=samples_per_gpu, + sampler=self._dataloader.sampler, + num_workers=self._dataloader.num_workers, + collate_fn=partial( + collate, samples_per_gpu=samples_per_gpu), + shuffle=False, + worker_init_fn=self._dataloader.worker_init_fn) + + self.iter_loader = iter(self._dataloader) + else: + raise NotImplementedError( + 'Currently, we only support dynamic batch size in' + ' ddp, because the number of gpus in DataParallel ' + 'cannot be obtained easily.') + + def __next__(self): + try: + data = next(self.iter_loader) + except StopIteration: + self._epoch += 1 + if hasattr(self._dataloader.sampler, 'set_epoch'): + self._dataloader.sampler.set_epoch(self._epoch) + self.iter_loader = iter(self._dataloader) + data = next(self.iter_loader) + + return data + + def __len__(self): + return len(self._dataloader) + + +@RUNNERS.register_module() +class IterBasedFP16Runner(IterBasedRunner): + """IterBasedRunner for FP16. + + In this IterBasedFP16Runner, training proceeds as the conventional + IterBasedRunner, with the only difference is that FP16 is used. + + Args: + fp16_loss_scalar (dict): Config for fp16 GradScaler + from ``torch.cuda.amp``. + use_apex_amp (bool, optional): Whether to use apex.amp to start mixed + precision training. Defaults to False. + """ + + def __init__(self, *args, fp16_loss_scaler, use_apex_amp=False, **kwargs): + + super().__init__(*args, **kwargs) + + # add fp16 grad scaler, using pytorch official GradScaler + self.loss_scaler = GradScaler(**fp16_loss_scaler) + mmcv.print_log('Use FP16 grad scaler in Training', 'edit') + + # flag to use amp in apex (NVIDIA) + self.use_apex_amp = use_apex_amp + + def call_hook(self, fn_name): + """Call all hooks. + Args: + fn_name (str): The function name in each hook to be called, such as + "before_train_epoch". + """ + for hook in self._hooks: + if hasattr(hook, fn_name): + getattr(hook, fn_name)(self) + + def train(self, data_loader, **kwargs): + self.model.train() + self.mode = 'train' + self.data_loader = data_loader + self._epoch = data_loader.epoch + self.call_hook('before_fetch_train_data') + data_batch = next(self.data_loader) + self.call_hook('before_train_iter') + + kwargs.update(dict(loss_scaler=self.loss_scaler)) + + if self.use_apex_amp: + kwargs.update(dict(use_apex_amp=True)) + + outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) + + # the loss scaler should be updated after ``train_step`` + self.loss_scaler.update() + + if not isinstance(outputs, dict): + raise TypeError('model.train_step() must return a dict') + if 'log_vars' in outputs: + self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) + self.outputs = outputs + self.call_hook('after_train_iter') + self._inner_iter += 1 + self._iter += 1 + + def resume(self, + checkpoint, + resume_optimizer=True, + resume_loss_scaler=True, + map_location='default'): + """Resume model from checkpoint. + + Args: + checkpoint (str): Checkpoint to resume from. + resume_optimizer (bool, optional): Whether resume the optimizer(s) + if the checkpoint file includes optimizer(s). Default to True. + resume_loss_scaler (bool, optional): Whether to resume the loss + scaler (GradScaler) from ``torch.cuda.amp``. Defaults to True. + map_location (str, optional): Same as :func:`torch.load`. + Default to 'default'. + """ + + if map_location == 'default': + device_id = torch.cuda.current_device() + checkpoint = self.load_checkpoint( + checkpoint, + map_location=lambda storage, loc: storage.cuda(device_id)) + else: + checkpoint = self.load_checkpoint( + checkpoint, map_location=map_location) + + self._epoch = checkpoint['meta']['epoch'] + self._iter = checkpoint['meta']['iter'] + self._inner_iter = checkpoint['meta']['iter'] + if 'optimizer' in checkpoint and resume_optimizer: + if isinstance(self.optimizer, Optimizer): + self.optimizer.load_state_dict(checkpoint['optimizer']) + elif isinstance(self.optimizer, dict): + for k in self.optimizer.keys(): + self.optimizer[k].load_state_dict( + checkpoint['optimizer'][k]) + else: + raise TypeError( + 'Optimizer should be dict or torch.optim.Optimizer ' + f'but got {type(self.optimizer)}') + + if 'loss_scaler' in checkpoint and resume_loss_scaler: + self.loss_scaler.load_state_dict(checkpoint['loss_scaler']) + + if self.use_apex_amp: + from apex import amp + amp.load_state_dict(checkpoint['amp']) + + self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') + + def save_checkpoint(self, + out_dir, + filename_tmpl='iter_{}.pth', + meta=None, + save_optimizer=True, + create_symlink=True): + """Save checkpoint to file. + + Args: + out_dir (str): Directory to save checkpoint files. + filename_tmpl (str, optional): Checkpoint file template. + Defaults to 'iter_{}.pth'. + meta (dict, optional): Metadata to be saved in checkpoint. + Defaults to None. + save_optimizer (bool, optional): Whether save optimizer. + Defaults to True. + create_symlink (bool, optional): Whether create symlink to the + latest checkpoint file. Defaults to True. + """ + + if meta is None: + meta = dict(iter=self.iter + 1, epoch=self.epoch + 1) + elif isinstance(meta, dict): + meta.update(iter=self.iter + 1, epoch=self.epoch + 1) + else: + raise TypeError( + f'meta should be a dict or None, but got {type(meta)}') + if self.meta is not None: + meta.update(self.meta) + + filename = filename_tmpl.format(self.iter + 1) + filepath = osp.join(out_dir, filename) + optimizer = self.optimizer if save_optimizer else None + _loss_scaler = self.loss_scaler + save_checkpoint( + self.model, + filepath, + optimizer=optimizer, + loss_scaler=_loss_scaler, + save_apex_amp=self.use_apex_amp, + meta=meta) + # in some environments, `os.symlink` is not supported, you may need to + # set `create_symlink` to False + if create_symlink: + dst_file = osp.join(out_dir, 'latest.pth') + if platform.system() != 'Windows': + mmcv.symlink(filename, dst_file) + else: + shutil.copy(filepath, dst_file) diff --git a/mmedit/models/restorers/basicvsr.py b/mmedit/models/restorers/basicvsr.py index 3d504727ba..60fede7c04 100644 --- a/mmedit/models/restorers/basicvsr.py +++ b/mmedit/models/restorers/basicvsr.py @@ -4,6 +4,7 @@ import mmcv import numpy as np import torch +import torch.cuda.amp.autocast from mmedit.core import tensor2img from ..registry import MODELS @@ -62,7 +63,11 @@ def check_if_mirror_extended(self, lrs): return is_mirror_extended - def train_step(self, data_batch, optimizer): + def train_step(self, + data_batch, + optimizer, + loss_scaler=None, + use_apex_amp=False): """Train step. Args: @@ -72,6 +77,7 @@ def train_step(self, data_batch, optimizer): Returns: dict: Returned output. """ + # fix SPyNet and EDVR at the beginning if self.step_counter < self.fix_iter: if not self.generator.find_unused_parameters: @@ -84,13 +90,30 @@ def train_step(self, data_batch, optimizer): self.generator.find_unused_parameters = False self.generator.requires_grad_(True) - outputs = self(**data_batch, test_mode=False) - loss, log_vars = self.parse_losses(outputs.pop('losses')) + fp16_enabled = True if loss_scaler is not None else False + with torch.cuda.amp.autocast(enabled=fp16_enabled): + outputs = self(**data_batch, test_mode=False) + loss, log_vars = self.parse_losses(outputs.pop('losses')) # optimize optimizer['generator'].zero_grad() - loss.backward() - optimizer['generator'].step() + if loss_scaler: + loss_scaler.scale(loss).backward() + elif use_apex_amp: + from apex import amp + with amp.scale_loss( + loss, optimizer['generator'], loss_id=1) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if loss_scaler: + loss_scaler.unscale_(optimizer['generator']) + # note that we do not contain clip_grad procedure + loss_scaler.step(optimizer['generator']) + # loss_scaler.update will be called in runner.train() + else: + optimizer['generator'].step() self.step_counter += 1 From 8489c930b98dd4e3ed8d5a3b22f0ee15c7182ae0 Mon Sep 17 00:00:00 2001 From: ckkelvinchan Date: Sun, 16 May 2021 14:35:01 +0800 Subject: [PATCH 2/2] BasicVSR fp16 --- mmedit/apis/train.py | 2 +- mmedit/core/__init__.py | 3 ++- mmedit/core/runner/__init__.py | 3 +++ .../{runners => runner}/apex_amp_utils.py | 0 .../checkpoints.py => runner/checkpoint.py} | 0 .../iter_based_fp16_runner.py | 2 +- .../backbones/sr_backbones/basicvsr_net.py | 8 +++--- mmedit/models/restorers/basicvsr.py | 25 +++++++++++++------ 8 files changed, 28 insertions(+), 15 deletions(-) create mode 100644 mmedit/core/runner/__init__.py rename mmedit/core/{runners => runner}/apex_amp_utils.py (100%) rename mmedit/core/{runners/checkpoints.py => runner/checkpoint.py} (100%) rename mmedit/core/{runners => runner}/iter_based_fp16_runner.py (99%) diff --git a/mmedit/apis/train.py b/mmedit/apis/train.py index 1a388636e1..8b34307712 100644 --- a/mmedit/apis/train.py +++ b/mmedit/apis/train.py @@ -11,7 +11,7 @@ from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper -from mmedit.core.runners.apex_amp_utils import apex_amp_initialize +from mmedit.core.runner.apex_amp_utils import apex_amp_initialize from mmedit.datasets.builder import build_dataloader, build_dataset from mmedit.utils import get_root_logger diff --git a/mmedit/core/__init__.py b/mmedit/core/__init__.py index 674c65ab4f..e065657ffa 100644 --- a/mmedit/core/__init__.py +++ b/mmedit/core/__init__.py @@ -3,10 +3,11 @@ from .hooks import VisualizationHook from .misc import tensor2img from .optimizer import build_optimizers +from .runner import IterBasedFP16Runner from .scheduler import LinearLrUpdaterHook __all__ = [ 'build_optimizers', 'tensor2img', 'EvalIterHook', 'DistEvalIterHook', 'mse', 'psnr', 'reorder_image', 'sad', 'ssim', 'LinearLrUpdaterHook', - 'VisualizationHook', 'L1Evaluation' + 'VisualizationHook', 'L1Evaluation', 'IterBasedFP16Runner' ] diff --git a/mmedit/core/runner/__init__.py b/mmedit/core/runner/__init__.py new file mode 100644 index 0000000000..3f91d3c446 --- /dev/null +++ b/mmedit/core/runner/__init__.py @@ -0,0 +1,3 @@ +from .iter_based_fp16_runner import IterBasedFP16Runner + +__all__ = ['IterBasedFP16Runner'] diff --git a/mmedit/core/runners/apex_amp_utils.py b/mmedit/core/runner/apex_amp_utils.py similarity index 100% rename from mmedit/core/runners/apex_amp_utils.py rename to mmedit/core/runner/apex_amp_utils.py diff --git a/mmedit/core/runners/checkpoints.py b/mmedit/core/runner/checkpoint.py similarity index 100% rename from mmedit/core/runners/checkpoints.py rename to mmedit/core/runner/checkpoint.py diff --git a/mmedit/core/runners/iter_based_fp16_runner.py b/mmedit/core/runner/iter_based_fp16_runner.py similarity index 99% rename from mmedit/core/runners/iter_based_fp16_runner.py rename to mmedit/core/runner/iter_based_fp16_runner.py index 9856b88a84..0da1ae0920 100644 --- a/mmedit/core/runners/iter_based_fp16_runner.py +++ b/mmedit/core/runner/iter_based_fp16_runner.py @@ -126,7 +126,7 @@ def __init__(self, *args, fp16_loss_scaler, use_apex_amp=False, **kwargs): # add fp16 grad scaler, using pytorch official GradScaler self.loss_scaler = GradScaler(**fp16_loss_scaler) - mmcv.print_log('Use FP16 grad scaler in Training', 'edit') + mmcv.print_log('Use FP16 grad scaler in Training', 'mmedit') # flag to use amp in apex (NVIDIA) self.use_apex_amp = use_apex_amp diff --git a/mmedit/models/backbones/sr_backbones/basicvsr_net.py b/mmedit/models/backbones/sr_backbones/basicvsr_net.py index 21dbb18143..bea65a1c04 100644 --- a/mmedit/models/backbones/sr_backbones/basicvsr_net.py +++ b/mmedit/models/backbones/sr_backbones/basicvsr_net.py @@ -250,10 +250,10 @@ def __init__(self, pretrained): self.register_buffer( 'mean', - torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).half()) self.register_buffer( 'std', - torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).half()) def compute_flow(self, ref, supp): """Compute flow from ref to supp. @@ -348,8 +348,8 @@ def forward(self, ref, supp): align_corners=False) # adjust the flow values - flow[:, 0, :, :] *= float(w) / float(w_up) - flow[:, 1, :, :] *= float(h) / float(h_up) + flow[:, 0, :, :] *= w / w_up + flow[:, 1, :, :] *= h / h_up return flow diff --git a/mmedit/models/restorers/basicvsr.py b/mmedit/models/restorers/basicvsr.py index 60fede7c04..0169991aa1 100644 --- a/mmedit/models/restorers/basicvsr.py +++ b/mmedit/models/restorers/basicvsr.py @@ -4,7 +4,7 @@ import mmcv import numpy as np import torch -import torch.cuda.amp.autocast +import torch.cuda.amp from mmedit.core import tensor2img from ..registry import MODELS @@ -34,17 +34,21 @@ def __init__(self, pixel_loss, train_cfg=None, test_cfg=None, - pretrained=None): + pretrained=None, + fp16_enabled=False): super().__init__(generator, pixel_loss, train_cfg, test_cfg, pretrained) # fix pre-trained networks self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0 - self.generator.find_unused_parameters = False + self.is_weight_fixed = False # count training steps self.register_buffer('step_counter', torch.zeros(1)) + # fp16 settings + self.fp16_enabled = fp16_enabled + def check_if_mirror_extended(self, lrs): """Check whether the input is a mirror-extended sequence. @@ -80,18 +84,23 @@ def train_step(self, # fix SPyNet and EDVR at the beginning if self.step_counter < self.fix_iter: - if not self.generator.find_unused_parameters: - self.generator.find_unused_parameters = True + if not self.is_weight_fixed: + self.is_weight_fixed = True for k, v in self.generator.named_parameters(): if 'spynet' in k or 'edvr' in k: v.requires_grad_(False) elif self.step_counter == self.fix_iter: # train all the parameters - self.generator.find_unused_parameters = False self.generator.requires_grad_(True) - fp16_enabled = True if loss_scaler is not None else False - with torch.cuda.amp.autocast(enabled=fp16_enabled): + if loss_scaler is None and self.fp16_enabled: + raise AssertionError('When loss_scaler is None, fp16_enabled ' + 'must be False. But got True.') + elif loss_scaler is not None and not self.fp16_enabled: + raise AssertionError('When loss_scaler is not None, fp16_enabled ' + 'must be True. But got False.') + + with torch.cuda.amp.autocast(enabled=self.fp16_enabled): outputs = self(**data_batch, test_mode=False) loss, log_vars = self.parse_losses(outputs.pop('losses'))