From b5c745d6a1ed77ae858a3253f7d1c3deb8511c09 Mon Sep 17 00:00:00 2001 From: "Kelvin C.K. Chan" Date: Mon, 24 May 2021 19:28:29 +0800 Subject: [PATCH] Fix unused_parameters in pt18 (#290) --- configs/restorers/basicvsr/basicvsr_reds4.py | 1 + configs/restorers/basicvsr/basicvsr_vimeo90k_bd.py | 1 + configs/restorers/basicvsr/basicvsr_vimeo90k_bi.py | 1 + configs/restorers/iconvsr/iconvsr_reds4.py | 1 + configs/restorers/iconvsr/iconvsr_vimeo90k_bd.py | 1 + configs/restorers/iconvsr/iconvsr_vimeo90k_bi.py | 1 + mmedit/models/restorers/basicvsr.py | 7 +++---- 7 files changed, 9 insertions(+), 4 deletions(-) diff --git a/configs/restorers/basicvsr/basicvsr_reds4.py b/configs/restorers/basicvsr/basicvsr_reds4.py index 5a05ee33f0..02ab0f172e 100644 --- a/configs/restorers/basicvsr/basicvsr_reds4.py +++ b/configs/restorers/basicvsr/basicvsr_reds4.py @@ -138,3 +138,4 @@ load_from = None resume_from = None workflow = [('train', 1)] +find_unused_parameters = True diff --git a/configs/restorers/basicvsr/basicvsr_vimeo90k_bd.py b/configs/restorers/basicvsr/basicvsr_vimeo90k_bd.py index f09980b512..e61dcecb20 100644 --- a/configs/restorers/basicvsr/basicvsr_vimeo90k_bd.py +++ b/configs/restorers/basicvsr/basicvsr_vimeo90k_bd.py @@ -155,3 +155,4 @@ load_from = None resume_from = None workflow = [('train', 1)] +find_unused_parameters = True diff --git a/configs/restorers/basicvsr/basicvsr_vimeo90k_bi.py b/configs/restorers/basicvsr/basicvsr_vimeo90k_bi.py index e58369eb42..0900398362 100644 --- a/configs/restorers/basicvsr/basicvsr_vimeo90k_bi.py +++ b/configs/restorers/basicvsr/basicvsr_vimeo90k_bi.py @@ -155,3 +155,4 @@ load_from = None resume_from = None workflow = [('train', 1)] +find_unused_parameters = True diff --git a/configs/restorers/iconvsr/iconvsr_reds4.py b/configs/restorers/iconvsr/iconvsr_reds4.py index e99336e24c..6c74f781d6 100644 --- a/configs/restorers/iconvsr/iconvsr_reds4.py +++ b/configs/restorers/iconvsr/iconvsr_reds4.py @@ -140,3 +140,4 @@ load_from = None resume_from = None workflow = [('train', 1)] +find_unused_parameters = True diff --git a/configs/restorers/iconvsr/iconvsr_vimeo90k_bd.py b/configs/restorers/iconvsr/iconvsr_vimeo90k_bd.py index e0f41d1c79..3594138a3a 100644 --- a/configs/restorers/iconvsr/iconvsr_vimeo90k_bd.py +++ b/configs/restorers/iconvsr/iconvsr_vimeo90k_bd.py @@ -159,3 +159,4 @@ load_from = None resume_from = None workflow = [('train', 1)] +find_unused_parameters = True diff --git a/configs/restorers/iconvsr/iconvsr_vimeo90k_bi.py b/configs/restorers/iconvsr/iconvsr_vimeo90k_bi.py index 738a14b2ba..97f33d0f93 100644 --- a/configs/restorers/iconvsr/iconvsr_vimeo90k_bi.py +++ b/configs/restorers/iconvsr/iconvsr_vimeo90k_bi.py @@ -159,3 +159,4 @@ load_from = None resume_from = None workflow = [('train', 1)] +find_unused_parameters = True diff --git a/mmedit/models/restorers/basicvsr.py b/mmedit/models/restorers/basicvsr.py index 3d504727ba..6f69ca0adc 100644 --- a/mmedit/models/restorers/basicvsr.py +++ b/mmedit/models/restorers/basicvsr.py @@ -39,7 +39,7 @@ def __init__(self, # fix pre-trained networks self.fix_iter = train_cfg.get('fix_iter', 0) if train_cfg else 0 - self.generator.find_unused_parameters = False + self.is_weight_fixed = False # count training steps self.register_buffer('step_counter', torch.zeros(1)) @@ -74,14 +74,13 @@ def train_step(self, data_batch, optimizer): """ # fix SPyNet and EDVR at the beginning if self.step_counter < self.fix_iter: - if not self.generator.find_unused_parameters: - self.generator.find_unused_parameters = True + if not self.is_weight_fixed: + self.is_weight_fixed = True for k, v in self.generator.named_parameters(): if 'spynet' in k or 'edvr' in k: v.requires_grad_(False) elif self.step_counter == self.fix_iter: # train all the parameters - self.generator.find_unused_parameters = False self.generator.requires_grad_(True) outputs = self(**data_batch, test_mode=False)