diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py index f58b5bf..08356dd 100755 --- a/basicsr/models/base_model.py +++ b/basicsr/models/base_model.py @@ -8,7 +8,7 @@ from basicsr.models import lr_scheduler as lr_scheduler from basicsr.utils import get_root_logger from basicsr.utils.dist_util import master_only - +from basicsr.utils.download_util import load_file_from_url class BaseModel(): """Base model.""" @@ -301,6 +301,10 @@ def load_network(self, net, load_path, strict=True, param_key='params'): None, use the root 'path'. Default: 'params'. """ + if load_path.startswith('https://'): + pretrain_model_dir = os.path.join(self.opt['root_path'], 'experiments/pretrained_models') + load_path = load_file_from_url(load_path, model_dir=pretrain_model_dir) + logger = get_root_logger() net = self.get_bare_model(net) load_net = torch.load(load_path, map_location=lambda storage, loc: storage) diff --git a/options/train_FeMaSR_LQ_stage.yml b/options/train_FeMaSR_LQ_stage.yml index 45e1c5a..abc36bf 100755 --- a/options/train_FeMaSR_LQ_stage.yml +++ b/options/train_FeMaSR_LQ_stage.yml @@ -1,5 +1,5 @@ # general settings -name: 014_FeMaSR_LQ_stage +name: 015_FeMaSR_LQ_stage # name: debug_FeMaSR model_type: FeMaSRModel scale: &upscale 4 @@ -10,11 +10,11 @@ manual_seed: 0 datasets: train: name: General_Image_Train - type: BSRGANTrainDataset - dataroot_gt: ../datasets/HQ_sub - # type: PairedImageDataset + # type: BSRGANTrainDataset # dataroot_gt: ../datasets/HQ_sub - # dataroot_lq: ../datasets/LQ_sub_X4 + type: PairedImageDataset + dataroot_gt: ../datasets/HQ_sub + dataroot_lq: ../datasets/LQ_sub_X4 io_backend: type: disk @@ -60,9 +60,9 @@ network_d: # path path: - pretrain_network_hq: ./experiments/008_FeMaSR_HQ_stage/models/net_g_best_.pth + pretrain_network_hq: https://github.com/chaofengc/FeMaSR/releases/download/v0.1-pretrain_models/FeMaSR_HRP_model_g.pth pretrain_network_g: ~ - pretrain_network_d: ./experiments/008_FeMaSR_HQ_stage/models/net_d_best_.pth + pretrain_network_d: https://github.com/chaofengc/FeMaSR/releases/download/v0.1-pretrain_models/FeMaSR_HRP_model_d.pth strict_load: false # resume_state: ~ @@ -147,6 +147,9 @@ logger: # resume_id: ~ # dist training settings -# dist_params: - # backend: nccl - # port: 16500 #29500 +dist_params: + backend: nccl + port: 16500 #29500 + +find_unused_parameters: true + diff --git a/requirements.txt b/requirements.txt index c18a82c..825701e 100755 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,5 @@ torch>=1.7 torchvision tqdm yapf -pyiqa +pyiqa>=0.1.4 einops