-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
100 lines (73 loc) · 3.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from os import path
from datetime import datetime
import sys
import json
import torch
class Names:
def __init__(self):
self.images = 'images'
self.ground_truths = 'ground_truth'
self.densities = 'densities'
self.train_data = 'train_data'
self.test_data = 'test_data'
self.shanghaitech_A = 'shanghaitech_A'
self.shanghaitech_B = 'shanghaitech_B'
self.datasets = 'datasets'
self.output_folder = 'output'
self.log_file = 'log'
self.model_file = 'model.pth'
self.args_files = 'args.json'
names = Names()
def get_dataset_dirs(dataset_name):
dataset_dir = path.join(path.dirname(__file__), names.datasets, dataset_name)
images_train_dir = path.join(dataset_dir, names.train_data, names.images)
ground_truth_train_dir = path.join(dataset_dir, names.train_data, names.ground_truths)
densities_train_dir = path.join(dataset_dir, names.train_data, names.densities)
images_test_dir = path.join(dataset_dir, names.test_data, names.images)
ground_truth_test_dir = path.join(dataset_dir, names.test_data, names.ground_truths)
densities_test_dir = path.join(dataset_dir, names.test_data, names.densities)
dirs_dict = {names.train_data:[images_train_dir, ground_truth_train_dir, densities_train_dir],
names.test_data:[images_test_dir, ground_truth_test_dir, densities_test_dir]}
return dirs_dict
class Logger(object):
def __init__(self, path, out_type):
self.terminal = out_type
self.log = open(path, "a", buffering=1)
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
def redirect_output(path):
sys.stdout = Logger(path, sys.stdout)
sys.stderr = Logger(path, sys.stderr)
def get_root_path():
return path.dirname(path.realpath(__file__))
def get_time():
return datetime.today().strftime('%d_%m_%Y__%H_%M_%S_%f')[:-3]
def get_shanghai_image_name(sample):
return 'IMG_{}.jpg'.format(sample)
def get_shanghai_gt_name(sample_idx):
return 'GT_IMG_{}.mat'.format(sample_idx)
def get_density_name(sample_idx):
return 'IMG_{}.npy'.format(sample_idx)
def log_args(output_dir, args):
with open(path.join(output_dir, names.args_files), 'w') as f:
json.dump(args.__dict__, f, indent=4, separators=(',', ': '))
def load_checkpoint(model, optimizer, filename='checkpoint.pth.tar'):
start_epoch = 0
if path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
checkpoint = torch.load(filename)
start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(filename, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(filename))
return model, optimizer, start_epoch