Skip to content

Commit

Permalink
Implement CPU and MPS support enhancements for WhisperSpeech
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored and jpc committed Feb 13, 2024
1 parent e35ee9a commit 5f691c6
Show file tree
Hide file tree
Showing 14 changed files with 92 additions and 31 deletions.
10 changes: 7 additions & 3 deletions whisperspeech/a2wav.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
# %% ../nbs/6. Quality-boosting vocoder.ipynb 2
class Vocoder:
def __init__(self, repo_id="charactr/vocos-encodec-24khz"):
self.vocos = Vocos.from_pretrained(repo_id).cuda()
if torch.cuda.is_available() and (torch.version.cuda or torch.version.hip):
self.vocos_compute_device = 'cuda'
else:
self.vocos_compute_device = 'cpu' # mps does not currently work with vocos, thus only cuda or cpu
self.vocos = Vocos.from_pretrained(repo_id).to(self.vocos_compute_device)

def is_notebook(self):
try:
Expand All @@ -26,9 +30,9 @@ def decode(self, atoks):
atoks = atoks.permute(1,0,2)
else:
q,t = atoks.shape

# print(atoks.dtype, atoks.device) # uncomment to check dtype and compute_device
features = self.vocos.codes_to_features(atoks)
bandwidth_id = torch.tensor({2:0,4:1,8:2}[q]).cuda()
bandwidth_id = torch.tensor({2: 0, 4: 1, 8: 2}[q]).to(self.vocos_compute_device) # Move tensor to the same device as model
return self.vocos.decode(features, bandwidth_id=bandwidth_id)

def decode_to_file(self, fname, atoks):
Expand Down
14 changes: 11 additions & 3 deletions whisperspeech/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@
import torch
from fastcore.script import call_parse
from whisperspeech.pipeline import Pipeline
from utils import get_compute_device

compute_device = get_compute_device() # determine compute_device

