Skip to content

Commit 53fdb11

Browse files
committed
Models Integrated
1 parent 9c5b2e7 commit 53fdb11

File tree

3 files changed

+68
-29
lines changed

3 files changed

+68
-29
lines changed

Net/extractors.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,13 @@ def forward(self, x):
162162
class _DenseLayer(nn.Sequential):
163163
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
164164
super(_DenseLayer, self).__init__()
165-
self.add_module('norm.1', nn.BatchNorm2d(num_input_features)),
166-
self.add_module('relu.1', nn.ReLU(inplace=True)),
167-
self.add_module('conv.1', nn.Conv2d(num_input_features, bn_size *
165+
self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
166+
self.add_module('relu1', nn.ReLU(inplace=True)),
167+
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
168168
growth_rate, kernel_size=1, stride=1, bias=False)),
169-
self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
170-
self.add_module('relu.2', nn.ReLU(inplace=True)),
171-
self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate,
169+
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
170+
self.add_module('relu2', nn.ReLU(inplace=True)),
171+
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
172172
kernel_size=3, stride=1, padding=1, bias=False)),
173173
self.drop_rate = drop_rate
174174

@@ -332,33 +332,38 @@ def densenet(pretrained=True):
332332
def resnet18(pretrained=True):
333333
model = ResNet(BasicBlock, [2, 2, 2, 2])
334334
if pretrained:
335-
load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']))
335+
pass
336+
# load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']))
336337
return model
337338

338339

339340
def resnet34(pretrained=True):
340341
model = ResNet(BasicBlock, [3, 4, 6, 3])
341342
if pretrained:
342-
load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34']))
343+
pass
344+
# load_weights_sequential(model, model_zoo.load_url(model_urls['resnet34']))
343345
return model
344346

345347

346348
def resnet50(pretrained=True):
347349
model = ResNet(Bottleneck, [3, 4, 6, 3])
348350
if pretrained:
349-
load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']))
351+
pass
352+
# load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']))
350353
return model
351354

352355

353356
def resnet101(pretrained=True):
354357
model = ResNet(Bottleneck, [3, 4, 23, 3])
355358
if pretrained:
356-
load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101']))
359+
pass
360+
# load_weights_sequential(model, model_zoo.load_url(model_urls['resnet101']))
357361
return model
358362

359363

360364
def resnet152(pretrained=True):
361365
model = ResNet(Bottleneck, [3, 8, 36, 3])
362366
if pretrained:
363-
load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152']))
367+
pass
368+
# load_weights_sequential(model, model_zoo.load_url(model_urls['resnet152']))
364369
return model

Net/pspnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def forward(self, x):
4040

4141

4242
class PSPNet(nn.Module):
43-
def __init__(self, n_classes=18, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet34',
43+
def __init__(self, n_classes=20, sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet34',
4444
pretrained=True):
4545
super().__init__()
4646
self.feats = getattr(extractors, backend)(pretrained)

train.py

+51-17
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616

1717
# Models
1818
models = {
19-
'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'),
20-
'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'),
21-
'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
22-
'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
23-
'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
24-
'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
25-
'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
19+
'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet', n_classes=20),
20+
'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet', n_classes=20),
21+
'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18', n_classes=20),
22+
'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34', n_classes=20),
23+
'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50', n_classes=20),
24+
'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101', n_classes=20),
25+
'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152', n_classes=20)
2626
}
2727

2828

@@ -64,13 +64,16 @@ def parse_arguments():
6464
parser = argparse.ArgumentParser(description='Human Parsing')
6565

6666
# Add more arguments based on requirements later
67-
parser.add_argument('-e', '--epochs', help='Set number of train epochs', default=100, type=int)
67+
parser.add_argument('-e', '--epochs', help='Set number of train epochs', default=30, type=int)
6868
parser.add_argument('-b', '--batch-size', help='Set size of the batch', default=32, type=int)
6969
parser.add_argument('-d', '--data-path', help='Set path of dataset', default='.', type=str)
7070
parser.add_argument('-n', '--num-class', help='Set number of segmentation classes', default=20, type=int)
7171
parser.add_argument('-be', '--backend', help='Set Feature extractor', default='densenet', type=str)
7272
parser.add_argument('-s', '--snapshot', help='Set path to pre-trained weights', default=None, type=str)
73-
parser.add_argument('-g', '--gpu', help='Set gpu [True / False]', default=False, type=bool)
73+
parser.add_argument('-g', '--gpu', help='Set gpu [True / False]', default=False, action='store_true')
74+
parser.add_argument('-lr', '--start-lr', help='Set starting learning rate', default=0.001, type=float)
75+
parser.add_argument('-a', '--alpha', help='Set coefficient for classification loss term', default=1.0, type=float)
76+
parser.add_argument('-m', '--milestones', type=str, default='10,20,30', help='Milestones for LR decreasing')
7477

7578
# Mutually Exclusive Group 1 (Train / Eval)
7679
train_eval_parser = parser.add_mutually_exclusive_group(required=False)
@@ -121,13 +124,44 @@ def run_trained_model(model_ft, train_loader):
121124
models_path = os.path.join('./checkpoints', args.backend)
122125
os.makedirs(models_path, exist_ok=True)
123126

124-
train_loader = get_dataloader(args.data_path, train=args.train, batch_size=args.batch_size, num_class=args.num_class)
125-
126-
# Debug
127-
for data in train_loader:
128-
x, y, cls = data
129-
break
130-
131-
# run_trained_model(models['torch_resnet50'], train_loader)
127+
net, starting_epoch = build_network(args.snapshot, args.backend)
128+
optimizer = optim.Adam(net.parameters(), lr=args.start_lr)
129+
scheduler = MultiStepLR(optimizer, milestones=[int(x) for x in args.milestones.split(',')])
132130

131+
train_loader = get_dataloader(args.data_path, train=args.train, batch_size=args.batch_size, num_class=args.num_class)
133132

133+
for epoch in range(1+starting_epoch, 1+starting_epoch+args.epochs):
134+
seg_criterion = nn.NLLLoss(weight=None)
135+
cls_criterion = nn.BCEWithLogitsLoss(weight=None)
136+
epoch_losses = []
137+
net.train()
138+
139+
for count, (img, gt, gt_cls) in enumerate(train_loader):
140+
# Input data
141+
if args.gpu:
142+
img, gt, gt_cls = img.cuda(), gt.cuda(), gt_cls.cuda()
143+
144+
img, gt, gt_cls = img, gt.long(), gt_cls.float()
145+
146+
# Forward pass
147+
out, out_cls = net(img)
148+
seg_loss, cls_loss = seg_criterion(out, gt), cls_criterion(out_cls, gt_cls)
149+
loss = seg_loss + args.alpha * cls_loss
150+
151+
# Backward
152+
optimizer.zero_grad()
153+
loss.backward()
154+
optimizer.step()
155+
156+
# Log
157+
epoch_losses.append(loss.item())
158+
status = '[{0}] step = {1}/{2}, loss = {3:0.4f} avg = {4:0.4f}, LR = {5:0.7f}'.format(
159+
epoch, count, len(train_loader),
160+
loss.item(), np.mean(epoch_losses), scheduler.get_lr()[0])
161+
print(status)
162+
163+
scheduler.step()
164+
if epoch % 10 == 0:
165+
torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch)])))
166+
167+
torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", 'last'])))

0 commit comments

Comments
 (0)