Skip to content

Commit

Permalink
Added a bundle of 3 trained A2A codecs to enhance the sound quality (…
Browse files Browse the repository at this point in the history
…NFY)
  • Loading branch information
jpc committed Apr 25, 2023
1 parent ed97e2a commit 82902a6
Show file tree
Hide file tree
Showing 5 changed files with 3,665 additions and 0 deletions.
3,505 changes: 3,505 additions & 0 deletions nbs/6. EnCodec quality enchancement.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions nbs/a2a-2t4.ckpt
Git LFS file not shown
3 changes: 3 additions & 0 deletions nbs/a2a-4t6.ckpt
Git LFS file not shown
3 changes: 3 additions & 0 deletions nbs/a2a-6t8.ckpt
Git LFS file not shown
151 changes: 151 additions & 0 deletions spear_tts_pytorch/a2a.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/6. EnCodec quality enchancement.ipynb.

# %% auto 0
__all__ = ['load_data', 'load_datasets', 'LayerNorm', 'init_transformer', 'PureEncoder', 'AAARTransformer', 'make_model']

# %% ../nbs/6. EnCodec quality enchancement.ipynb 1
import io
import time
import random

# %% ../nbs/6. EnCodec quality enchancement.ipynb 2
import torch
import torch.nn as nn
from torch.profiler import profile, record_function, ProfilerActivity, schedule

# %% ../nbs/6. EnCodec quality enchancement.ipynb 3
from pathlib import Path
import json
from fastprogress import progress_bar, master_bar
import pandas as pd

# %% ../nbs/6. EnCodec quality enchancement.ipynb 4
from spear_tts_pytorch.train import *
from spear_tts_pytorch.modules import *

# %% ../nbs/6. EnCodec quality enchancement.ipynb 12
def load_data(path):
atoks = []
for name in Path(path).glob('*.encodec'):
atoks.append(name)
return pd.DataFrame(dict(atoks=atoks))

# %% ../nbs/6. EnCodec quality enchancement.ipynb 14
import torch.nn.functional as F

class SADataset(torch.utils.data.Dataset):
def __init__(self, data, fromq=2, toq=4):
self.fromq = fromq
self.toq = toq
self.n_ctx = 192
self.data = data
self.samples = [(i,j) for i,name in enumerate(data['atoks']) for j in range(torch.load(name).shape[-1] // self.n_ctx)]

def __len__(self):
return len(self.samples)

def A_tokens(self):
return len(self)*self.n_ctx*4

def hours(self):
return len(self)*self.n_ctx/2250*30/3600

def __repr__(self):
return f"Dataset<{len(self)} samples, ({self.fromq}->{self.toq}), {self.A_tokens()} Atokens, {self.hours():.1f} hours>"

def __getitem__(self, idx):
i,j = self.samples[idx]
row = self.data.iloc[i]
jA = j * self.n_ctx
Atoks = torch.load(row['atoks'], map_location='cpu')[0,:self.toq,jA:jA+self.n_ctx].T
outtoks = Atoks.reshape(-1).clone()
Atoks[:,self.fromq:] = 1024 # mask token
intoks = Atoks.reshape(-1).clone()
return intoks, outtoks

# %% ../nbs/6. EnCodec quality enchancement.ipynb 20
def load_datasets(
path:Path, # encodec files path
subsample:float=1, # use a fraction of the files
fromq:int=2, # input quantizers
toq:int=8 # output quantizers
):
data = load_data(path)

val_data, train_data = data[:12], data[12:int(len(data)*subsample)]

return SADataset(train_data, fromq=fromq, toq=toq), SADataset(val_data, fromq=fromq, toq=toq)

# %% ../nbs/6. EnCodec quality enchancement.ipynb 23
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)

# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
def init_transformer(m):
if isinstance(m, (nn.Linear, nn.Embedding)):
torch.nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)

class PureEncoder(nn.Module):
def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, pos_embs=None):
super().__init__()
self.length = length
self.codes = codes

# embed semantic tokens
self.embedding = nn.Embedding(codes+1, width)
if pos_embs is None: pos_embs = sinusoids(length, width)
self.register_buffer("positional_embedding", pos_embs)

self.layers = nn.ModuleList([
ResidualAttentionBlock(width, n_head) for _ in range(depth)
])
self.ln_post = LayerNorm(width)

self.apply(init_transformer)

def forward(self, Stoks):
Sembs = self.embedding(Stoks)

xin = (Sembs + self.positional_embedding[:Sembs.shape[1]])

x = xin
for l in self.layers: x = l(x, causal=False)

x = self.ln_post(x)

logits = (x @ self.embedding.weight.to(x.dtype).T).float()
return logits

# %% ../nbs/6. EnCodec quality enchancement.ipynb 24
class AAARTransformer(nn.Module):
def __init__(self, width=384, depth=4, ctx_n=250, n_head=6, fromq=2, toq=4):
super().__init__()

pos_embs = sinusoids(ctx_n * toq, width)

self.encoder = PureEncoder(pos_embs=pos_embs, length=ctx_n*toq, width=width, n_head=n_head, depth=depth)

def forward(self, intoks, outtoks, loss=True):
with record_function("decoder"):
logits = self.encoder(intoks)
if loss is not None:
with record_function("loss"):
loss = F.cross_entropy(logits.reshape(-1,logits.shape[-1]), outtoks.view(-1))
return logits, loss

# %% ../nbs/6. EnCodec quality enchancement.ipynb 25
def make_model(size:str, dataset:torch.utils.data.Dataset=None):
assert(dataset is not None)
kwargs = dict(fromq = dataset.fromq, toq = dataset.toq, ctx_n = dataset.n_ctx)
if size == 'tiny':
return AAARTransformer(depth=4, **kwargs)
elif size == 'base':
return AAARTransformer(depth=6, width=512, n_head=8, **kwargs)
elif size == 'small':
return AAARTransformer(depth=12, width=768, n_head=12, **kwargs)

0 comments on commit 82902a6

Please sign in to comment.