-
Notifications
You must be signed in to change notification settings - Fork 13
/
train.py
136 lines (115 loc) · 4.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
###
# Author: Kai Li
# Date: 2024-01-22 01:16:22
# Email: lk21@mails.tsinghua.edu.cn
# LastEditTime: 2024-01-24 00:05:10
###
import json
from typing import Any, Dict, List, Optional, Tuple
import os
from omegaconf import OmegaConf
import argparse
import pytorch_lightning as pl
import torch
torch.set_float32_matmul_precision("highest")
import hydra
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning import Callback, LightningDataModule, LightningModule, Trainer
# from pytorch_lightning.loggers import Logger
from omegaconf import DictConfig
import look2hear.system
import look2hear.datas
import look2hear.losses
from look2hear.utils import RankedLogger, instantiate, print_only
import warnings
warnings.filterwarnings("ignore")
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
if cfg.get("seed"):
pl.seed_everything(cfg.seed, workers=True)
# instantiate datamodule
print_only(f"Instantiating datamodule <{cfg.datas._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.datas)
# instantiate model
print_only(f"Instantiating AudioNet <{cfg.model._target_}>")
model: torch.nn.Module = hydra.utils.instantiate(cfg.model)
print_only(f"Instantiating Discriminator <{cfg.discriminator._target_}>")
discriminator: torch.nn.Module = hydra.utils.instantiate(cfg.discriminator)
# instantiate optimizer
print_only(f"Instantiating optimizer <{cfg.optimizer_g._target_}>")
optimizer_g: torch.optim = hydra.utils.instantiate(cfg.optimizer_g, params=model.parameters())
optimizer_d: torch.optim = hydra.utils.instantiate(cfg.optimizer_d, params=discriminator.parameters())
# optimizer: torch.optim = torch.optim.Adam(model.parameters(), lr=cfg.optimizer.lr)
# instantiate scheduler
print_only(f"Instantiating scheduler <{cfg.scheduler_g._target_}>")
scheduler_g: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_g, optimizer=optimizer_g)
scheduler_d: torch.optim.lr_scheduler = hydra.utils.instantiate(cfg.scheduler_d, optimizer=optimizer_d)
# instantiate loss
print_only(f"Instantiating loss <{cfg.loss_g._target_}>")
loss_g: torch.nn.Module = hydra.utils.instantiate(cfg.loss_g)
loss_d: torch.nn.Module = hydra.utils.instantiate(cfg.loss_d)
losses = {
"g": loss_g,
"d": loss_d
}
# instantiate metrics
print_only(f"Instantiating metrics <{cfg.metrics._target_}>")
metrics: torch.nn.Module = hydra.utils.instantiate(cfg.metrics)
# instantiate system
print_only(f"Instantiating system <{cfg.system._target_}>")
system: LightningModule = hydra.utils.instantiate(
cfg.system,
model=model,
discriminator=discriminator,
loss_func=losses,
metrics=metrics,
optimizer=[optimizer_g, optimizer_d],
scheduler=[scheduler_g, scheduler_d]
)
# instantiate callbacks
callbacks: List[Callback] = []
if cfg.get("early_stopping"):
print_only(f"Instantiating early_stopping <{cfg.early_stopping._target_}>")
callbacks.append(hydra.utils.instantiate(cfg.early_stopping))
if cfg.get("checkpoint"):
print_only(f"Instantiating checkpoint <{cfg.checkpoint._target_}>")
checkpoint: pl.callbacks.ModelCheckpoint = hydra.utils.instantiate(cfg.checkpoint)
callbacks.append(checkpoint)
# instantiate logger
print_only(f"Instantiating logger <{cfg.logger._target_}>")
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name, "logs"), exist_ok=True)
logger = hydra.utils.instantiate(cfg.logger)
# instantiate trainer
print_only(f"Instantiating trainer <{cfg.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
cfg.trainer,
callbacks=callbacks,
logger=logger,
strategy=DDPStrategy(find_unused_parameters=True),
)
trainer.fit(system, datamodule=datamodule)
print_only("Training finished!")
best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()}
with open(os.path.join(cfg.exp.dir, cfg.exp.name, "best_k_models.json"), "w") as f:
json.dump(best_k, f, indent=0)
state_dict = torch.load(checkpoint.best_model_path)
system.load_state_dict(state_dict=state_dict["state_dict"])
system.cpu()
to_save = system.audio_model.serialize()
torch.save(to_save, os.path.join(cfg.exp.dir, cfg.exp.name, "best_model.pth"))
import wandb
if wandb.run:
print_only("Closing wandb!")
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--conf_dir",
default="local/conf.yml",
help="Full path to save best validation model",
)
args = parser.parse_args()
cfg = OmegaConf.load(args.conf_dir)
os.makedirs(os.path.join(cfg.exp.dir, cfg.exp.name), exist_ok=True)
# 保存配置到新的文件
OmegaConf.save(cfg, os.path.join(cfg.exp.dir, cfg.exp.name, "config.yaml"))
train(cfg)