-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.py
100 lines (77 loc) · 2.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pyrallis
import os.path as osp
import numpy as np
from configs import TrainConfig
from configs.prompts import get_avatar_list
from core.trainer import Trainer
def update_path(path, dirname):
if path is None:
return None
if '@' in path:
path = path.replace('@', dirname)
return path
def parse_indices(num_prompts, opts):
if len(opts) == 0:
prompt_indices = [i for i in range(1, num_prompts+1)]
else:
opts = opts[0]
if '-' in opts:
start_i, end_i = list(map(int, opts.split('-')))
prompt_indices = [i for i in range(start_i, end_i + 1)]
else:
prompt_indices = eval(opts)
if isinstance(prompt_indices, int):
prompt_indices = [prompt_indices]
return prompt_indices
def run(cfg: TrainConfig):
trainer = Trainer(cfg)
if cfg.log.eval_only:
trainer.full_eval()
elif cfg.log.pretrain_only:
trainer.pretrain()
elif cfg.log.nerf2gs:
trainer.pretrain_nerf2gs()
else:
trainer.train()
def run_multiple(cfg: TrainConfig):
set_name, *opts = cfg.guide.text_set.split(',', maxsplit=1)
prompts = get_avatar_list(set_name)
prompt_indices = parse_indices(num_prompts=len(prompts), opts=opts)
assert '@' in cfg.log.exp_name, 'exp_name must contain "@" for inserting text prompts'
default_cfgs = {
'exp_name': cfg.log.exp_name,
'from_nerf': cfg.render.from_nerf,
'ckpt': cfg.optim.ckpt,
'ckpt_extra': cfg.optim.ckpt_extra,
'smpl_age': cfg.prompt.smpl_age,
'smpl_gender': cfg.prompt.smpl_gender,
}
for k in (np.array(prompt_indices) - 1).tolist():
metadata = prompts[k]
if type(metadata) is str:
cfg.guide.text = prompts[k]
cfg.prompt.smpl_age = default_cfgs['smpl_age']
cfg.prompt.smpl_gender = default_cfgs['smpl_gender']
elif type(metadata) is dict:
cfg.guide.text = prompts[k]['text_prompt']
cfg.prompt.smpl_age = metadata.get('smpl_age', default_cfgs['smpl_age'])
cfg.prompt.smpl_gender = metadata.get('smpl_gender', default_cfgs['smpl_gender'])
dirname = '{:04d}_{}'.format(k+1, cfg.guide.text.replace(' ', '_')[:50])
cfg.log.exp_name = update_path(default_cfgs['exp_name'], dirname)
cfg.render.from_nerf = update_path(default_cfgs['from_nerf'], dirname)
cfg.optim.ckpt = update_path(default_cfgs['ckpt'], dirname)
cfg.optim.ckpt_extra = update_path(default_cfgs['ckpt_extra'], dirname)
try:
run(cfg)
except Exception as e:
print(e)
@pyrallis.wrap()
def main(cfg: TrainConfig):
if cfg.guide.text_set is None:
# single prompt
run(cfg)
else:
# multiple prompts
run_multiple(cfg)
if __name__ == '__main__':
main()