forked from facebookresearch/audiocraft
-
Notifications
You must be signed in to change notification settings - Fork 0
/
musicgen.py
705 lines (639 loc) · 34.3 KB
/
musicgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from pathlib import Path
import time
import typing as tp
import warnings
import flashy
import math
import omegaconf
import torch
from torch.nn import functional as F
from . import base, builders
from .compression import CompressionSolver
from .. import metrics as eval_metrics
from .. import models
from ..data.audio_dataset import AudioDataset
from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
from ..data.audio_utils import normalize_audio
from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
from ..utils.cache import CachedBatchWriter, CachedBatchLoader
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once
class MusicGenSolver(base.StandardSolver):
"""Solver for MusicGen training task.
Used in: https://arxiv.org/abs/2306.05284
"""
DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
def __init__(self, cfg: omegaconf.DictConfig):
super().__init__(cfg)
# easier access to sampling parameters
self.generation_params = {
'use_sampling': self.cfg.generate.lm.use_sampling,
'temp': self.cfg.generate.lm.temp,
'top_k': self.cfg.generate.lm.top_k,
'top_p': self.cfg.generate.lm.top_p,
}
self._best_metric_name: tp.Optional[str] = 'ce'
self._cached_batch_writer = None
self._cached_batch_loader = None
if cfg.cache.path:
if cfg.cache.write:
self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
if self.cfg.cache.write_num_shards:
self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
self._best_metric_name = None
else:
self._cached_batch_loader = CachedBatchLoader(
Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
min_length=self.cfg.optim.updates_per_epoch or 1)
self.dataloaders['original_train'] = self.dataloaders['train']
self.dataloaders['train'] = self._cached_batch_loader # type: ignore
@staticmethod
def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
device: tp.Optional[str] = None, autocast: bool = True,
batch_size: tp.Optional[int] = None,
override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
**kwargs):
"""Mostly a convenience function around magma.train.get_solver_from_sig,
populating all the proper param, deactivating EMA, FSDP, loading the best state,
basically all you need to get a solver ready to "play" with in single GPU mode
and with minimal memory overhead.
Args:
sig (str): signature to load.
dtype (str or None): potential dtype, as a string, i.e. 'float16'.
device (str or None): potential device, as a string, i.e. 'cuda'.
override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
"""
from audiocraft import train
our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
our_override_cfg['autocast'] = autocast
if dtype is not None:
our_override_cfg['dtype'] = dtype
if device is not None:
our_override_cfg['device'] = device
if batch_size is not None:
our_override_cfg['dataset'] = {'batch_size': batch_size}
if override_cfg is None:
override_cfg = {}
override_cfg = omegaconf.OmegaConf.merge(
omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
solver = train.get_solver_from_sig(
sig, override_cfg=override_cfg,
load_best=True, disable_fsdp=True,
ignore_state_keys=['optimizer', 'ema'], **kwargs)
solver.model.eval()
return solver
def get_formatter(self, stage_name: str) -> flashy.Formatter:
return flashy.Formatter({
'lr': '.2E',
'ce': '.3f',
'ppl': '.3f',
'grad_norm': '.3E',
}, exclude_keys=['ce_q*', 'ppl_q*'])
@property
def best_metric_name(self) -> tp.Optional[str]:
return self._best_metric_name
def build_model(self) -> None:
"""Instantiate models and optimizer."""
# we can potentially not use all quantizers with which the EnCodec model was trained
# (e.g. we trained the model with quantizers dropout)
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
assert self.compression_model.sample_rate == self.cfg.sample_rate, (
f"Compression model sample rate is {self.compression_model.sample_rate} but "
f"Solver sample rate is {self.cfg.sample_rate}."
)
# ensure we have matching configuration between LM and compression model
assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
"Cardinalities of the LM and compression model don't match: ",
f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
f"compression model cardinality is {self.compression_model.cardinality}"
)
assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
"Numbers of codebooks of the LM and compression models don't match: ",
f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
)
self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
self.compression_model.num_codebooks, self.compression_model.cardinality,
self.compression_model.frame_rate)
# instantiate LM model
self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
if self.cfg.fsdp.use:
assert not self.cfg.autocast, "Cannot use autocast with fsdp"
self.model = self.wrap_with_fsdp(self.model)
self.register_ema('model')
# initialize optimization
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
self.register_best_state('model')
self.autocast_dtype = {
'float16': torch.float16, 'bfloat16': torch.bfloat16
}[self.cfg.autocast_dtype]
self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
if self.cfg.fsdp.use:
need_scaler = self.cfg.fsdp.param_dtype == 'float16'
else:
need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
if need_scaler:
if self.cfg.fsdp.use:
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
self.scaler = ShardedGradScaler() # type: ignore
else:
self.scaler = torch.cuda.amp.GradScaler()
self.register_stateful('scaler')
def build_dataloaders(self) -> None:
"""Instantiate audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
def show(self) -> None:
"""Show the compression model and LM model."""
self.logger.info("Compression model:")
self.log_model_summary(self.compression_model)
self.logger.info("LM model:")
self.log_model_summary(self.model)
def load_state_dict(self, state: dict) -> None:
if 'condition_provider' in state:
model_state = state['model']
condition_provider_state = state.pop('condition_provider')
prefix = 'condition_provider.'
for key, value in condition_provider_state.items():
key = prefix + key
assert key not in model_state
model_state[key] = value
super().load_state_dict(state)
def load_from_pretrained(self, name: str):
# TODO: support native HF versions of MusicGen.
lm_pkg = models.loaders.load_lm_model_ckpt(name)
state: dict = {
'best_state': {
'model': lm_pkg['best_state'],
},
}
return state
def _compute_cross_entropy(
self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
"""Compute cross entropy between multi-codebook targets and model's logits.
The cross entropy is computed per codebook to provide codebook-level cross entropy.
Valid timesteps for each of the codebook are pulled from the mask, where invalid
timesteps are set to 0.
Args:
logits (torch.Tensor): Model's logits of shape [B, K, T, card].
targets (torch.Tensor): Target codes, of shape [B, K, T].
mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
Returns:
ce (torch.Tensor): Cross entropy averaged over the codebooks
ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
"""
B, K, T = targets.shape
assert logits.shape[:-1] == targets.shape
assert mask.shape == targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook: tp.List[torch.Tensor] = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
# average cross entropy across codebooks
ce = ce / K
return ce, ce_per_codebook
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
"""Prepare input batchs for language model training.
Args:
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
and corresponding metadata as SegmentWithAttributes (with B items).
check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
Returns:
Condition tensors (dict[str, any]): Preprocessed condition attributes.
Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
with B the batch size, K the number of codebooks, T_s the token timesteps.
Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
"""
if self.model.training:
warnings.warn(
"Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
"This is inconsistent with how model were trained in the MusicGen paper. We removed the "
"`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
"Really sorry about that.")
if self._cached_batch_loader is None or self.current_stage != "train":
audio, infos = batch
audio = audio.to(self.device)
audio_tokens = None
assert audio.size(0) == len(infos), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in metadata ({len(infos)})"
)
else:
audio = None
# In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
infos, = batch # type: ignore
assert all([isinstance(info, AudioInfo) for info in infos])
assert all([info.audio_tokens is not None for info in infos]) # type: ignore
audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore
audio_tokens = audio_tokens.long()
for info in infos:
if isinstance(info, MusicInfo):
# Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
# then you must be using the chroma cache! otherwise the code will try
# to use this segment and fail (by that I mean you will see NaN everywhere).
info.self_wav = WavCondition(
torch.full([1, info.channels, info.total_frames], float('NaN')),
length=torch.tensor([info.n_frames]),
sample_rate=[info.sample_rate],
path=[info.meta.path],
seek_time=[info.seek_time])
dataset = get_dataset_from_loader(self.dataloaders['original_train'])
assert isinstance(dataset, MusicDataset), type(dataset)
if dataset.paraphraser is not None and info.description is not None:
# Hackingly reapplying paraphraser when using cache.
info.description = dataset.paraphraser.sample_paraphrase(
info.meta.path, info.description)
# prepare attributes
attributes = [info.to_condition_attributes() for info in infos]
attributes = self.model.cfg_dropout(attributes)
attributes = self.model.att_dropout(attributes)
tokenized = self.model.condition_provider.tokenize(attributes)
# Now we should be synchronization free.
if self.device == "cuda" and check_synchronization_points:
torch.cuda.set_sync_debug_mode("warn")
if audio_tokens is None:
with torch.no_grad():
audio_tokens, scale = self.compression_model.encode(audio)
assert scale is None, "Scaled compression model not supported with LM."
with self.autocast:
condition_tensors = self.model.condition_provider(tokenized)
# create a padding mask to hold valid vs invalid positions
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
# replace encodec tokens from padded audio with special_token_id
if self.cfg.tokens.padding_with_special_token:
audio_tokens = audio_tokens.clone()
padding_mask = padding_mask.clone()
token_sample_rate = self.compression_model.frame_rate
B, K, T_s = audio_tokens.shape
for i in range(B):
n_samples = infos[i].n_frames
audio_sample_rate = infos[i].sample_rate
# take the last token generated from actual audio frames (non-padded audio)
valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
padding_mask[i, :, valid_tokens:] = 0
if self.device == "cuda" and check_synchronization_points:
torch.cuda.set_sync_debug_mode("default")
if self._cached_batch_writer is not None and self.current_stage == 'train':
assert self._cached_batch_loader is None
assert audio_tokens is not None
for info, one_audio_tokens in zip(infos, audio_tokens):
assert isinstance(info, AudioInfo)
if isinstance(info, MusicInfo):
assert not info.joint_embed, "joint_embed and cache not supported yet."
info.self_wav = None
assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
info.audio_tokens = one_audio_tokens.short().cpu()
self._cached_batch_writer.save(infos)
return condition_tensors, audio_tokens, padding_mask
def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
"""Perform one training or valid step on a given batch."""
check_synchronization_points = idx == 1 and self.device == 'cuda'
condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
batch, check_synchronization_points)
self.deadlock_detect.update('tokens_and_conditions')
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('warn')
with self.autocast:
model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore
logits = model_output.logits
mask = padding_mask & model_output.mask
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
loss = ce
self.deadlock_detect.update('loss')
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('default')
if self.is_training:
metrics['lr'] = self.optimizer.param_groups[0]['lr']
if self.scaler is not None:
loss = self.scaler.scale(loss)
self.deadlock_detect.update('scale')
if self.cfg.fsdp.use:
loss.backward()
flashy.distrib.average_tensors(self.model.buffers())
elif self.cfg.optim.eager_sync:
with flashy.distrib.eager_sync_model(self.model):
loss.backward()
else:
# this should always be slower but can be useful
# for weird use cases like multiple backwards.
loss.backward()
flashy.distrib.sync_model(self.model)
self.deadlock_detect.update('backward')
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
if self.cfg.optim.max_norm:
if self.cfg.fsdp.use:
metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
else:
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.optim.max_norm
)
if self.scaler is None:
self.optimizer.step()
else:
self.scaler.step(self.optimizer)
self.scaler.update()
if self.lr_scheduler:
self.lr_scheduler.step()
self.optimizer.zero_grad()
self.deadlock_detect.update('optim')
if self.scaler is not None:
scale = self.scaler.get_scale()
metrics['grad_scale'] = scale
if not loss.isfinite().all():
raise RuntimeError("Model probably diverged.")
metrics['ce'] = ce
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
return metrics
@torch.no_grad()
def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
gen_duration: float, prompt_duration: tp.Optional[float] = None,
remove_prompt: bool = False,
**generation_params) -> dict:
"""Run generate step on a batch of optional audio tensor and corresponding attributes.
Args:
batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
gen_duration (float): Target audio duration for the generation.
prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
generation_params: Additional generation parameters.
Returns:
gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
and the prompt along with additional information.
"""
bench_start = time.time()
audio, meta = batch
assert audio.size(0) == len(meta), (
f"Mismatch between number of items in audio batch ({audio.size(0)})",
f" and in metadata ({len(meta)})"
)
# prepare attributes
attributes = [x.to_condition_attributes() for x in meta]
# TODO: Add dropout for chroma?
# prepare audio prompt
if prompt_duration is None:
prompt_audio = None
else:
assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
prompt_audio = audio[..., :prompt_audio_frames]
# get audio tokens from compression model
if prompt_audio is None or prompt_audio.nelement() == 0:
num_samples = len(attributes)
prompt_tokens = None
else:
num_samples = None
prompt_audio = prompt_audio.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt_audio)
assert scale is None, "Compression model in MusicGen should not require rescaling."
# generate by sampling from the LM
with self.autocast:
total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
gen_tokens = self.model.generate(
prompt_tokens, attributes, max_gen_len=total_gen_len,
num_samples=num_samples, **self.generation_params)
# generate audio from tokens
assert gen_tokens.dim() == 3
gen_audio = self.compression_model.decode(gen_tokens, None)
bench_end = time.time()
gen_outputs = {
'rtf': (bench_end - bench_start) / gen_duration,
'ref_audio': audio,
'gen_audio': gen_audio,
'gen_tokens': gen_tokens,
'prompt_audio': prompt_audio,
'prompt_tokens': prompt_tokens,
}
return gen_outputs
def generate_audio(self) -> dict:
"""Audio generation stage."""
generate_stage_name = f'{self.current_stage}'
sample_manager = SampleManager(self.xp)
self.logger.info(f"Generating samples in {sample_manager.base_folder}")
loader = self.dataloaders['generate']
updates = len(loader)
lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
dataset = get_dataset_from_loader(loader)
dataset_duration = dataset.segment_duration
assert dataset_duration is not None
assert isinstance(dataset, AudioDataset)
target_duration = self.cfg.generate.lm.gen_duration
prompt_duration = self.cfg.generate.lm.prompt_duration
if target_duration is None:
target_duration = dataset_duration
if prompt_duration is None:
prompt_duration = dataset_duration / 4
assert prompt_duration < dataset_duration, (
f"Specified prompt duration ({prompt_duration}s) is longer",
f" than reference audio duration ({dataset_duration}s)"
)
def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
hydrated_conditions = []
for sample in [x.to_condition_attributes() for x in meta]:
cond_dict = {}
for cond_type in sample.__annotations__.keys():
for cond_key, cond_val in getattr(sample, cond_type).items():
if cond_key not in self.model.condition_provider.conditioners.keys():
continue
if is_jsonable(cond_val):
cond_dict[cond_key] = cond_val
elif isinstance(cond_val, WavCondition):
cond_dict[cond_key] = cond_val.path
elif isinstance(cond_val, JointEmbedCondition):
cond_dict[cond_key] = cond_val.text # only support text at inference for now
else:
# if we reached this point, it is not clear how to log the condition
# so we just log the type.
cond_dict[cond_key] = str(type(cond_val))
continue
hydrated_conditions.append(cond_dict)
return hydrated_conditions
metrics: dict = {}
average = flashy.averager()
for batch in lp:
audio, meta = batch
# metadata for sample manager
hydrated_conditions = get_hydrated_conditions(meta)
sample_generation_params = {
**{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
**self.generation_params
}
if self.cfg.generate.lm.unprompted_samples:
if self.cfg.generate.lm.gen_gt_samples:
# get the ground truth instead of generation
self.logger.warn(
"Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
gen_unprompted_audio = audio
rtf = 1.
else:
gen_unprompted_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=None,
**self.generation_params)
gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
rtf = gen_unprompted_outputs['rtf']
sample_manager.add_samples(
gen_unprompted_audio, self.epoch, hydrated_conditions,
ground_truth_wavs=audio, generation_args=sample_generation_params)
if self.cfg.generate.lm.prompted_samples:
gen_outputs = self.run_generate_step(
batch, gen_duration=target_duration, prompt_duration=prompt_duration,
**self.generation_params)
gen_audio = gen_outputs['gen_audio'].cpu()
prompt_audio = gen_outputs['prompt_audio'].cpu()
sample_manager.add_samples(
gen_audio, self.epoch, hydrated_conditions,
prompt_wavs=prompt_audio, ground_truth_wavs=audio,
generation_args=sample_generation_params)
metrics['rtf'] = rtf
metrics = average(metrics)
flashy.distrib.barrier()
return metrics
def generate(self) -> dict:
"""Generate stage."""
self.model.eval()
with torch.no_grad():
return self.generate_audio()
def run_epoch(self):
if self.cfg.cache.write:
if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
return
super().run_epoch()
def train(self):
"""Train stage.
"""
if self._cached_batch_writer is not None:
self._cached_batch_writer.start_epoch(self.epoch)
if self._cached_batch_loader is None:
dataset = get_dataset_from_loader(self.dataloaders['train'])
assert isinstance(dataset, AudioDataset)
dataset.current_epoch = self.epoch
else:
self._cached_batch_loader.start_epoch(self.epoch)
return super().train()
def evaluate_audio_generation(self) -> dict:
"""Evaluate audio generation with off-the-shelf metrics."""
evaluate_stage_name = f'{self.current_stage}_generation'
# instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
should_run_eval = False
eval_chroma_wavs: tp.Optional[torch.Tensor] = None
if self.cfg.evaluate.metrics.fad:
fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.kld:
kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.text_consistency:
text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
should_run_eval = True
if self.cfg.evaluate.metrics.chroma_cosine:
chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
# if we have predefind wavs for chroma we should purge them for computing the cosine metric
has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
if has_predefined_eval_chromas:
warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
'Resetting eval chromas to None for evaluation.')
eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore
self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore
should_run_eval = True
def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
compressed_audio = self.compression_model.decode(audio_tokens, scale)
return compressed_audio[..., :audio.shape[-1]]
metrics: dict = {}
if should_run_eval:
loader = self.dataloaders['evaluate']
updates = len(loader)
lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
average = flashy.averager()
dataset = get_dataset_from_loader(loader)
assert isinstance(dataset, AudioDataset)
self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
for idx, batch in enumerate(lp):
audio, meta = batch
assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
target_duration = audio.shape[-1] / self.cfg.sample_rate
if self.cfg.evaluate.fixed_generation_duration:
target_duration = self.cfg.evaluate.fixed_generation_duration
gen_outputs = self.run_generate_step(
batch, gen_duration=target_duration,
**self.generation_params
)
y_pred = gen_outputs['gen_audio'].detach()
y_pred = y_pred[..., :audio.shape[-1]]
normalize_kwargs = dict(self.cfg.generate.audio)
normalize_kwargs.pop('format', None)
y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
y = audio.cpu() # should already be on CPU but just in case
sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding
sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples
audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
if fad is not None:
if self.cfg.metrics.fad.use_gt:
y_pred = get_compressed_audio(y).cpu()
fad.update(y_pred, y, sizes, sample_rates, audio_stems)
if kldiv is not None:
if self.cfg.metrics.kld.use_gt:
y_pred = get_compressed_audio(y).cpu()
kldiv.update(y_pred, y, sizes, sample_rates)
if text_consistency is not None:
texts = [m.description for m in meta]
if self.cfg.metrics.text_consistency.use_gt:
y_pred = y
text_consistency.update(y_pred, texts, sizes, sample_rates)
if chroma_cosine is not None:
if self.cfg.metrics.chroma_cosine.use_gt:
y_pred = get_compressed_audio(y).cpu()
chroma_cosine.update(y_pred, y, sizes, sample_rates)
# restore chroma conditioner's eval chroma wavs
if eval_chroma_wavs is not None:
self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
flashy.distrib.barrier()
if fad is not None:
metrics['fad'] = fad.compute()
if kldiv is not None:
kld_metrics = kldiv.compute()
metrics.update(kld_metrics)
if text_consistency is not None:
metrics['text_consistency'] = text_consistency.compute()
if chroma_cosine is not None:
metrics['chroma_cosine'] = chroma_cosine.compute()
metrics = average(metrics)
metrics = flashy.distrib.average_metrics(metrics, len(loader))
return metrics
def evaluate(self) -> dict:
"""Evaluate stage."""
self.model.eval()
with torch.no_grad():
metrics: dict = {}
if self.cfg.evaluate.metrics.base:
metrics.update(self.common_train_valid('evaluate'))
gen_metrics = self.evaluate_audio_generation()
return {**metrics, **gen_metrics}