Skip to content

Commit

Permalink
Stop using whitenoise.wav, mp3 and torchaudio.load in sox effect test
Browse files Browse the repository at this point in the history
Part of #764

 - Replace `whitenoise.wav` with on-the-fly data generation
 - Replace `torchaudio.load` with `common_utils.load_wav`
 - Replace `steam-train-whistle-daniel_simon.mp3` with `.wav`
  • Loading branch information
engineerchuan authored Jul 14, 2020
1 parent 4b3e905 commit d11ad6b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 120 deletions.
22 changes: 19 additions & 3 deletions test/common_utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ def get_whitenoise(
seed: int = 0,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
channels_first=True,
scale_factor: float = 1,
):
"""Generate pseudo audio data with whitenoise
Args:
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
Expand All @@ -32,19 +33,34 @@ def get_whitenoise(
Note that this function does not modify global random generator state.
dtype: Torch dtype
device: device
channels_first: whether first dimension is n_channels
scale_factor: scale the Tensor before clamping and quantization
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
shape = [n_channels, sample_rate * duration]
if dtype not in [torch.float32, torch.int32, torch.int16, torch.uint8]:
raise NotImplementedError(f'dtype {dtype} is not supported.')
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only folk on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn(shape, dtype=dtype, device='cpu')
tensor = torch.randn([sample_rate * duration], dtype=torch.float32, device='cpu')
tensor /= 2.0
tensor *= scale_factor
tensor.clamp_(-1.0, 1.0)
if dtype == torch.int32:
tensor *= (tensor > 0) * 2147483647 + (tensor < 0) * 2147483648
if dtype == torch.int16:
tensor *= (tensor > 0) * 32767 + (tensor < 0) * 32768
if dtype == torch.uint8:
tensor *= (tensor > 0) * 127 + (tensor < 0) * 128
tensor += 128
tensor = tensor.to(dtype)
tensor = tensor.repeat([n_channels, 1])
if not channels_first:
tensor = tensor.t()
return tensor.to(device=device)


Expand Down
Loading

0 comments on commit d11ad6b

Please sign in to comment.