Skip to content

Commit

Permalink
v5 initial push
Browse files Browse the repository at this point in the history
  • Loading branch information
adamnsandle committed Jun 27, 2024
1 parent 8145ed9 commit fd1f1a6
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 38 deletions.
Binary file modified files/silero_vad.jit
Binary file not shown.
Binary file modified files/silero_vad.onnx
Binary file not shown.
44 changes: 29 additions & 15 deletions silero-vad.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"USE_ONNX = False # change this to True if you want to test onnx model\n",
"if USE_ONNX:\n",
" !pip install -q onnxruntime\n",
" \n",
"\n",
"model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',\n",
" model='silero_vad',\n",
" force_reload=True,\n",
Expand All @@ -65,16 +65,7 @@
"id": "fXbbaUO3jsrw"
},
"source": [
"## Full Audio"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RAfJPb_a-Auj"
},
"source": [
"**Speech timestapms from full audio**"
"## Speech timestapms from full audio"
]
},
{
Expand All @@ -101,10 +92,33 @@
"source": [
"# merge all speech chunks to one audio\n",
"save_audio('only_speech.wav',\n",
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE) \n",
" collect_chunks(speech_timestamps, wav), sampling_rate=SAMPLING_RATE)\n",
"Audio('only_speech.wav')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zeO1xCqxUC6w"
},
"source": [
"## Entire audio inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "LjZBcsaTT7Mk"
},
"outputs": [],
"source": [
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"# audio is being splitted into 31.25 ms long pieces\n",
"# so output length equals ceil(input_length * 31.25 / SAMPLING_RATE)\n",
"predicts = model.audio_forward(wav, sr=SAMPLING_RATE)"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand All @@ -124,10 +138,10 @@
"source": [
"## using VADIterator class\n",
"\n",
"vad_iterator = VADIterator(model)\n",
"vad_iterator = VADIterator(model, sampling_rate=SAMPLING_RATE)\n",
"wav = read_audio(f'en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"\n",
"window_size_samples = 1536 # number of samples in a single audio chunk\n",
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
"for i in range(0, len(wav), window_size_samples):\n",
" chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
Expand All @@ -150,7 +164,7 @@
"\n",
"wav = read_audio('en_example.wav', sampling_rate=SAMPLING_RATE)\n",
"speech_probs = []\n",
"window_size_samples = 1536\n",
"window_size_samples = 512 if SAMPLING_RATE == 16000 else 256\n",
"for i in range(0, len(wav), window_size_samples):\n",
" chunk = wav[i: i+ window_size_samples]\n",
" if len(chunk) < window_size_samples:\n",
Expand Down
55 changes: 32 additions & 23 deletions utils_vad.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torchaudio
from typing import Callable, List
import torch.nn.functional as F
import warnings

languages = ['ru', 'en', 'de', 'es']
Expand Down Expand Up @@ -39,22 +38,27 @@ def _validate_input(self, x, sr: int):

if sr not in self.sample_rates:
raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)")

if sr / x.shape[1] > 31.25:
raise ValueError("Input audio chunk is too short")

return x, sr

def reset_states(self, batch_size=1):
self._h = np.zeros((2, batch_size, 64)).astype('float32')
self._c = np.zeros((2, batch_size, 64)).astype('float32')
self._state = torch.zeros((2, batch_size, 128)).float()
self._context = torch.zeros(0)
self._last_sr = 0
self._last_batch_size = 0

def __call__(self, x, sr: int):

x, sr = self._validate_input(x, sr)
num_samples = 512 if sr == 16000 else 256

if x.shape[-1] != num_samples:
raise ValueError(f"Provided number of samples is {x.shape[-1]} (Supported values: 256 for 8000 sample rate, 512 for 16000)")

batch_size = x.shape[0]
context_size = 64 if sr == 16000 else 32

if not self._last_batch_size:
self.reset_states(batch_size)
Expand All @@ -63,28 +67,35 @@ def __call__(self, x, sr: int):
if (self._last_batch_size) and (self._last_batch_size != batch_size):
self.reset_states(batch_size)

if not len(self._context):
self._context = torch.zeros(batch_size, context_size)

x = torch.cat([self._context, x], dim=1)
if sr in [8000, 16000]:
ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')}
ort_inputs = {'input': x.numpy(), 'state': self._state.numpy(), 'sr': np.array(sr)}
ort_outs = self.session.run(None, ort_inputs)
out, self._h, self._c = ort_outs
out, state = ort_outs
self._state = torch.from_numpy(state)
else:
raise ValueError()

self._context = x[..., -context_size:]
self._last_sr = sr
self._last_batch_size = batch_size

out = torch.tensor(out)
out = torch.from_numpy(out)
return out

def audio_forward(self, x, sr: int, num_samples: int = 512):
def audio_forward(self, x, sr: int):
outs = []
x, sr = self._validate_input(x, sr)
self.reset_states()
num_samples = 512 if sr == 16000 else 256

if x.shape[1] % num_samples:
pad_num = num_samples - (x.shape[1] % num_samples)
x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0)

