-
Notifications
You must be signed in to change notification settings - Fork 207
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a bundle of 3 trained A2A codecs to enhance the sound quality (…
…NFY)
- Loading branch information
Showing
5 changed files
with
3,665 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Git LFS file not shown
Git LFS file not shown
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/6. EnCodec quality enchancement.ipynb. | ||
|
||
# %% auto 0 | ||
__all__ = ['load_data', 'load_datasets', 'LayerNorm', 'init_transformer', 'PureEncoder', 'AAARTransformer', 'make_model'] | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 1 | ||
import io | ||
import time | ||
import random | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 2 | ||
import torch | ||
import torch.nn as nn | ||
from torch.profiler import profile, record_function, ProfilerActivity, schedule | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 3 | ||
from pathlib import Path | ||
import json | ||
from fastprogress import progress_bar, master_bar | ||
import pandas as pd | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 4 | ||
from spear_tts_pytorch.train import * | ||
from spear_tts_pytorch.modules import * | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 12 | ||
def load_data(path): | ||
atoks = [] | ||
for name in Path(path).glob('*.encodec'): | ||
atoks.append(name) | ||
return pd.DataFrame(dict(atoks=atoks)) | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 14 | ||
import torch.nn.functional as F | ||
|
||
class SADataset(torch.utils.data.Dataset): | ||
def __init__(self, data, fromq=2, toq=4): | ||
self.fromq = fromq | ||
self.toq = toq | ||
self.n_ctx = 192 | ||
self.data = data | ||
self.samples = [(i,j) for i,name in enumerate(data['atoks']) for j in range(torch.load(name).shape[-1] // self.n_ctx)] | ||
|
||
def __len__(self): | ||
return len(self.samples) | ||
|
||
def A_tokens(self): | ||
return len(self)*self.n_ctx*4 | ||
|
||
def hours(self): | ||
return len(self)*self.n_ctx/2250*30/3600 | ||
|
||
def __repr__(self): | ||
return f"Dataset<{len(self)} samples, ({self.fromq}->{self.toq}), {self.A_tokens()} Atokens, {self.hours():.1f} hours>" | ||
|
||
def __getitem__(self, idx): | ||
i,j = self.samples[idx] | ||
row = self.data.iloc[i] | ||
jA = j * self.n_ctx | ||
Atoks = torch.load(row['atoks'], map_location='cpu')[0,:self.toq,jA:jA+self.n_ctx].T | ||
outtoks = Atoks.reshape(-1).clone() | ||
Atoks[:,self.fromq:] = 1024 # mask token | ||
intoks = Atoks.reshape(-1).clone() | ||
return intoks, outtoks | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 20 | ||
def load_datasets( | ||
path:Path, # encodec files path | ||
subsample:float=1, # use a fraction of the files | ||
fromq:int=2, # input quantizers | ||
toq:int=8 # output quantizers | ||
): | ||
data = load_data(path) | ||
|
||
val_data, train_data = data[:12], data[12:int(len(data)*subsample)] | ||
|
||
return SADataset(train_data, fromq=fromq, toq=toq), SADataset(val_data, fromq=fromq, toq=toq) | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 23 | ||
class LayerNorm(nn.LayerNorm): | ||
def forward(self, x): | ||
return super().forward(x.float()).type(x.dtype) | ||
|
||
# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163 | ||
def init_transformer(m): | ||
if isinstance(m, (nn.Linear, nn.Embedding)): | ||
torch.nn.init.trunc_normal_(m.weight, std=.02) | ||
if isinstance(m, nn.Linear) and m.bias is not None: | ||
torch.nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.LayerNorm): | ||
torch.nn.init.constant_(m.bias, 0) | ||
torch.nn.init.constant_(m.weight, 1.0) | ||
|
||
class PureEncoder(nn.Module): | ||
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, pos_embs=None): | ||
super().__init__() | ||
self.length = length | ||
self.codes = codes | ||
|
||
# embed semantic tokens | ||
self.embedding = nn.Embedding(codes+1, width) | ||
if pos_embs is None: pos_embs = sinusoids(length, width) | ||
self.register_buffer("positional_embedding", pos_embs) | ||
|
||
self.layers = nn.ModuleList([ | ||
ResidualAttentionBlock(width, n_head) for _ in range(depth) | ||
]) | ||
self.ln_post = LayerNorm(width) | ||
|
||
self.apply(init_transformer) | ||
|
||
def forward(self, Stoks): | ||
Sembs = self.embedding(Stoks) | ||
|
||
xin = (Sembs + self.positional_embedding[:Sembs.shape[1]]) | ||
|
||
x = xin | ||
for l in self.layers: x = l(x, causal=False) | ||
|
||
x = self.ln_post(x) | ||
|
||
logits = (x @ self.embedding.weight.to(x.dtype).T).float() | ||
return logits | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 24 | ||
class AAARTransformer(nn.Module): | ||
def __init__(self, width=384, depth=4, ctx_n=250, n_head=6, fromq=2, toq=4): | ||
super().__init__() | ||
|
||
pos_embs = sinusoids(ctx_n * toq, width) | ||
|
||
self.encoder = PureEncoder(pos_embs=pos_embs, length=ctx_n*toq, width=width, n_head=n_head, depth=depth) | ||
|
||
def forward(self, intoks, outtoks, loss=True): | ||
with record_function("decoder"): | ||
logits = self.encoder(intoks) | ||
if loss is not None: | ||
with record_function("loss"): | ||
loss = F.cross_entropy(logits.reshape(-1,logits.shape[-1]), outtoks.view(-1)) | ||
return logits, loss | ||
|
||
# %% ../nbs/6. EnCodec quality enchancement.ipynb 25 | ||
def make_model(size:str, dataset:torch.utils.data.Dataset=None): | ||
assert(dataset is not None) | ||
kwargs = dict(fromq = dataset.fromq, toq = dataset.toq, ctx_n = dataset.n_ctx) | ||
if size == 'tiny': | ||
return AAARTransformer(depth=4, **kwargs) | ||
elif size == 'base': | ||
return AAARTransformer(depth=6, width=512, n_head=8, **kwargs) | ||
elif size == 'small': | ||
return AAARTransformer(depth=12, width=768, n_head=12, **kwargs) |