-
Notifications
You must be signed in to change notification settings - Fork 14
/
train.py
74 lines (55 loc) · 2.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import pytorch_lightning as pl
import hydra
import torch
import wandb
import yaml
import os
from lib.gdna_model import BaseModel
from lib.dataset.datamodule import DataModule, DataProcessor
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
@hydra.main(config_path="config", config_name="config")
def main(opt):
print(opt.pretty())
pl.seed_everything(42, workers=True)
torch.set_num_threads(10)
callbacks = []
datamodule = DataModule(opt.datamodule)
datamodule.setup(stage='fit')
meta_info = datamodule.meta_info
data_processor = DataProcessor(opt.datamodule)
with open('.hydra/config.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
logger = pl.loggers.WandbLogger(project='gdna',
name=opt.expname,
config=config,
offline=False,
resume=True,
settings=wandb.Settings(start_method='fork'))
checkpoint_callback = pl.callbacks.ModelCheckpoint(save_top_k=None,
monitor=None,
dirpath='./checkpoints',
save_last=True,
every_n_val_epochs=1)
callbacks.append(checkpoint_callback)
checkpoint_path = './checkpoints/last.ckpt'
if not (os.path.exists(checkpoint_path) and opt.resume):
checkpoint_path = None
trainer = pl.Trainer(logger=logger,
callbacks=callbacks,
resume_from_checkpoint=checkpoint_path,
**opt.trainer)
model = BaseModel(opt=opt.model,
meta_info=meta_info,
data_processor=data_processor,
)
starting_path = hydra.utils.to_absolute_path(opt.starting_path)
if os.path.exists(starting_path) and checkpoint_path is None:
model = model.load_from_checkpoint(starting_path,
strict=False,
opt=opt.model,
meta_info=meta_info,
data_processor=data_processor)
trainer.fit(model, datamodule=datamodule)
if __name__ == '__main__':
main()