self.reset_states(x.shape[0])
for i in range(0, x.shape[1], num_samples):
wavs_batch = x[:, i:i+num_samples]
out_chunk = self.__call__(wavs_batch, sr)
Expand Down Expand Up @@ -179,11 +190,11 @@ def get_speech_timestamps(audio: torch.Tensor,
min_speech_duration_ms: int = 250,
max_speech_duration_s: float = float('inf'),
min_silence_duration_ms: int = 100,
window_size_samples: int = 512,
speech_pad_ms: int = 30,
return_seconds: bool = False,
visualize_probs: bool = False,
progress_tracking_callback: Callable[[float], None] = None):
progress_tracking_callback: Callable[[float], None] = None,
window_size_samples: int = 512,):

"""
This method is used for splitting long audios into speech chunks using silero VAD
Expand All @@ -193,14 +204,14 @@ def get_speech_timestamps(audio: torch.Tensor,
audio: torch.Tensor, one dimensional
One dimensional float torch.Tensor, other types are casted to torch if possible
model: preloaded .jit silero VAD model
model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
It is better to tune this parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
sampling_rate: int (default - 16000)
Currently silero VAD models support 8000 and 16000 sample rates
Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates
min_speech_duration_ms: int (default - 250 milliseconds)
Final speech chunks shorter min_speech_duration_ms are thrown out
Expand All @@ -213,11 +224,6 @@ def get_speech_timestamps(audio: torch.Tensor,
min_silence_duration_ms: int (default - 100 milliseconds)
In the end of each speech chunk wait for min_silence_duration_ms before separating it
window_size_samples: int (default - 1536 samples)
Audio chunks of window_size_samples size are fed to the silero VAD model.
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate and 256, 512, 768 samples for 8000 sample rate.
Values other than these may affect model perfomance!!
speech_pad_ms: int (default - 30 milliseconds)
Final speech chunks are padded by speech_pad_ms each side
Expand All @@ -230,6 +236,9 @@ def get_speech_timestamps(audio: torch.Tensor,
progress_tracking_callback: Callable[[float], None] (default - None)
callback function taking progress in percents as an argument
window_size_samples: int (default - 512 samples)
!!! DEPRECATED, DOES NOTHING !!!
Returns
----------
speeches: list of dicts
Expand All @@ -256,10 +265,10 @@ def get_speech_timestamps(audio: torch.Tensor,
else:
step = 1

if sampling_rate == 8000 and window_size_samples > 768:
warnings.warn('window_size_samples is too big for 8000 sampling_rate! Better set window_size_samples to 256, 512 or 768 for 8000 sample rate!')
if window_size_samples not in [256, 512, 768, 1024, 1536]:
warnings.warn('Unusual window_size_samples! Supported window_size_samples:\n - [512, 1024, 1536] for 16000 sampling_rate\n - [256, 512, 768] for 8000 sampling_rate')
if sampling_rate not in [8000, 16000]:
raise ValueError("Currently silero VAD models support 8000 and 16000 (or multiply of 16000) sample rates")

window_size_samples = 512 if sampling_rate == 16000 else 256

model.reset_states()
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
Expand Down Expand Up @@ -450,7 +459,7 @@ def __init__(self,
Parameters
----------
model: preloaded .jit silero VAD model
model: preloaded .jit/.onnx silero VAD model
threshold: float (default - 0.5)
Speech threshold. Silero VAD outputs speech probabilities for each audio chunk, probabilities ABOVE this value are considered as SPEECH.
Expand Down

0 comments on commit fd1f1a6

Please sign in to comment.