Skip to content

Commit

Permalink
feat: support url for weight path and dist train
Browse files Browse the repository at this point in the history
add support to load url path for pretrained models

 add example option
file for distributed training
  • Loading branch information
chaofengc committed Oct 7, 2022
1 parent af0b2b8 commit 497d3ee
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
6 changes: 5 additions & 1 deletion basicsr/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from basicsr.models import lr_scheduler as lr_scheduler
from basicsr.utils import get_root_logger
from basicsr.utils.dist_util import master_only

from basicsr.utils.download_util import load_file_from_url

class BaseModel():
"""Base model."""
Expand Down Expand Up @@ -301,6 +301,10 @@ def load_network(self, net, load_path, strict=True, param_key='params'):
None, use the root 'path'.
Default: 'params'.
"""
if load_path.startswith('https://'):
pretrain_model_dir = os.path.join(self.opt['root_path'], 'experiments/pretrained_models')
load_path = load_file_from_url(load_path, model_dir=pretrain_model_dir)

logger = get_root_logger()
net = self.get_bare_model(net)
load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
Expand Down
23 changes: 13 additions & 10 deletions options/train_FeMaSR_LQ_stage.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# general settings
name: 014_FeMaSR_LQ_stage
name: 015_FeMaSR_LQ_stage
# name: debug_FeMaSR
model_type: FeMaSRModel
scale: &upscale 4
Expand All @@ -10,11 +10,11 @@ manual_seed: 0
datasets:
train:
name: General_Image_Train
type: BSRGANTrainDataset
dataroot_gt: ../datasets/HQ_sub
# type: PairedImageDataset
# type: BSRGANTrainDataset
# dataroot_gt: ../datasets/HQ_sub
# dataroot_lq: ../datasets/LQ_sub_X4
type: PairedImageDataset
dataroot_gt: ../datasets/HQ_sub
dataroot_lq: ../datasets/LQ_sub_X4
io_backend:
type: disk

Expand Down Expand Up @@ -60,9 +60,9 @@ network_d:

# path
path:
pretrain_network_hq: ./experiments/008_FeMaSR_HQ_stage/models/net_g_best_.pth
pretrain_network_hq: https://github.com/chaofengc/FeMaSR/releases/download/v0.1-pretrain_models/FeMaSR_HRP_model_g.pth
pretrain_network_g: ~
pretrain_network_d: ./experiments/008_FeMaSR_HQ_stage/models/net_d_best_.pth
pretrain_network_d: https://github.com/chaofengc/FeMaSR/releases/download/v0.1-pretrain_models/FeMaSR_HRP_model_d.pth
strict_load: false
# resume_state: ~

Expand Down Expand Up @@ -147,6 +147,9 @@ logger:
# resume_id: ~

# dist training settings
# dist_params:
# backend: nccl
# port: 16500 #29500
dist_params:
backend: nccl
port: 16500 #29500

find_unused_parameters: true

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ torch>=1.7
torchvision
tqdm
yapf
pyiqa
pyiqa>=0.1.4
einops

0 comments on commit 497d3ee

Please sign in to comment.