|
11 | 11 | from Datasets.lip import LIP
|
12 | 12 | from Datasets.lip import LIPWithClass
|
13 | 13 | from matplotlib import pyplot as plt
|
| 14 | +from Net.pspnet import PSPNet |
14 | 15 |
|
15 | 16 |
|
16 | 17 | # Models
|
17 | 18 | models = {
|
18 |
| - 'torch_resnet50': models.segmentation.fcn_resnet50(pretrained=True, progress=True, num_classes=21).eval() |
| 19 | + 'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'), |
| 20 | + 'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'), |
| 21 | + 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), |
| 22 | + 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), |
| 23 | + 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), |
| 24 | + 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), |
| 25 | + 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') |
19 | 26 | }
|
20 | 27 |
|
21 | 28 |
|
| 29 | +def build_network(snapshot, backend, gpu=False): |
| 30 | + epoch = 0 |
| 31 | + backend = backend.lower() |
| 32 | + net = models[backend]() |
| 33 | + net = nn.DataParallel(net) |
| 34 | + if snapshot is not None: |
| 35 | + _, epoch = os.path.basename(snapshot).split('_') |
| 36 | + epoch = int(epoch) |
| 37 | + net.load_state_dict(torch.load(snapshot)) |
| 38 | + logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot)) |
| 39 | + if gpu: |
| 40 | + net = net.cuda() |
| 41 | + return net, epoch |
| 42 | + |
| 43 | + |
22 | 44 | def get_transform():
|
23 | 45 | transform_image_list = [
|
24 | 46 | transforms.Resize((256, 256), 3),
|
@@ -46,6 +68,9 @@ def parse_arguments():
|
46 | 68 | parser.add_argument('-b', '--batch-size', help='Set size of the batch', default=32, type=int)
|
47 | 69 | parser.add_argument('-d', '--data-path', help='Set path of dataset', default='.', type=str)
|
48 | 70 | parser.add_argument('-n', '--num-class', help='Set number of segmentation classes', default=20, type=int)
|
| 71 | + parser.add_argument('-be', '--backend', help='Set Feature extractor', default='densenet', type=str) |
| 72 | + parser.add_argument('-s', '--snapshot', help='Set path to pre-trained weights', default=None, type=str) |
| 73 | + parser.add_argument('-g', '--gpu', help='Set gpu [True / False]', default=False, type=bool) |
49 | 74 |
|
50 | 75 | # Mutually Exclusive Group 1 (Train / Eval)
|
51 | 76 | train_eval_parser = parser.add_mutually_exclusive_group(required=False)
|
@@ -89,14 +114,20 @@ def run_trained_model(model_ft, train_loader):
|
89 | 114 |
|
90 | 115 |
|
91 | 116 | if __name__ == '__main__':
|
| 117 | + # Parse Arguments |
92 | 118 | args = parse_arguments()
|
| 119 | + |
| 120 | + # Make directory to store trained weights |
| 121 | + models_path = os.path.join('./checkpoints', args.backend) |
| 122 | + os.makedirs(models_path, exist_ok=True) |
| 123 | + |
93 | 124 | train_loader = get_dataloader(args.data_path, train=args.train, batch_size=args.batch_size, num_class=args.num_class)
|
94 | 125 |
|
95 | 126 | # Debug
|
96 | 127 | for data in train_loader:
|
97 | 128 | x, y, cls = data
|
98 | 129 | break
|
99 | 130 |
|
100 |
| - run_trained_model(models['torch_resnet50'], train_loader) |
| 131 | + # run_trained_model(models['torch_resnet50'], train_loader) |
101 | 132 |
|
102 | 133 |
|
0 commit comments