-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
66 lines (53 loc) · 2.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import tqdm
from torch.utils.data import DataLoader
from dataset.Dataset import TestDataset
from config.config import DefaultConfig
import torch.nn.functional as F
import os
import imageio
from model.idea1 import CBAMUnet
def generate_model(args):
model_all = {'BaseNet': CBAMUnet(out_planes=args.num_classes)}
model = model_all[args.net_work]
model = torch.nn.DataParallel(model)
print("=> loading pretrained model '{}'".format(args.pretrained_model_path))
model.load_state_dict(torch.load(args.pretrained_model_path)['state_dict'])
return model
def test():
print('loading test data......')
args = DefaultConfig()
model = generate_model(args)
for dataset in tqdm.tqdm(args.testdataset, desc='Total TestSet', total=len(args.testdataset), position=0,
bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}'):
dataset_path = os.path.join(args.data, dataset)
dataset_test = TestDataset(dataset_path, scale=(args.crop_height, args.crop_width), mode='val')
dataloader_test = DataLoader(
dataset_test,
batch_size=1,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False
)
save_path = os.path.join(args.data, dataset,"output/")
# 判断结果
if not os.path.exists(save_path):
os.makedirs(save_path)
model.eval()
with torch.no_grad():
for i, (img, gt, name) in tqdm.tqdm(enumerate(dataloader_test), desc=dataset + ' - Test', total=len(dataloader_test), position=1, leave=False, bar_format='{desc:<30}{percentage:3.0f}%|{bar:50}{r_bar}'):
if torch.cuda.is_available() and args.use_gpu:
img = img.cuda()
gt = gt.cuda()
output = model(img)
out = F.upsample(output, size=gt.shape[2:], mode='bilinear', align_corners=False)
out = out.data.sigmoid().cpu().numpy().squeeze()
out = (out - out.min()) / (out.max() - out.min() + 1e-8)
#save_path = os.path.join(save_path, name)
path = save_path + "".join(name)
imageio.imwrite(path, out)
# Image.fromarray(((out > 0.5) * 255).astype(np.uint8)).save(os.path.join(save_path, name[0]))
if __name__ == '__main__':
test()
print('Done')