Skip to content

Commit 588f364

Browse files
committed
gitignore updated
1 parent 87c7d75 commit 588f364

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ __pycache__/
44
*$py.class
55
.DS_Store
66

7+
# Pretrained Models
8+
/checkpoints
9+
710
# C extensions
811
*.so
912

File renamed without changes.

net/pspnet.py Net/pspnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33
from torch.nn import functional as F
44

5-
import extractors
5+
import Net.extractors as extractors
66

77

88
class PSPModule(nn.Module):

train.py

+33-2
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,36 @@
1111
from Datasets.lip import LIP
1212
from Datasets.lip import LIPWithClass
1313
from matplotlib import pyplot as plt
14+
from Net.pspnet import PSPNet
1415

1516

1617
# Models
1718
models = {
18-
'torch_resnet50': models.segmentation.fcn_resnet50(pretrained=True, progress=True, num_classes=21).eval()
19+
'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'),
20+
'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'),
21+
'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
22+
'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
23+
'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
24+
'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
25+
'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
1926
}
2027

2128

29+
def build_network(snapshot, backend, gpu=False):
30+
epoch = 0
31+
backend = backend.lower()
32+
net = models[backend]()
33+
net = nn.DataParallel(net)
34+
if snapshot is not None:
35+
_, epoch = os.path.basename(snapshot).split('_')
36+
epoch = int(epoch)
37+
net.load_state_dict(torch.load(snapshot))
38+
logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot))
39+
if gpu:
40+
net = net.cuda()
41+
return net, epoch
42+
43+
2244
def get_transform():
2345
transform_image_list = [
2446
transforms.Resize((256, 256), 3),
@@ -46,6 +68,9 @@ def parse_arguments():
4668
parser.add_argument('-b', '--batch-size', help='Set size of the batch', default=32, type=int)
4769
parser.add_argument('-d', '--data-path', help='Set path of dataset', default='.', type=str)
4870
parser.add_argument('-n', '--num-class', help='Set number of segmentation classes', default=20, type=int)
71+
parser.add_argument('-be', '--backend', help='Set Feature extractor', default='densenet', type=str)
72+
parser.add_argument('-s', '--snapshot', help='Set path to pre-trained weights', default=None, type=str)
73+
parser.add_argument('-g', '--gpu', help='Set gpu [True / False]', default=False, type=bool)
4974

5075
# Mutually Exclusive Group 1 (Train / Eval)
5176
train_eval_parser = parser.add_mutually_exclusive_group(required=False)
@@ -89,14 +114,20 @@ def run_trained_model(model_ft, train_loader):
89114

90115

91116
if __name__ == '__main__':
117+
# Parse Arguments
92118
args = parse_arguments()
119+
120+
# Make directory to store trained weights
121+
models_path = os.path.join('./checkpoints', args.backend)
122+
os.makedirs(models_path, exist_ok=True)
123+
93124
train_loader = get_dataloader(args.data_path, train=args.train, batch_size=args.batch_size, num_class=args.num_class)
94125

95126
# Debug
96127
for data in train_loader:
97128
x, y, cls = data
98129
break
99130

100-
run_trained_model(models['torch_resnet50'], train_loader)
131+
# run_trained_model(models['torch_resnet50'], train_loader)
101132

102133

0 commit comments

Comments
 (0)