This repository has been archived by the owner on Mar 7, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
96 lines (69 loc) · 2.05 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from omegaconf import DictConfig, OmegaConf
import hydra
import jax
from jax import random, numpy as np, value_and_grad, jit, tree_util
from optax import (
chain,
clip_by_global_norm,
scale_by_adam,
scale,
apply_updates,
add_decayed_weights,
masked,
)
from clap.models import CLAP
# data
from clap.datasets import PairTextSpectrogramTFRecords
@hydra.main(config_path="configs")
def train(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
# rng
rng_key = random.PRNGKey(cfg.training.seed)
# data
training_data_path = hydra.utils.get_original_cwd() + "/" + cfg.training.data_folder
dataloader = PairTextSpectrogramTFRecords(
training_data_path,
cfg.training.batch_size,
)
# model
model = CLAP(
text_config=cfg.model.text,
audio_config=cfg.model.audio,
)
# optimizer
exclude_bias = lambda params: tree_util.tree_map(lambda x: x.ndim != 1, params)
optim = chain(
clip_by_global_norm(cfg.optimizer.max_norm),
scale_by_adam(eps=1e-4),
add_decayed_weights(cfg.optimizer.weight_decay, exclude_bias),
scale(-cfg.optimizer.learning_rate),
)
# init
batch = next(iter(dataloader))
text = batch["text"]
audio = batch["audio"]
params = model.init(rng_key, text, audio)
optim_state = optim.init(params)
# loss function, for use with value_and_grad
@jit
@value_and_grad
def loss_fn(params, text, audio):
return model.apply(
params,
text,
audio,
return_loss=True,
is_training=True,
)
# train loop
for _ in range(cfg.training.epochs):
for batch in dataloader:
text = batch["text"]
audio = batch["audio"]
loss, grads = loss_fn(params, text, audio)
updates, optim_state = optim.update(grads, optim_state, params)
params = apply_updates(params, updates)
print(f"loss: {loss}")
# finished
if __name__ == "__main__":
train()