-
Notifications
You must be signed in to change notification settings - Fork 7
/
run_cat_ethl.py
120 lines (100 loc) · 4.17 KB
/
run_cat_ethl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
import os
from torch.utils.data import ConcatDataset
from cat_net.options import Options
from cat_net.models import CATModel
from cat_net.datasets import tum_rgbd
from cat_net import experiment
### COMMAND LINE ARGUMENTS ###
parser = argparse.ArgumentParser()
parser.add_argument('stage', type=str, choices=['train', 'test', 'both'])
parser.add_argument('--resume', action='store_true')
args = parser.parse_args()
resume_from_epoch = 'latest' if args.resume else None
### CONFIGURATION ###
opts = Options()
opts.data_dir = '/media/m2-drive/datasets/ethl_dataset/raw'
opts.results_dir = '/media/raid5-array/experiments/cat-net/ethl_dataset'
opts.down_levels = 7
opts.innermost_kernel_size = (3, 4)
### SET TRAINING, VALIDATION AND TEST SETS ###
syn_seqs = ['ethl1', 'ethl2']
syn_conds = ['static', 'global', 'local', 'loc_glo', 'flash']
real_seqs = ['real_sync']
real_conds = ['global', 'local', 'flash']
canonical = syn_conds[0]
### ETHL/SYN to ETHL/SYN ###
for test_seq in syn_seqs:
train_seqs = syn_seqs.copy()
train_seqs.remove(test_seq)
val_seqs = [test_seq]
val_conds = [syn_conds[1]]
train_data = []
for seq in train_seqs:
for cond in syn_conds:
print('Train {}: {} --> {}'.format(seq, cond, canonical))
data = tum_rgbd.TorchDataset(
opts, seq, cond, canonical, opts.random_crop)
train_data.append(data)
train_data = ConcatDataset(train_data)
val_data = []
for seq in val_seqs:
for cond in val_conds:
print('Val {}: {} --> {}'.format(seq, cond, canonical))
data = tum_rgbd.TorchDataset(
opts, seq, cond, canonical, False)
val_data.append(data)
val_data = ConcatDataset(val_data)
### TRAIN / TEST ###
opts.experiment_name = '{}-test'.format(test_seq)
model = CATModel(opts)
if args.stage == 'train' or args.stage == 'both':
print(opts)
opts.save_txt()
experiment.train(opts, model, train_data, val_data,
opts.train_epochs, resume_from_epoch=resume_from_epoch)
if args.stage == 'test' or args.stage == 'both':
for cond in syn_conds:
print('Test {}: {} --> {}'.format(test_seq, cond, canonical))
expdir = os.path.join(opts.experiment_name, '{}-test'.format(cond))
test_data = tum_rgbd.TorchDataset(
opts, test_seq, cond, canonical, False)
experiment.test(opts, model, test_data, expdir=expdir,
save_loss=True, save_images=True)
### ETHL/SYN TO ETHL/REAL ###
for test_seq in real_seqs:
train_seqs = syn_seqs
val_seqs = [syn_seqs[0]]
val_conds = [syn_conds[1]]
train_data = []
for train_seq in train_seqs:
for cond in syn_conds:
print('Train {}: {} --> {}'.format(train_seq, cond, canonical))
data = tum_rgbd.TorchDataset(
opts, train_seq, cond, canonical, opts.random_crop)
train_data.append(data)
train_data = ConcatDataset(train_data)
val_data = []
for val_seq in val_seqs:
for cond in val_conds:
print('Val {}: {} --> {}'.format(val_seq, cond, canonical))
data = tum_rgbd.TorchDataset(
opts, val_seq, cond, canonical, False)
val_data.append(data)
val_data = ConcatDataset(val_data)
### TRAIN / TEST ###
opts.experiment_name = '{}-test'.format(test_seq)
model = CATModel(opts)
if args.stage == 'train' or args.stage == 'both':
print(opts)
opts.save_txt()
experiment.train(opts, model, train_data, val_data,
opts.train_epochs, resume_from_epoch=resume_from_epoch)
if args.stage == 'test' or args.stage == 'both':
for cond in real_conds:
print('Test {}: {} --> syn/{}'.format(test_seq, cond, canonical))
expdir = os.path.join(opts.experiment_name, '{}-test'.format(cond))
test_data = tum_rgbd.TorchDataset(
opts, test_seq, cond, cond, False) # cond,cond to avoid non-existent real_static
experiment.test(opts, model, test_data, expdir=expdir,
save_loss=True, save_images=True)