-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
executable file
·131 lines (112 loc) · 4.27 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import dotenv
import hydra
import pytorch_lightning as pl
from main import utils
from omegaconf import DictConfig, open_dict
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.plugins.environments import SLURMEnvironment
import signal
from omegaconf import OmegaConf
# Load environment variables from `.env`.
dotenv.load_dotenv(override=True)
log = utils.get_logger(__name__)
@hydra.main(config_path=".", config_name="config.yaml", version_base=None)
def main(config: DictConfig) -> None:
OmegaConf.register_new_resolver("eval", eval)
# Logs config tree
utils.extras(config)
# Apply seed for reproducibility
pl.seed_everything(config.seed)
# Initialize datamodule
log.info(f"Instantiating datamodule <{config.datamodule._target_}>.")
datamodule = hydra.utils.instantiate(config.datamodule, _convert_="partial")
# Initialize model
log.info(f"Instantiating model <{config.model._target_}>.")
model = hydra.utils.instantiate(config.model, _convert_="partial")
# Initialize all callbacks (e.g. checkpoints, early stopping)
callbacks = []
# If save is provided add callback that saves and stops, to be used with +ckpt
if "save" in config:
# Ignore loggers and other callbacks
with open_dict(config):
config.pop("loggers")
config.pop("callbacks")
config.trainer.num_sanity_val_steps = 0
attribute, path = config.get("save"), config.get("ckpt_dir")
filename = os.path.join(path, f"{attribute}.pt")
callbacks += [utils.SavePytorchModelAndStopCallback(filename, attribute)]
if "callbacks" in config:
for _, cb_conf in config["callbacks"].items():
if "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>.")
callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
# Add learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='step')
callbacks.append(lr_monitor)
# Initialize loggers (e.g. wandb)
loggers = []
if "loggers" in config:
for _, lg_conf in config["loggers"].items():
if "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>.")
# Sometimes wandb throws error if slow connection...
logger = utils.retry_if_error(
lambda: hydra.utils.instantiate(lg_conf, _convert_="partial")
)
loggers.append(logger)
# Initialize trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>.")
if 'SLURM_JOB_ID' in os.environ:
trainer = hydra.utils.instantiate(
config.trainer,
callbacks=callbacks,
logger=loggers,
_convert_="partial",
plugins=[SLURMEnvironment(requeue_signal=signal.SIGUSR1)]
)
else:
trainer = hydra.utils.instantiate(
config.trainer,
callbacks=callbacks,
logger=loggers,
_convert_="partial"
)
# Send some parameters from config to all lightning loggers
log.info("Logging hyperparameters!")
utils.log_hyperparameters(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=loggers,
)
# Train with checkpoint if present, otherwise from start
if "ckpt" in config:
ckpt = config.get("ckpt")
log.info(f"Starting training from {ckpt}")
trainer.fit(model=model, datamodule=datamodule, ckpt_path=ckpt)
else:
log.info("Starting training.")
trainer.fit(model=model, datamodule=datamodule)
# Make sure everything closed properly
log.info("Finalizing!")
utils.finish(
config=config,
model=model,
datamodule=datamodule,
trainer=trainer,
callbacks=callbacks,
logger=loggers,
)
# Print path to best checkpoint
if (
not config.trainer.get("fast_dev_run")
and config.get("train")
and not config.get("save")
):
log.info(f"Best model ckpt at {trainer.checkpoint_callback.best_model_path}")
if __name__ == "__main__":
main()