|
16 | 16 |
|
17 | 17 | # Models
|
18 | 18 | models = {
|
19 |
| - 'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'), |
20 |
| - 'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'), |
21 |
| - 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), |
22 |
| - 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), |
23 |
| - 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), |
24 |
| - 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), |
25 |
| - 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') |
| 19 | + 'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet', n_classes=20), |
| 20 | + 'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet', n_classes=20), |
| 21 | + 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18', n_classes=20), |
| 22 | + 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34', n_classes=20), |
| 23 | + 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50', n_classes=20), |
| 24 | + 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101', n_classes=20), |
| 25 | + 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152', n_classes=20) |
26 | 26 | }
|
27 | 27 |
|
28 | 28 |
|
@@ -64,13 +64,16 @@ def parse_arguments():
|
64 | 64 | parser = argparse.ArgumentParser(description='Human Parsing')
|
65 | 65 |
|
66 | 66 | # Add more arguments based on requirements later
|
67 |
| - parser.add_argument('-e', '--epochs', help='Set number of train epochs', default=100, type=int) |
| 67 | + parser.add_argument('-e', '--epochs', help='Set number of train epochs', default=30, type=int) |
68 | 68 | parser.add_argument('-b', '--batch-size', help='Set size of the batch', default=32, type=int)
|
69 | 69 | parser.add_argument('-d', '--data-path', help='Set path of dataset', default='.', type=str)
|
70 | 70 | parser.add_argument('-n', '--num-class', help='Set number of segmentation classes', default=20, type=int)
|
71 | 71 | parser.add_argument('-be', '--backend', help='Set Feature extractor', default='densenet', type=str)
|
72 | 72 | parser.add_argument('-s', '--snapshot', help='Set path to pre-trained weights', default=None, type=str)
|
73 |
| - parser.add_argument('-g', '--gpu', help='Set gpu [True / False]', default=False, type=bool) |
| 73 | + parser.add_argument('-g', '--gpu', help='Set gpu [True / False]', default=False, action='store_true') |
| 74 | + parser.add_argument('-lr', '--start-lr', help='Set starting learning rate', default=0.001, type=float) |
| 75 | + parser.add_argument('-a', '--alpha', help='Set coefficient for classification loss term', default=1.0, type=float) |
| 76 | + parser.add_argument('-m', '--milestones', type=str, default='10,20,30', help='Milestones for LR decreasing') |
74 | 77 |
|
75 | 78 | # Mutually Exclusive Group 1 (Train / Eval)
|
76 | 79 | train_eval_parser = parser.add_mutually_exclusive_group(required=False)
|
@@ -121,13 +124,44 @@ def run_trained_model(model_ft, train_loader):
|
121 | 124 | models_path = os.path.join('./checkpoints', args.backend)
|
122 | 125 | os.makedirs(models_path, exist_ok=True)
|
123 | 126 |
|
124 |
| - train_loader = get_dataloader(args.data_path, train=args.train, batch_size=args.batch_size, num_class=args.num_class) |
125 |
| - |
126 |
| - # Debug |
127 |
| - for data in train_loader: |
128 |
| - x, y, cls = data |
129 |
| - break |
130 |
| - |
131 |
| - # run_trained_model(models['torch_resnet50'], train_loader) |
| 127 | + net, starting_epoch = build_network(args.snapshot, args.backend) |
| 128 | + optimizer = optim.Adam(net.parameters(), lr=args.start_lr) |
| 129 | + scheduler = MultiStepLR(optimizer, milestones=[int(x) for x in args.milestones.split(',')]) |
132 | 130 |
|
| 131 | + train_loader = get_dataloader(args.data_path, train=args.train, batch_size=args.batch_size, num_class=args.num_class) |
133 | 132 |
|
| 133 | + for epoch in range(1+starting_epoch, 1+starting_epoch+args.epochs): |
| 134 | + seg_criterion = nn.NLLLoss(weight=None) |
| 135 | + cls_criterion = nn.BCEWithLogitsLoss(weight=None) |
| 136 | + epoch_losses = [] |
| 137 | + net.train() |
| 138 | + |
| 139 | + for count, (img, gt, gt_cls) in enumerate(train_loader): |
| 140 | + # Input data |
| 141 | + if args.gpu: |
| 142 | + img, gt, gt_cls = img.cuda(), gt.cuda(), gt_cls.cuda() |
| 143 | + |
| 144 | + img, gt, gt_cls = img, gt.long(), gt_cls.float() |
| 145 | + |
| 146 | + # Forward pass |
| 147 | + out, out_cls = net(img) |
| 148 | + seg_loss, cls_loss = seg_criterion(out, gt), cls_criterion(out_cls, gt_cls) |
| 149 | + loss = seg_loss + args.alpha * cls_loss |
| 150 | + |
| 151 | + # Backward |
| 152 | + optimizer.zero_grad() |
| 153 | + loss.backward() |
| 154 | + optimizer.step() |
| 155 | + |
| 156 | + # Log |
| 157 | + epoch_losses.append(loss.item()) |
| 158 | + status = '[{0}] step = {1}/{2}, loss = {3:0.4f} avg = {4:0.4f}, LR = {5:0.7f}'.format( |
| 159 | + epoch, count, len(train_loader), |
| 160 | + loss.item(), np.mean(epoch_losses), scheduler.get_lr()[0]) |
| 161 | + print(status) |
| 162 | + |
| 163 | + scheduler.step() |
| 164 | + if epoch % 10 == 0: |
| 165 | + torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch)]))) |
| 166 | + |
| 167 | + torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", 'last']))) |
0 commit comments