forked from jchengai/pluto
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_training.py
109 lines (92 loc) · 3.56 KB
/
run_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
from typing import Optional
import hydra
import numpy
import pytorch_lightning as pl
from nuplan.planning.script.builders.folder_builder import (
build_training_experiment_folder,
)
from nuplan.planning.script.builders.logging_builder import build_logger
from nuplan.planning.script.builders.worker_pool_builder import build_worker
from nuplan.planning.script.profiler_context_manager import ProfilerContextManager
from nuplan.planning.script.utils import set_default_path
from nuplan.planning.training.experiments.caching import cache_data
from omegaconf import DictConfig
from src.custom_training import (
TrainingEngine,
build_training_engine,
update_config_for_training,
)
logging.getLogger("numba").setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# If set, use the env. variable to overwrite the default dataset and experiment paths
set_default_path()
# If set, use the env. variable to overwrite the Hydra config
CONFIG_PATH = "./config"
CONFIG_NAME = "default_training"
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def main(cfg: DictConfig) -> Optional[TrainingEngine]:
"""
Main entrypoint for training/validation experiments.
:param cfg: omegaconf dictionary
"""
pl.seed_everything(cfg.seed, workers=True)
# Configure logger
build_logger(cfg)
# Override configs based on setup, and print config
update_config_for_training(cfg)
# Create output storage folder
build_training_experiment_folder(cfg=cfg)
# Build worker
worker = build_worker(cfg)
if cfg.py_func == "train":
# Build training engine
with ProfilerContextManager(
cfg.output_dir, cfg.enable_profiling, "build_training_engine"
):
engine = build_training_engine(cfg, worker)
# Run training
logger.info("Starting training...")
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "training"):
engine.trainer.fit(
model=engine.model,
datamodule=engine.datamodule,
ckpt_path=cfg.checkpoint,
)
return engine
if cfg.py_func == "validate":
# Build training engine
with ProfilerContextManager(
cfg.output_dir, cfg.enable_profiling, "build_training_engine"
):
engine = build_training_engine(cfg, worker)
# Run training
logger.info("Starting training...")
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "validate"):
engine.trainer.validate(
model=engine.model,
datamodule=engine.datamodule,
ckpt_path=cfg.checkpoint,
)
return engine
elif cfg.py_func == "test":
# Build training engine
with ProfilerContextManager(
cfg.output_dir, cfg.enable_profiling, "build_training_engine"
):
engine = build_training_engine(cfg, worker)
# Test model
logger.info("Starting testing...")
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "testing"):
engine.trainer.test(model=engine.model, datamodule=engine.datamodule)
return engine
elif cfg.py_func == "cache":
# Precompute and cache all features
logger.info("Starting caching...")
with ProfilerContextManager(cfg.output_dir, cfg.enable_profiling, "caching"):
cache_data(cfg=cfg, worker=worker)
return None
else:
raise NameError(f"Function {cfg.py_func} does not exist")
if __name__ == "__main__":
main()