-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathddp_test_main.py
131 lines (118 loc) · 7.53 KB
/
ddp_test_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
from utils.argfunc_utils import *
from ddp_builder import *
import os
from ddp_engine_test import test_one_epoch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from my_model.ngswin_model.ngswin import NGswin
from my_model.swinirng import SwinIRNG
def get_args_parser():
parser = argparse.ArgumentParser('N-Gram in Swin Transformers for Efficient Lightweight Image Super-Resolution', description='N-Gram in Swin Transformers for Efficient Lightweight Image Super-Resolution', add_help=True)
# ddp
parser.add_argument('--total_nodes', default=1, type=int, metavar='N')
parser.add_argument('--gpus_per_node', default=4, type=int, help='number of gpus per node')
parser.add_argument('--node_rank', default=0, type=int, help='ranking within the nodes')
parser.add_argument('--ip_address', type=str, required=True, help='ip address of the host node')
parser.add_argument('--backend', default='nccl', type=str, help='nccl or gloo')
# etc
parser.add_argument('--model_time', type=str, help='automatically set when build model or manually set when load_model is True')
parser.add_argument('--task', type=str, default='lightweight_sr')
parser.add_argument('--load_model', type=str2bool, default=False, help='use checkpoint epoch or not')
parser.add_argument('--checkpoint_epoch', type=int, default=0, help='restart train checkpoint')
# model type
parser.add_argument('--model_name', type=str, default='NGswin', help='NGswin or SwinIR-NG')
parser.add_argument('--target_mode', type=str, default='light_x2', help='light_x2 light_x3 light_x4')
parser.add_argument('--scale', type=int, help="upscale factor corresponding to 'target_mode'. it is automatically set.")
parser.add_argument('--window_size', type=int, default=8, help='window size of (shifted) window attention')
parser.add_argument('--training_patch_size', type=int, default=64, help='LQ image patch size. model input patch size only for training')
# model ngram & in-channels spec -> ignored for SwinIR-NG
parser.add_argument('--ngrams', type=str2tuple, default=(2,2,2,2), help='ngram size around each patch or window. embed-enc1-enc2-enc3-enc4-dec order.')
parser.add_argument('--in_chans', type=int, default=3, help='number of input image channels')
# model encoder spec -> ignored for SwinIR-NG
parser.add_argument('--embed_dim', type=int, default=64, help='base dimension of model encoder')
parser.add_argument('--depths', type=str2tuple, default=(6,4,4), help='number of transformer blocks on encoder')
parser.add_argument('--num_heads', type=str2tuple, default=(6,4,4), help='number of multi-heads of self attention on encoder')
parser.add_argument('--head_dim', type=int, help='dimension per multi-heads on encoder')
# model decoder spec -> ignored for SwinIR-NG
parser.add_argument('--dec_dim', type=int, default=64, help='base dimension of model decoder')
parser.add_argument('--dec_depths', type=int, default=6, help='number of transformer blocks on decoder')
parser.add_argument('--dec_num_heads', type=int, default=6, help='number of multi-heads of self attention on decoder')
parser.add_argument('--dec_head_dim', type=int, help='each dimension per multi-heads on decoder')
# model etc param spec -> ignored for SwinIR-NG
parser.add_argument('--mlp_ratio', type=float, default=2.0, help="FFN's hidden dimension ratio over transformer dimension on encoder and decoder both")
parser.add_argument('--qkv_bias', type=str2bool, default=True, help='whether using self attention qkv parameter bias on encoder and decoder both')
# model dropout spec -> ignored for SwinIR-NG
parser.add_argument('--drop_rate', type=float, default=0.0, help="dropout rate except attention layers")
parser.add_argument('--attn_drop_rate', type=float, default=0.0, help="dropout rate in attention layers")
parser.add_argument('--drop_path_rate', type=float, default=0.0, help="stochastic drop rate of attention and ffn layers in transformer on encoder and decoder both")
# model activation / norm / position-embedding spec -> ignored for SwinIR-NG
parser.add_argument('--act_layer', type=str2nn_module, default=nn.GELU, help="activation layer")
parser.add_argument('--norm_layer', type=str2nn_module, default=nn.LayerNorm, help="normalization layer")
# train / test spec
parser.add_argument('--test_only', type=str2bool, default=False, help='only evaluate model. not train')
parser.add_argument('--sr_image_save', type=str2bool, default=True, help='save reconstructed image at test')
# dataset / dataloader spec
parser.add_argument('--img_norm', type=str2bool, default=True, help="image normalization before input") # -> ignored for SwinIR-NG
return parser
def main(gpu, args):
rank = args.node_rank * args.gpus_per_node + gpu
dist.init_process_group(
backend=args.backend,
init_method='env://',
world_size=args.world_size,
rank=rank
)
assert args.model_name in ['NGswin', 'SwinIR-NG'], "'model_name' should be NGswin or SwinIR-NG"
if args.model_name == 'NGswin':
model = NGswin(training_img_size=args.training_patch_size,
ngrams=args.ngrams,
in_chans=args.in_chans,
embed_dim=args.embed_dim,
depths=args.depths,
num_heads=args.num_heads,
head_dim=args.head_dim,
dec_dim=args.dec_dim,
dec_depths=args.dec_depths,
dec_num_heads=args.dec_num_heads,
dec_head_dim=args.dec_head_dim,
target_mode=args.target_mode,
window_size=args.window_size,
mlp_ratio=args.mlp_ratio,
qkv_bias=args.qkv_bias,
img_norm=args.img_norm,
drop_rate=args.drop_rate,
attn_drop_rate=args.attn_drop_rate,
drop_path_rate=args.drop_path_rate,
act_layer=args.act_layer,
norm_layer=args.norm_layer)
else:
model = SwinIRNG(upscale=args.scale, img_size=args.training_patch_size)
print_complextity(model, args)
sd = torch.load(f'pretrain/{args.model_name}_x{args.scale}.pth', map_location='cpu')
missings,_ = model.load_state_dict(sd, strict=False)
for xx in missings:
assert 'relative_position_index' in xx or 'attn_mask' in xx, f'essential key {xx} is dropped!'
print('<All keys matched successfully>')
model = model.to(gpu)
model = DDP(model, device_ids=[gpu])
if rank==0:
args = record_args_after_build(args)
if args.model_time is None:
import time
time.sleep(10)
args.model_time = sorted(os.listdir('args'))[-1][5:-4]
test_one_epoch(rank, gpu, model, 0, args)
if __name__ == '__main__':
os.makedirs(f'./args/', exist_ok=True)
os.makedirs(f'./logs/', exist_ok=True)
parser = get_args_parser()
args = parser.parse_args()
args.scale = int(args.target_mode[-1]) if args.target_mode[-1].isdigit() else 1
args.world_size = args.gpus_per_node * args.total_nodes
os.environ['MASTER_ADDR'] = args.ip_address
os.environ['MASTER_PORT'] = '8888' if args.backend=='nccl' else '8989'
if args.node_rank==0:
args = record_args_before_build(args)
mp.spawn(main, nprocs=args.gpus_per_node, args=(args,))