-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy paths1_train.py
104 lines (92 loc) · 3.44 KB
/
s1_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/train_t2s.py
import argparse
import logging
import os
from pathlib import Path
import torch
from pytorch_lightning import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy
from AR.data.data_module import Text2SemanticDataModule
from AR.models.t2s_lightning_module import Text2SemanticLightningModule
from AR.utils.io import load_yaml_config
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
torch.set_float32_matmul_precision('high')
from AR.utils import get_newest_ckpt
def main(args):
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
ckpt_dir = output_dir / 'ckpt'
ckpt_dir.mkdir(parents=True, exist_ok=True)
config = load_yaml_config(args.config_file)
seed_everything(config["train"]["seed"], workers=True)
ckpt_callback: ModelCheckpoint = ModelCheckpoint(
save_top_k=-1,
save_on_train_epoch_end=False,
every_n_epochs=config["train"]["save_every_n_epoch"],
dirpath=ckpt_dir)
logger = WandbLogger(
project="ar_s1",
name=output_dir.stem,
save_dir=output_dir,
# resume the loss curve
resume=True,
# id='k19kvsq8'
)
trainer: Trainer = Trainer(
max_epochs=config["train"]["epochs"],
accelerator='gpu',
devices=-1,
benchmark=False,
fast_dev_run=False,
strategy=DDPStrategy(),
precision=config["train"]["precision"],
logger=logger,
callbacks=[ckpt_callback])
model: Text2SemanticLightningModule = Text2SemanticLightningModule(
config, output_dir)
data_module: Text2SemanticDataModule = Text2SemanticDataModule(
config,
train_semantic_path=args.train_semantic_path,
train_phoneme_path=args.train_phoneme_path,
dev_semantic_path=args.dev_semantic_path,
dev_phoneme_path=args.dev_phoneme_path)
try:
# 使用正则表达式匹配文件名中的数字部分,并按数字大小进行排序
newest_ckpt_name = get_newest_ckpt(os.listdir(ckpt_dir))
ckpt_path = ckpt_dir / newest_ckpt_name
except Exception:
ckpt_path = None
print("ckpt_path:", ckpt_path)
trainer.fit(model, data_module, ckpt_path=ckpt_path)
# srun --gpus-per-node=1 --ntasks-per-node=1 python train.py --path-to-configuration configurations/default.yaml
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-c',
'--config_file',
type=str,
default='configs/s1.yaml',
help='path of config file')
# args for dataset
parser.add_argument(
'--train_semantic_path',
type=str,
default='dump/semantic_train.tsv')
parser.add_argument(
'--train_phoneme_path', type=str, default='dump/phoneme_train.npy')
parser.add_argument(
'--dev_semantic_path', type=str, default='dump/semantic_dev.tsv')
parser.add_argument(
'--dev_phoneme_path', type=str, default='dump/phoneme_dev.npy')
parser.add_argument(
'--output_dir',
type=str,
default='logs/s1',
help='directory to save the results')
args = parser.parse_args()
logging.info(str(args))
main(args)