-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
68 lines (62 loc) · 2.68 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os, sys
from opt import get_opts
import torch
# system
from system import NeRFSystem, NeRF3DSystem, NeRF3DSystem_ib, EG3DSystem
# pytorch-lightning
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.plugins import DDPPlugin
import torch.utils.tensorboard as tensorboard
if __name__ == '__main__':
hparams = get_opts()
if hparams.mode == 'd3':
print("Use NeRF_3D")
system = NeRF3DSystem(hparams)
elif hparams.mode == "d3_ib":
print("Use NeRF_3D Img Batch")
system = NeRF3DSystem_ib(hparams)
elif hparams.mode == 'eg3d':
system = EG3DSystem(hparams)
else:
system = NeRFSystem(hparams)
checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(f'ckpts/{hparams.exp_name}',
'{epoch:d}'),
monitor='val/loss',
mode='min',
save_top_k=100,)
logger = TestTubeLogger(
save_dir="logs",
name=hparams.exp_name,
debug=False,
create_git_tag=False
)
if hparams.is_use_mixed_precision:
trainer = Trainer(max_epochs=hparams.num_epochs,
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=hparams.ckpt_path,
logger=logger,
weights_summary=None,
progress_bar_refresh_rate=1,
gpus=hparams.num_gpus,
distributed_backend='ddp' if hparams.num_gpus>1 else None,
plugins=DDPPlugin(find_unused_parameters=True),
num_sanity_val_steps=1,
benchmark=True,
precision=16,
profiler=hparams.num_gpus==1)
else:
trainer = Trainer(max_epochs=hparams.num_epochs,
checkpoint_callback=checkpoint_callback,
resume_from_checkpoint=hparams.ckpt_path,
logger=logger,
weights_summary=None,
progress_bar_refresh_rate=1,
gpus=hparams.num_gpus,
distributed_backend='ddp' if hparams.num_gpus>1 else None,
plugins=DDPPlugin(find_unused_parameters=True),
num_sanity_val_steps=1,
benchmark=True,
profiler=hparams.num_gpus==1)
trainer.fit(system)