# %% ../nbs/C. Benchmark.ipynb 3
def measure(fun, iterations = 10):
ts = []
for x in range(iterations):
start = time.time()
fun()
torch.cuda.synchronize()
if compute_device == 'cuda':
torch.cuda.synchronize()
elif compute_device == 'mps':
torch.mps.synchronize()
elif compute_device == 'cpu':
torch.cpu.synchronize() # only kept for device-agnostic formatting; technically
ts.append(time.time() - start)
ts = torch.tensor(ts)
return ts.mean(), ts.std()
Expand All @@ -37,13 +45,13 @@ def benchmark(

if t2s_ctx_n:
pipe.t2s.stoks_len = t2s_ctx_n
pipe.t2s.decoder.mask = torch.empty(t2s_ctx_n, t2s_ctx_n).fill_(-torch.inf).triu_(1).cuda()
pipe.t2s.decoder.mask = torch.empty(t2s_ctx_n, t2s_ctx_n).fill_(-torch.inf).triu_(1).to(compute_device)

pipe.t2s.optimize(max_batch_size=max_batch_size, torch_compile=not no_torch_compile)

if s2a_ctx_n:
pipe.s2a.ctx_n = s2a_ctx_n
pipe.s2a.decoder.mask = torch.empty(s2a_ctx_n, s2a_ctx_n).fill_(-torch.inf).triu_(1).cuda()
pipe.s2a.decoder.mask = torch.empty(s2a_ctx_n, s2a_ctx_n).fill_(-torch.inf).triu_(1).to(compute_device)

pipe.s2a.optimize(max_batch_size=max_batch_size, torch_compile=not no_torch_compile)

Expand Down
7 changes: 5 additions & 2 deletions whisperspeech/extract_acoustic.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@
from pathlib import Path
from fastcore.script import *
from fastprogress import progress_bar, master_bar
from utils import get_compute_device

compute_device = get_compute_device()

# %% ../nbs/1. Acoustic token extraction.ipynb 5
def load(fname, newsr=24000):
"""Load an audio file to the GPU and resample to `newsr`."""
x, sr = torchaudio.load(fname)
_tform = torchaudio.transforms.Resample(sr, newsr)
return _tform(x).cuda().unsqueeze(0)
return _tform(x).to(compute_device).unsqueeze(0)

# %% ../nbs/1. Acoustic token extraction.ipynb 6
def load_model():
"Load the pretrained EnCodec model"
from encodec.model import EncodecModel
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(1.5)
model.cuda().eval();
model.to(compute_device).eval();
return model

# %% ../nbs/1. Acoustic token extraction.ipynb 7
Expand Down
5 changes: 4 additions & 1 deletion whisperspeech/extract_spk_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from . import vad, utils

from speechbrain.pretrained import EncoderClassifier
from utils import get_compute_device

compute_device = get_compute_device()

# %% ../nbs/2A. Speaker Embeddings.ipynb 5
def calc_len(x):
Expand Down Expand Up @@ -47,7 +50,7 @@ def process_shard(

classifier = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
savedir=f"{os.environ['HOME']}/.cache/speechbrain/",
run_opts={"device": "cuda"})
run_opts = {"device": compute_device}

with utils.AtomicTarWriter(utils.derived_name(input, f'spk_emb')) as sink:
for keys, samples, seconds in progress_bar(dl, total=total):
Expand Down
9 changes: 6 additions & 3 deletions whisperspeech/extract_stoks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from . import vq_stoks, utils, vad_merge
import webdataset as wds

from utils import get_compute_device
compute_device = get_compute_device()

# %% ../nbs/3B. Semantic token extraction.ipynb 7
@call_parse
def prepare_stoks(
Expand All @@ -32,13 +35,13 @@ def prepare_stoks(
kind:str="maxvad", # could be eqvad to get more uniform chunk lengths

):
vq_model = vq_stoks.RQBottleneckTransformer.load_model(vq_model).cuda()
vq_model = vq_stoks.RQBottleneckTransformer.load_model(vq_model).to(compute_device)
vq_model.ensure_whisper()
# vq_model.encode_mel = torch.compile(vq_model.encode_mel, mode="reduce-overhead", fullgraph=True)

spk_classifier = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
savedir=f"{os.environ['HOME']}/.cache/speechbrain/",
run_opts={"device": "cuda"})
run_opts = {"device": compute_device}

total = n_samples//batch_size if n_samples else 'noinfer'

Expand All @@ -53,7 +56,7 @@ def prepare_stoks(
with utils.AtomicTarWriter(utils.derived_name(input, f'{kind}-stoks', dir="."), throwaway=n_samples is not None) as sink:
for keys, rpad_ss, samples16k in progress_bar(dl, total=total):
with torch.no_grad():
samples16k = samples16k.cuda().to(torch.float16)
samples16k = samples16k.to(compute_device).to(torch.float16)
stoks = vq_model.encode_audio(samples16k).cpu().numpy().astype(np.int16)
spk_embs = spk_classifier.encode_batch(
samples16k, wav_lens=torch.tensor(30 - rpad_ss, dtype=torch.float)/30)[:,0,:].cpu().numpy()
Expand Down
25 changes: 18 additions & 7 deletions whisperspeech/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@
from whisperspeech.a2wav import Vocoder
import traceback
from pathlib import Path
from utils import get_compute_device

compute_device = get_compute_device()

if torch.cuda.is_available():
encoder_device = 'cuda'
vocoder_device = 'cuda'
else:
encoder_device = 'cpu'
vocoder_device = 'cpu'

# %% ../nbs/7. Pipeline.ipynb 2
class Pipeline:
Expand Down Expand Up @@ -45,20 +55,21 @@ def __init__(self, t2s_ref=None, s2a_ref=None, optimize=True, torch_compile=Fals
try:
if t2s_ref:
args["ref"] = t2s_ref
self.t2s = TSARTransformer.load_model(**args).cuda()
self.t2s = TSARTransformer.load_model(**args).to(compute_device) # use obtained compute device
if optimize: self.t2s.optimize(torch_compile=torch_compile)
except:
print("Failed to load the T2S model:")
print(traceback.format_exc())
try:
if s2a_ref:
args["ref"] = s2a_ref
self.s2a = SADelARTransformer.load_model(**args).cuda()
self.s2a = SADelARTransformer.load_model(**args).to(compute_device) # use obtained compute device
if optimize: self.s2a.optimize(torch_compile=torch_compile)
except:
print("Failed to load the S2A model:")
print(traceback.format_exc())
self.vocoder = Vocoder()

self.vocoder = Vocoder().to(vocoder_device)
self.encoder = None

def extract_spk_emb(self, fname):
Expand All @@ -69,20 +80,20 @@ def extract_spk_emb(self, fname):
from speechbrain.pretrained import EncoderClassifier
self.encoder = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
savedir="~/.cache/speechbrain/",
run_opts={"device": "cuda"})
run_opts={"device": encoder_device})
audio_info = torchaudio.info(fname)
actual_sample_rate = audio_info.sample_rate
num_frames = actual_sample_rate * 30 # specify 30 seconds worth of frames
num_frames = actual_sample_rate * 30 # specify 30 seconds worth of frames
samples, sr = torchaudio.load(fname, num_frames=num_frames)
samples = samples[:, :num_frames]
samples = self.encoder.audio_normalizer(samples[0], sr)
spk_emb = self.encoder.encode_batch(samples.unsqueeze(0))

return spk_emb[0,0]

def generate_atoks(self, text, speaker=None, lang='en', cps=15, step_callback=None):
if speaker is None: speaker = self.default_speaker
elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker)
elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker).to(compute_device) # use obtained compute device
text = text.replace("\n", " ")
stoks = self.t2s.generate(text, cps=cps, lang=lang, step=step_callback)[0]
atoks = self.s2a.generate(stoks, speaker.unsqueeze(0), step=step_callback)
Expand Down
7 changes: 5 additions & 2 deletions whisperspeech/prepare_s2a_atoks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

