-
Notifications
You must be signed in to change notification settings - Fork 58
/
supervised.py
181 lines (133 loc) · 6.76 KB
/
supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import argparse
import logging
import os
import pprint
import torch
import numpy as np
from torch import nn
import torch.distributed as dist
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.optim import SGD
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import yaml
from dataset.acdc import ACDCDataset
from model.unet import UNet
from util.classes import CLASSES
from util.utils import AverageMeter, count_params, init_log, DiceLoss
from util.dist_helper import setup_distributed
parser = argparse.ArgumentParser(description='Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--labeled-id-path', type=str, required=True)
parser.add_argument('--unlabeled-id-path', type=str, default=None)
parser.add_argument('--save-path', type=str, required=True)
parser.add_argument('--local_rank', default=0, type=int)
parser.add_argument('--port', default=None, type=int)
def main():
args = parser.parse_args()
cfg = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
logger = init_log('global', logging.INFO)
logger.propagate = 0
rank, world_size = setup_distributed(port=args.port)
if rank == 0:
all_args = {**cfg, **vars(args), 'ngpus': world_size}
logger.info('{}\n'.format(pprint.pformat(all_args)))
writer = SummaryWriter(args.save_path)
os.makedirs(args.save_path, exist_ok=True)
cudnn.enabled = True
cudnn.benchmark = True
model = UNet(in_chns=1, class_num=cfg['nclass'])
if rank == 0:
logger.info('Total params: {:.1f}M\n'.format(count_params(model)))
optimizer = SGD(model.parameters(), cfg['lr'], momentum=0.9, weight_decay=0.0001)
local_rank = int(os.environ["LOCAL_RANK"])
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model.cuda(local_rank)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], broadcast_buffers=False,
output_device=local_rank, find_unused_parameters=False)
criterion_ce = nn.CrossEntropyLoss()
criterion_dice = DiceLoss(n_classes=cfg['nclass'])
trainset = ACDCDataset(cfg['dataset'], cfg['data_root'], 'train_l', cfg['crop_size'], args.labeled_id_path)
valset = ACDCDataset(cfg['dataset'], cfg['data_root'], 'val')
trainsampler = torch.utils.data.distributed.DistributedSampler(trainset)
trainloader = DataLoader(trainset, batch_size=cfg['batch_size'],
pin_memory=True, num_workers=1, drop_last=True, sampler=trainsampler)
valsampler = torch.utils.data.distributed.DistributedSampler(valset)
valloader = DataLoader(valset, batch_size=1, pin_memory=True, num_workers=1,
drop_last=False, sampler=valsampler)
iters = 0
total_iters = len(trainloader) * cfg['epochs']
previous_best = 0.0
epoch = -1
if os.path.exists(os.path.join(args.save_path, 'latest.pth')):
checkpoint = torch.load(os.path.join(args.save_path, 'latest.pth'))
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
previous_best = checkpoint['previous_best']
if rank == 0:
logger.info('************ Load from checkpoint at epoch %i\n' % epoch)
for epoch in range(epoch + 1, cfg['epochs']):
if rank == 0:
logger.info('===========> Epoch: {:}, LR: {:.5f}, Previous best: {:.2f}'.format(
epoch, optimizer.param_groups[0]['lr'], previous_best))
model.train()
total_loss = AverageMeter()
trainsampler.set_epoch(epoch)
for i, (img, mask) in enumerate(trainloader):
img, mask = img.cuda(), mask.cuda()
pred = model(img)
loss = (criterion_ce(pred, mask) + criterion_dice(pred.softmax(dim=1), mask.unsqueeze(1).float())) / 2.0
torch.distributed.barrier()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss.update(loss.item())
iters = epoch * len(trainloader) + i
lr = cfg['lr'] * (1 - iters / total_iters) ** 0.9
optimizer.param_groups[0]["lr"] = lr
if rank == 0:
writer.add_scalar('train/loss_all', loss.item(), iters)
writer.add_scalar('train/loss_x', loss.item(), iters)
if (i % (max(2, len(trainloader) // 8)) == 0) and (rank == 0):
logger.info('Iters: {:}, Total loss: {:.3f}'.format(i, total_loss.avg))
model.eval()
dice_class = [0] * 3
with torch.no_grad():
for img, mask in valloader:
img, mask = img.cuda(), mask.cuda()
h, w = img.shape[-2:]
img = F.interpolate(img, (cfg['crop_size'], cfg['crop_size']), mode='bilinear', align_corners=False)
img = img.permute(1, 0, 2, 3)
pred = model(img)
pred = F.interpolate(pred, (h, w), mode='bilinear', align_corners=False)
pred = pred.argmax(dim=1).unsqueeze(0)
for cls in range(1, cfg['nclass']):
inter = ((pred == cls) * (mask == cls)).sum().item()
union = (pred == cls).sum().item() + (mask == cls).sum().item()
dice_class[cls-1] += 2.0 * inter / union
dice_class = [dice * 100.0 / len(valloader) for dice in dice_class]
mean_dice = sum(dice_class) / len(dice_class)
if rank == 0:
for (cls_idx, dice) in enumerate(dice_class):
logger.info('***** Evaluation ***** >>>> Class [{:} {:}] Dice: '
'{:.2f}'.format(cls_idx, CLASSES[cfg['dataset']][cls_idx], dice))
logger.info('***** Evaluation ***** >>>> MeanDice: {:.2f}\n'.format(mean_dice))
writer.add_scalar('eval/MeanDice', mean_dice, epoch)
for i, dice in enumerate(dice_class):
writer.add_scalar('eval/%s_dice' % (CLASSES[cfg['dataset']][i]), dice, epoch)
is_best = mean_dice > previous_best
previous_best = max(mean_dice, previous_best)
if rank == 0:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'previous_best': previous_best,
}
torch.save(checkpoint, os.path.join(args.save_path, 'latest.pth'))
if is_best:
torch.save(checkpoint, os.path.join(args.save_path, 'best.pth'))
if __name__ == '__main__':
main()