Skip to content

Commit

Permalink
Update VAD docstring and check for input shape length (#1513)
Browse files Browse the repository at this point in the history
* Update VAD docstring and check for input shape length

* Update docstring in forward for transform

* Address review feedback: merge tests, update wording
  • Loading branch information
Artyom Astafurov authored May 21, 2021
1 parent 22fe802 commit 08f2bde
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
25 changes: 25 additions & 0 deletions test/torchaudio_unittest/transforms/sox_compatibility_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import warnings

import torch
import torchaudio.transforms as T
from parameterized import parameterized

Expand Down Expand Up @@ -61,3 +64,25 @@ def test_vad(self, filename):
data, sample_rate = load_wav(path)
result = T.Vad(sample_rate)(data)
self.assert_sox_effect(result, path, ['vad'])

def test_vad_warning(self):
"""vad should throw a warning if input dimension is greater than 2"""
sample_rate = 41100

data = torch.rand(5, 5, sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 1

data = torch.rand(5, sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 0

data = torch.rand(sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 0
15 changes: 14 additions & 1 deletion torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings
from typing import Optional

import torch
Expand Down Expand Up @@ -1374,7 +1375,10 @@ def vad(
so in order to trim from the back, the reverse effect must also be used.
Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
Tensor of shape `(channels, time)` is treated as a multi-channel recording
of the same event and the resulting output will be trimmed to the earliest
voice activity in any channel.
sample_rate (int): Sample rate of audio signal.
trigger_level (float, optional): The measurement level used to trigger activity detection.
This may need to be cahnged depending on the noise level, signal level,
Expand Down Expand Up @@ -1420,6 +1424,15 @@ def vad(
http://sox.sourceforge.net/sox.html
"""

if waveform.ndim > 2:
warnings.warn(
"Expected input tensor dimension of 1 for single channel"
f" or 2 for multi-channel. Got {waveform.ndim} instead. "
"Batch semantics is not supported. "
"Please refer to https://github.com/pytorch/audio/issues/1348"
" and https://github.com/pytorch/audio/issues/1468."
)

measure_duration: float = (
2.0 / measure_freq if measure_duration is None else measure_duration
)
Expand Down
5 changes: 4 additions & 1 deletion torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,10 @@ def __init__(self,
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
Tensor of shape `(channels, time)` is treated as a multi-channel recording
of the same event and the resulting output will be trimmed to the earliest
voice activity in any channel.
"""
return F.vad(
waveform=waveform,
Expand Down

0 comments on commit 08f2bde

Please sign in to comment.