-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_qmst.py
44 lines (37 loc) · 1.22 KB
/
train_qmst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import pytorch_lightning as pl
from models.qmst import argparser
from models.qmst.model import Model
from models.qmst.data_module import DataModule
from pytorch_lightning.callbacks import ModelCheckpoint
from models.qmst.config import GPUS,ACCELERATOR
from copy import deepcopy
args = argparser.get_args()
if __name__ == "__main__":
trainer = pl.Trainer(
gpus=GPUS,
accelerator=ACCELERATOR,
fast_dev_run=args.dev,
precision=32,
default_root_dir='.log_qmst',
max_epochs=args.epoch,
callbacks=[
ModelCheckpoint(monitor='dev_loss',filename='{epoch}-{dev_loss:.2f}',save_last=True),
]
)
# DataModule
dm = DataModule()
# from_checkpoint
if args.from_checkpoint is None:
model = Model()
else:
print('load from checkpoint')
model = Model.load_from_checkpoint(args.from_checkpoint)
# train
if args.run_test == False:
tuner = pl.tuner.tuning.Tuner(deepcopy(trainer))
new_batch_size = tuner.scale_batch_size(model, datamodule=dm)
del tuner
model.hparams.batch_size = new_batch_size
trainer.fit(model,datamodule=dm)
# run_test
trainer.test(model,datamodule=dm)