-
Notifications
You must be signed in to change notification settings - Fork 16
/
monodepth2.py
63 lines (49 loc) · 1.98 KB
/
monodepth2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from __future__ import division
import yaml
import os
import argparse
import numpy as np
import logging
from utils.std_capturing import *
from model.monodepth2_learner import MonoDepth2Learner
def _cli_train(config, output_dir, args):
with open(os.path.join(output_dir, 'config.yml'), 'w') as f:
yaml.dump(config, f, default_flow_style=False)
monodepth2_learner = MonoDepth2Learner(**config)
monodepth2_learner.train(output_dir)
print('Monodepth2 Training Done ...')
def _cli_test(config, output_dir, args):
monodepth2_learner = MonoDepth2Learner(**config)
monodepth2_learner.test(output_dir)
print('Monodepth2 Test Done ...')
def _cli_eval(config, ckpt_name, args):
monodepth2_learner = MonoDepth2Learner(**config)
monodepth2_learner.eval(ckpt_name,args.eval_type)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest='command')
# Training command
p_train = subparsers.add_parser('train')
p_train.add_argument('config', type=str)
p_train.add_argument('ckpt_name', type=str)
p_train.set_defaults(func=_cli_train)
# Test command
p_test = subparsers.add_parser('test')
p_test.add_argument('config', type=str)
p_test.add_argument('ckpt_name', type=str)
p_test.set_defaults(func=_cli_test)
# Evaluate command
p_eval = subparsers.add_parser('eval')
p_eval.add_argument('config',type=str)
p_eval.add_argument('ckpt_name',type=str)
p_eval.add_argument('eval_type',type=str,default='depth',help='pose,depth')
p_eval.set_defaults(func=_cli_eval)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f)
output_dir = os.path.join(config['model']['root_dir'], args.ckpt_name)
if not os.path.exists(output_dir):
os.mkdir(output_dir)
with capture_outputs(os.path.join(output_dir, 'log')):
logging.info('Running command {}'.format(args.command.upper()))
args.func(config, output_dir, args)