from . import utils, vad_merge, extract_acoustic
import webdataset as wds
from utils import get_compute_device

compute_device = get_compute_device()

# %% ../nbs/3C. S2A acoustic tokens preparation.ipynb 4
@call_parse
Expand All @@ -27,7 +30,7 @@ def prepare_atoks(
batch_size:int=1, # process several segments at once
bandwidth:float=3,
):
amodel = extract_acoustic.load_model()
amodel = extract_acoustic.load_model().to(compute_device) # Move model to computed device
amodel.set_target_bandwidth(bandwidth)

total = n_samples//batch_size if n_samples else 'noinfer'
Expand All @@ -43,7 +46,7 @@ def prepare_atoks(

with utils.AtomicTarWriter(utils.derived_name(input, f'atoks-{bandwidth}kbps', dir="."), throwaway=n_samples is not None) as sink:
for keys, rpad_ss, samples in progress_bar(dl, total=total):
csamples = samples.cuda().unsqueeze(1)
csamples = samples.to(compute_device).unsqueeze(1) # Move tensors to computed device
atokss = amodel.encode(csamples)[0][0]
atokss = atokss.cpu().numpy().astype(np.int16)
for key, rpad_s, atoks in zip(keys, rpad_ss, atokss):
Expand Down
7 changes: 5 additions & 2 deletions whisperspeech/prepare_t2s_txts.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@
from . import utils, vad_merge
import webdataset as wds

from utils import get_compute_device
compute_device = get_compute_device()

# %% ../nbs/3A. T2S transcripts preparation.ipynb 4
class Transcriber:
"""
A helper class to transcribe a batch of 30 second audio chunks.
"""
def __init__(self, model_size, lang=False):
self.model = whisperx.asr.load_model(
model_size, "cuda", compute_type="float16", language=lang,
model_size, compute_device, compute_type="float16", language=lang,
asr_options=dict(repetition_penalty=1, no_repeat_ngram_size=0, prompt_reset_on_temperature=0.5))
# without calling vad_model at least once the rest segfaults for some reason...
self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000})
Expand Down Expand Up @@ -81,7 +84,7 @@ def prepare_txt(

with utils.AtomicTarWriter(utils.derived_name(input, f'{transcription_model}-txt', dir="."), throwaway=n_samples is not None) as sink:
for keys, rpads, samples in progress_bar(dl, total=total):
csamples = samples.cuda()
csamples = samples.to(compute_device)
txts = transcriber.transcribe(csamples)
# with torch.no_grad():
# embs = whmodel.encoder(whisper.log_mel_spectrogram(csamples))
Expand Down
9 changes: 6 additions & 3 deletions whisperspeech/s2a_delar_mup_wds_mlang.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 4
from .modules import *

from utils import get_compute_device
compute_device = get_compute_device()

# %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 8
def rand(start, end):
return random.random() * (end - start) + start
Expand Down Expand Up @@ -266,8 +269,8 @@ def __init__(self, depth=3, ctx_n=2250,
for l in self.decoder.layers:
l.cross_attn.key_subsampling = 3

self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
self.register_buffer('val_true', torch.zeros(self.quantizers).to(compute_device))
self.register_buffer('val_total', torch.zeros(self.quantizers).to(compute_device))
self.apply(self.init_transformer)

def setup(self, device):
Expand Down Expand Up @@ -399,7 +402,7 @@ def load_model(cls, ref="collabora/whisperspeech:s2a-q4-small-en+pl.model",
local_filename = ref
if not local_filename:
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
spec = torch.load(local_filename)
spec = torch.load(local_filename, map_location=compute_device)
if '_extra_state' not in spec['state_dict'] and 'speaker_map' in spec['config']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
model.load_state_dict(spec['state_dict'])
Expand Down
5 changes: 4 additions & 1 deletion whisperspeech/t2s_up_wds_mlang_enclm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
# %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 6
import re

from utils import get_compute_device
compute_device = get_compute_device()

class CharTokenizer:
"""Trivial tokenizer – just use UTF-8 bytes"""
eot = 0
Expand Down Expand Up @@ -333,7 +336,7 @@ def load_model(cls, ref="collabora/whisperspeech:t2s-small-en+pl.model",
local_filename = ref
if not local_filename:
local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
spec = torch.load(local_filename)
spec = torch.load(local_filename, map_location=compute_device)
model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
model.load_state_dict(spec['state_dict'])
model.eval()
Expand Down
8 changes: 8 additions & 0 deletions whisperspeech/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,11 @@ def AtomicTarWriter(name, throwaway=False):
def readlines(fname):
with open(fname) as file:
return [line.rstrip() for line in file]

def get_compute_device():
if torch.cuda.is_available() and (torch.version.cuda or torch.version.hip):
return 'cuda'
elif torch.backends.mps.is_available():
return 'mps'
else:
return 'cpu'
5 changes: 4 additions & 1 deletion whisperspeech/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@

import whisperx

from utils import get_compute_device
compute_device = get_compute_device()

# %% ../nbs/1B. Voice activity detection.ipynb 5
# some of the original file names have a dot in their name
# webdataset does not like it so let's patch it
Expand Down Expand Up @@ -60,7 +63,7 @@ def process_shard(

ds = load_dataset(input)
dl = torch.utils.data.DataLoader(ds, num_workers=2, batch_size=None)
vad_model = whisperx.vad.load_vad_model('cuda')
vad_model = whisperx.vad.load_vad_model(compute_device)

tmp = output+".tmp"
with wds.TarWriter(tmp) as sink:
Expand Down
7 changes: 5 additions & 2 deletions whisperspeech/vq_stoks.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@

from fastcore.script import *

from utils import get_compute_device
compute_device = get_compute_device()

# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 9
def merge_in(dataset_fun):
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function
Expand Down Expand Up @@ -270,8 +273,8 @@ def __init__(self, vq_codes=512, q_depth=12, depth=1, n_head=2, head_width=64, f
self.whmodel = None

self.apply(self.init_transformer)
self.register_buffer('val_true', torch.zeros(1).cuda())
self.register_buffer('val_total', torch.zeros(1).cuda())
self.register_buffer('val_true', torch.zeros(1).to(compute_device))
self.register_buffer('val_total', torch.zeros(1).to(compute_device))

def setup(self, device):
self.ensure_whisper(device)
Expand Down
5 changes: 4 additions & 1 deletion whisperspeech/wh_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from . import vad, utils
import webdataset as wds

from utils import get_compute_device
compute_device = get_compute_device()

# %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
# let's make it a bit more conservative
# with full 30 second chunks it sometimes misses a small part of the transcript
Expand Down Expand Up @@ -135,7 +138,7 @@ def process_shard(
with wds.TarWriter(tmp) as sink:
for keys, samples in progress_bar(dl, total=n_samples):
with torch.no_grad():
embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).to(compute_device))
decs = whmodel.decode(embs, decoding_options)
for key, dec in zip(keys, decs):
sink.write({
Expand Down

0 comments on commit 5f691c6

Please sign in to comment.