-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
185 lines (151 loc) · 9.9 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os
import numpy as np
import argparse
import time
import paddle
from tqdm import tqdm
from functools import partial
from model import *
from data import *
from utils import *
from loss import get_monodepth_loss
def train(params):
"""Training loop."""
train_dataset = MonodepthDataset(params.data_path, params.filenames_train, params, params.dataset, params.mode, params.use_aug)
train_loader = paddle.io.DataLoader(train_dataset, batch_size=params.batch_size, shuffle=True, num_workers=params.num_threads)
val_dataset = MonodepthDataset(params.data_path, params.filenames_val, params, params.dataset, params.mode, False)
val_loader = paddle.io.DataLoader(val_dataset, batch_size=params.batch_size, shuffle=False, num_workers=params.num_threads)
model = MonodepthModel(params.encoder, params.do_stereo, params.use_deconv)
if params.checkpoint_path.endswith('.h5'):
load_tensorflow_weight(model, params.checkpoint_path)
elif params.checkpoint_path.endswith('.pdparams'):
model.load_dict(paddle.load(params.checkpoint_path))
num_training_samples = count_text_lines(params.filenames_train)
num_validation_samples = count_text_lines(params.filenames_val)
steps_per_epoch = np.ceil(num_training_samples / params.batch_size).astype(np.int32)
num_total_steps = params.num_epochs * steps_per_epoch
lr_scheduler = paddle.optimizer.lr.MultiStepDecay(params.learning_rate, milestones=[np.int32((3/5) * num_total_steps), np.int32((4/5) * num_total_steps)], gamma=0.5)
optim = paddle.optimizer.Adam(learning_rate=lr_scheduler, parameters=model.parameters())
loss_fn = partial(get_monodepth_loss,
alpha_image_loss=params.alpha_image_loss,
disp_gradient_loss_weight=params.disp_gradient_loss_weight,
lr_loss_weight=params.lr_loss_weight)
print("total number of training samples: {}".format(num_training_samples))
print("total number of training steps: {}".format(num_total_steps))
print("total number of validation samples: {}".format(num_validation_samples))
total_num_parameters = 0
for variable in model.parameters():
total_num_parameters += np.prod(variable.shape)
print("number of trainable parameters: {}".format(total_num_parameters))
start_time = time.time()
step = 0
best_validation_loss = float('inf')
for e in range(params.num_epochs):
losses = []
for left, right in iter(train_loader):
before_op_time = time.time()
step += 1
disp_left_est, disp_right_est = model(left, right)
loss = loss_fn(disp_left_est, disp_right_est, left, right)
optim.clear_grad()
loss.backward()
optim.step()
lr_scheduler.step()
duration = time.time() - before_op_time
losses.append(float(loss))
if step and step % 100 == 0:
examples_per_sec = params.batch_size / duration
time_sofar = (time.time() - start_time) / 3600
training_time_left = (num_total_steps / step - 1.0) * time_sofar
print_string = 'batch {:>6} | examples/s: {:4.2f} | loss: {:.5f} | time elapsed: {:.2f}h | time left: {:.2f}h'
print(print_string.format(step, examples_per_sec, sum(losses) / len(losses), time_sofar, training_time_left))
losses = []
if (e + 1) % params.save_epochs == 0:
paddle.save(model.state_dict(), os.path.join(params.log_directory, params.model_name, f'weight_epoch_{e}.pdparams'))
with paddle.no_grad():
val_losses = []
val_start_time = time.time()
for left, right in iter(val_loader):
disp_left_est, disp_right_est = model(left, right)
loss = loss_fn(disp_left_est, disp_right_est, left, right)
val_losses.append(float(loss))
val_cost_time = time.time() - val_start_time
val_loss = sum(val_losses) / len(val_losses)
print('epoch {:>3} | val loss: {:.5f} | time cost: {:.2f} s |'.format(e, val_loss, val_cost_time))
if val_loss < best_validation_loss:
best_validation_loss = val_loss
paddle.save(model.state_dict(), os.path.join(params.log_directory, params.model_name, f'best_val_weight.pdparams'))
def test(params):
"""Test function."""
test_dataset = MonodepthDataset(params.data_path, params.filenames_test, params, params.dataset, params.mode, False)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=params.batch_size, shuffle=False)
model = MonodepthModel(params.encoder, params.do_stereo, params.use_deconv)
if params.checkpoint_path.endswith('.h5'):
load_tensorflow_weight(model, params.checkpoint_path)
elif params.checkpoint_path.endswith('.pdparams'):
model.load_dict(paddle.load(params.checkpoint_path))
num_test_samples = count_text_lines(params.filenames_test)
print('now testing {} files'.format(num_test_samples))
disparities = []
disparities_pp = []
with paddle.no_grad():
for left in tqdm(iter(test_loader)):
B, _, C, H, W = left.shape
left = left.reshape((B * 2, C, H, W))
disp, _ = model(left)
disp = disp[0][:, 0]
disp = disp.reshape((B, 2, H, W)).transpose((1, 0, 2, 3))
disparities.append(disp[0].numpy())
disparities_pp.append(post_process_disparity(disp.numpy()))
disparities = np.concatenate(disparities, axis=0)
disparities_pp = np.concatenate(disparities_pp, axis=0)
print('done.')
print('writing disparities.')
if params.output_directory == '':
output_directory = os.path.dirname(params.checkpoint_path)
else:
output_directory = params.output_directory
if not os.path.exists(output_directory):
os.makedirs(output_directory)
np.save(os.path.join(output_directory, 'disparities.npy'), disparities)
np.save(os.path.join(output_directory, 'disparities_pp.npy'), disparities_pp)
print('done.')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, help='random seed.', default=42)
parser.add_argument('--mode', type=str, help='train or test', default='train')
parser.add_argument('--model_name', type=str, help='model name', default='monodepth')
parser.add_argument('--encoder', type=str, help='type of encoder, vgg or resnet50', default='resnet')
parser.add_argument('--dataset', type=str, help='dataset to train on, kitti, or cityscapes', default='kitti')
parser.add_argument('--data_path', type=str, help='path to the data', default='eigen/')
parser.add_argument('--filenames_train', type=str, help='path to the train filenames text file', default='filenames/eigen_train_files.txt')
parser.add_argument('--filenames_val', type=str, help='path to the val filenames text file', default='filenames/eigen_val_files.txt')
parser.add_argument('--filenames_test', type=str, help='path to the test filenames text file', default='filenames/eigen_test_files.txt')
parser.add_argument('--use_aug', type=int, help='whether to use augmentation in dataloading.', default=1)
parser.add_argument('--height', type=int, help='input height', default=256)
parser.add_argument('--width', type=int, help='input width', default=512)
parser.add_argument('--batch_size', type=int, help='batch size', default=8)
parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50)
parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4)
parser.add_argument('--lr_loss_weight', type=float, help='left-right consistency weight', default=1.0)
parser.add_argument('--alpha_image_loss', type=float, help='weight between SSIM and L1 in the image loss', default=0.85)
parser.add_argument('--disp_gradient_loss_weight', type=float, help='disparity smoothness weigth', default=0.1)
parser.add_argument('--do_stereo', help='if set, will train the stereo model', action='store_true')
parser.add_argument('--wrap_mode', type=str, help='bilinear sampler wrap mode, edge or border', default='border')
parser.add_argument('--use_deconv', help='if set, will use transposed convolutions', action='store_true')
parser.add_argument('--num_gpus', type=int, help='number of GPUs to use for training', default=1)
parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=4)
parser.add_argument('--output_directory', type=str, help='output directory for test disparities, if empty outputs to checkpoint folder', default='')
parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='logs')
parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true')
parser.add_argument('--full_summary', help='if set, will keep more data for each summary. Warning: the file can become very large', action='store_true')
parser.add_argument('--save_epochs', type=int, help='how many epochs to save a checkpoint.', default=5)
params = parser.parse_args()
setup_seed(params.seed)
if params.mode == 'train':
train(params)
elif params.mode == 'test':
test(params)
if __name__ == '__main__':
main()