-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaunch.py
58 lines (50 loc) · 1.87 KB
/
launch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import warnings
import argparse
import sys
from pointrix.utils.config import load_config
from pointrix.engine.default_trainer import DefaultTrainer
from pointrix.logger.writer import logproject, Logger
from controller.gf import GFDensificationController
from model.model import GaussianFlow
from data.data import NerfiesDataset
from data.dnerf_data import DNeRFDataset
from data.custom_data import CustomDataset
from model.point import GaussianFlowPointCloud
from model.renderer import GaussianFlowRender
from model.camera import TimeCameraModel
from gui import GaussianFlowGUI
import taichi as ti
def main(args, extras) -> None:
ti.init(arch=ti.cuda)
warnings.filterwarnings("ignore")
cfg = load_config(args.config, cli_args=extras)
project_path = os.path.dirname(os.path.abspath(__file__))
logproject(project_path, os.path.join(cfg.exp_dir, 'project_file'), ['py', 'yaml'])
# try:
gaussian_trainer = DefaultTrainer(
cfg.trainer,
cfg.exp_dir,
cfg.name
)
if cfg.trainer.training:
gaussian_trainer.train_loop()
model_path = os.path.join(
cfg.exp_dir,
"chkpnt" + str(gaussian_trainer.global_step) + ".pth"
)
gaussian_trainer.save_model(path=model_path)
gaussian_trainer.test()
else:
gaussian_trainer.test(cfg.trainer.test_model_path)
Logger.print("\nTraining complete.")
# except:
# Logger.print_exception(show_locals=False)
# for hook in gaussian_trainer.hooks:
# hook.exception()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument("--smc_file", type=str, default=None)
args, extras = parser.parse_known_args()
main(args, extras)