forked from ashleve/lightning-hydra-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
184 lines (153 loc) · 6.71 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from typing import Any, Dict, List, Optional, Tuple
import hydra
import lightning as L
import rootutils
import os
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
from lightning.pytorch.loggers import Logger
from lightning.pytorch.strategies.strategy import Strategy
from omegaconf import DictConfig
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# ------------------------------------------------------------------------------------ #
# the setup_root above is equivalent to:
# - adding project root dir to PYTHONPATH
# (so you don't need to force user to install project as a package)
# (necessary before importing any local modules e.g. `from src import utils`)
# - setting up PROJECT_ROOT environment variable
# (which is used as a base for paths in "configs/paths/default.yaml")
# (this way all filepaths are the same no matter where you run the code)
# - loading environment variables from ".env" in root dir
#
# you can remove it if you:
# 1. either install project as a package or move entry files to project root dir
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
#
# more info: https://github.com/ashleve/rootutils
# ------------------------------------------------------------------------------------ #
from src import register_custom_omegaconf_resolvers, resolve_omegaconf_variable
from src.utils import (
RankedLogger,
extras,
get_metric_value,
instantiate_callbacks,
instantiate_loggers,
log_hyperparameters,
task_wrapper,
)
log = RankedLogger(__name__, rank_zero_only=True)
@task_wrapper
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
training.
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
failure. Useful for multiruns, saving info about the crash, etc.
:param cfg: A DictConfig configuration composed by Hydra.
:return: A tuple with metrics and dict with all instantiated objects.
"""
# set seed for random number generators in pytorch, numpy and python.random
if cfg.get("seed"):
L.seed_everything(cfg.seed, workers=True)
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
log.info(f"Instantiating model <{cfg.model._target_}>")
model: LightningModule = hydra.utils.instantiate(cfg.model)
log.info("Instantiating callbacks...")
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
log.info("Instantiating loggers...")
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
plugins = None
if "_target_" in cfg.environment:
log.info(f"Instantiating environment <{cfg.environment._target_}>")
plugins: ClusterEnvironment = hydra.utils.instantiate(cfg.environment)
strategy = getattr(cfg.trainer, "strategy", None)
if "_target_" in cfg.strategy:
log.info(f"Instantiating strategy <{cfg.strategy._target_}>")
strategy: Strategy = hydra.utils.instantiate(cfg.strategy)
if "mixed_precision" in strategy.__dict__ and getattr(strategy, "mixed_precision", None) is not None:
strategy.mixed_precision.param_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.param_dtype)
if getattr(cfg.strategy.mixed_precision, "param_dtype", None) is not None
else None
)
strategy.mixed_precision.reduce_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.reduce_dtype)
if getattr(cfg.strategy.mixed_precision, "reduce_dtype", None) is not None
else None
)
strategy.mixed_precision.buffer_dtype = (
resolve_omegaconf_variable(cfg.strategy.mixed_precision.buffer_dtype)
if getattr(cfg.strategy.mixed_precision, "buffer_dtype", None) is not None
else None
)
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = (
hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
plugins=plugins,
strategy=strategy,
)
if strategy is not None
else hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
plugins=plugins,
)
)
object_dict = {
"cfg": cfg,
"datamodule": datamodule,
"model": model,
"callbacks": callbacks,
"logger": logger,
"trainer": trainer,
}
if logger:
log.info("Logging hyperparameters!")
log_hyperparameters(object_dict)
if cfg.get("train"):
log.info("Starting training!")
ckpt_path = None
if cfg.get("ckpt_path") and os.path.exists(cfg.get("ckpt_path")):
ckpt_path = cfg.get("ckpt_path")
elif cfg.get("ckpt_path"):
log.warning(
"`ckpt_path` was given, but the path does not exist. Training with new model weights."
)
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
train_metrics = trainer.callback_metrics
if cfg.get("test"):
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
test_metrics = trainer.callback_metrics
# merge train and test metrics
metric_dict = {**train_metrics, **test_metrics}
return metric_dict, object_dict
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
def main(cfg: DictConfig) -> Optional[float]:
"""Main entry point for training.
:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""
# apply extra utilities
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
extras(cfg)
# train the model
metric_dict, _ = train(cfg)
# safely retrieve metric value for hydra-based hyperparameter optimization
metric_value = get_metric_value(
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
)
# return optimized metric
return metric_value
if __name__ == "__main__":
register_custom_omegaconf_resolvers()
main()