Skip to content

Commit d8bf04e

Browse files
committed
audio preprocessing tutorial.
1 parent d13664e commit d8bf04e

File tree

2 files changed

+273
-4
lines changed

2 files changed

+273
-4
lines changed
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
"""
2+
Torchaudio Tutorial
3+
===================
4+
5+
PyTorch is an open source deep learning platform that provides a
6+
seamless path from research prototyping to production deployment with
7+
GPU support.
8+
9+
Significant effort in solving machine learning problems goes into data
10+
preparation. Torchaudio leverages PyTorch’s GPU support, and provides
11+
many tools to make data loading easy and more readable. In this
12+
tutorial, we will see how to load and preprocess data from a simple
13+
dataset.
14+
15+
For this tutorial, please make sure the ``matplotlib`` package is
16+
installed for easier visualization.
17+
18+
"""
19+
20+
import torch
21+
import torchaudio
22+
import matplotlib.pyplot as plt
23+
24+
25+
######################################################################
26+
# Opening a dataset
27+
# -----------------
28+
#
29+
30+
31+
######################################################################
32+
# Torchaudio supports loading sound files in the wav and mp3 format.
33+
#
34+
35+
filename = "assets/steam-train-whistle-daniel_simon-converted-from-mp3.wav"
36+
waveform, frequency = torchaudio.load(filename)
37+
38+
print("Shape of waveform: {}".format(waveform.size()))
39+
print("Frequency of waveform: {}".format(frequency))
40+
41+
plt.figure()
42+
plt.plot(waveform.transpose(0,1).numpy())
43+
44+
45+
######################################################################
46+
# Transformations
47+
# ---------------
48+
#
49+
# Torchaudio supports a growing list of
50+
# `transformations <https://pytorch.org/audio/transforms.html>`_.
51+
#
52+
# - **Scale**: Scale audio tensor from a 16-bit integer (represented as a
53+
# FloatTensor) to a floating point number between -1.0 and 1.0. Note
54+
# the 16-bit number is called the “bit depth” or “precision”, not to be
55+
# confused with “bit rate”.
56+
# - **PadTrim**: PadTrim a 2d-Tensor
57+
# - **Downmix**: Downmix any stereo signals to mono.
58+
# - **LC2CL**: Permute a 2d tensor from samples (n x c) to (c x n).
59+
# - **Resample**: Resample the signal to a different frequency.
60+
# - **Spectrogram**: Create a spectrogram from a raw audio signal
61+
# - **MelScale**: This turns a normal STFT into a mel frequency STFT,
62+
# using a conversion matrix. This uses triangular filter banks.
63+
# - **SpectrogramToDB**: This turns a spectrogram from the
64+
# power/amplitude scale to the decibel scale.
65+
# - **MFCC**: Create the Mel-frequency cepstrum coefficients from an
66+
# audio signal
67+
# - **MelSpectrogram**: Create MEL Spectrograms from a raw audio signal
68+
# using the STFT function in PyTorch.
69+
# - **BLC2CBL**: Permute a 3d tensor from Bands x Sample length x
70+
# Channels to Channels x Bands x Samples length.
71+
# - **MuLawEncoding**: Encode signal based on mu-law companding.
72+
# - **MuLawExpanding**: Decode mu-law encoded signal.
73+
#
74+
# Since all transforms are nn.Modules or jit.ScriptModules, they can be
75+
# used as part of a neural network at any point.
76+
#
77+
78+
79+
######################################################################
80+
# To start, we can look at the log of the spectrogram on a log scale.
81+
#
82+
83+
specgram = torchaudio.transforms.Spectrogram()(waveform)
84+
85+
print("Shape of spectrogram: {}".format(specgram.size()))
86+
87+
plt.figure()
88+
plt.imshow(specgram.log2().transpose(1,2)[0,:,:].numpy(), cmap='gray')
89+
90+
91+
######################################################################
92+
# Or we can look at the Mel Spectrogram on a log scale.
93+
#
94+
95+
specgram = torchaudio.transforms.MelSpectrogram()(waveform)
96+
97+
print("Shape of spectrogram: {}".format(specgram.size()))
98+
99+
plt.figure()
100+
p = plt.imshow(specgram.log2().transpose(1,2)[0,:,:].detach().numpy(), cmap='gray')
101+
102+
103+
######################################################################
104+
# We can resample the signal, one channel at a time.
105+
#
106+
107+
new_frequency = frequency/10
108+
109+
# Since Resample applies to a single channel, we resample first channel here
110+
channel = 0
111+
transformed = torchaudio.transforms.Resample(frequency, new_frequency)(waveform[channel,:].view(1,-1))
112+
113+
print("Shape of transformed waveform: {}".format(transformed.size()))
114+
115+
plt.figure()
116+
plt.plot(transformed[0,:].numpy())
117+
118+
119+
######################################################################
120+
# Or we can first convert the stereo to mono, and resample, using
121+
# composition.
122+
#
123+
124+
transformed = torchaudio.transforms.Compose([
125+
torchaudio.transforms.LC2CL(),
126+
torchaudio.transforms.DownmixMono(),
127+
torchaudio.transforms.LC2CL(),
128+
torchaudio.transforms.Resample(frequency, new_frequency)
129+
])(waveform)
130+
131+
print("Shape of transformed waveform: {}".format(transformed.size()))
132+
133+
plt.figure()
134+
plt.plot(transformed[0,:].numpy())
135+
136+
137+
######################################################################
138+
# As another example of transformations, we can encode the signal based on
139+
# the Mu-Law companding. But to do so, we need the signal to be between -1
140+
# and 1. Since the tensor is just a regular PyTorch tensor, we can apply
141+
# standard operators on it.
142+
#
143+
144+
# Let's check if the tensor is in the interval [-1,1]
145+
print("Min of waveform: {}\nMax of waveform: {}\nMean of waveform: {}".format(waveform.min(), waveform.max(), waveform.mean()))
146+
147+
148+
######################################################################
149+
# Since the waveform is already between -1 and 1, we do not need to
150+
# normalize it.
151+
#
152+
153+
def normalize(tensor):
154+
# Subtract the mean, and scale to the interval [-1,1]
155+
tensor_minusmean = tensor - tensor.mean()
156+
return tensor_minusmean/tensor_minusmean.abs().max()
157+
158+
# Let's normalize to the full interval [-1,1]
159+
# waveform = normalize(waveform)
160+
161+
162+
######################################################################
163+
# Let’s apply encode the waveform.
164+
#
165+
166+
transformed = torchaudio.transforms.MuLawEncoding()(waveform)
167+
168+
print("Shape of transformed waveform: {}".format(transformed.size()))
169+
170+
plt.figure()
171+
plt.plot(transformed[0,:].numpy())
172+
173+
174+
######################################################################
175+
# And now decode.
176+
#
177+
178+
reconstructed = torchaudio.transforms.MuLawExpanding()(transformed)
179+
180+
print("Shape of recovered waveform: {}".format(reconstructed.size()))
181+
182+
plt.figure()
183+
plt.plot(reconstructed[0,:].numpy())
184+
185+
186+
######################################################################
187+
# We can finally compare the original waveform with its reconstructed
188+
# version.
189+
#
190+
191+
# Compute median relative difference
192+
err = ((waveform-reconstructed).abs() / waveform.abs()).median()
193+
194+
print("Median relative difference between original and MuLaw reconstucted signals: {:.2%}".format(err))
195+
196+
197+
######################################################################
198+
# Migrating to Torchaudio from Kaldi
199+
# ----------------------------------
200+
#
201+
# Users may be familiar with
202+
# `Kaldi <http://github.com/kaldi-asr/kaldi>`_, a toolkit for speech
203+
# recognition. Torchaudio offers compatibility with it in
204+
# ``torchaudio.kaldi_io``. It can indeed read from kaldi scp, or ark file
205+
# or streams with:
206+
#
207+
# - read_vec_int_ark
208+
# - read_vec_flt_scp
209+
# - read_vec_flt_arkfile/stream
210+
# - read_mat_scp
211+
# - read_mat_ark
212+
#
213+
# Torchaudio provides Kaldi-compatible transforms for ``spectrogram`` and
214+
# ``fbank`` with the benefit of GPU support, see
215+
# `here <compliance.kaldi.html>`__ for more information.
216+
#
217+
218+
n_fft = 400.0
219+
frame_length = n_fft / frequency * 1000.0
220+
frame_shift = frame_length / 2.0
221+
222+
params = {
223+
"channel": 0,
224+
"dither": 0.0,
225+
"window_type": "hanning",
226+
"frame_length": frame_length,
227+
"frame_shift": frame_shift,
228+
"remove_dc_offset": False,
229+
"round_to_power_of_two": False,
230+
"sample_frequency": frequency,
231+
}
232+
233+
specgram = torchaudio.compliance.kaldi.spectrogram(waveform, **params)
234+
235+
print("Shape of spectrogram: {}".format(specgram.size()))
236+
237+
plt.figure()
238+
plt.imshow(specgram.transpose(0,1).numpy(), cmap='gray')
239+
240+
241+
######################################################################
242+
# We also support computing the filterbank features from raw audio signal,
243+
# matching Kaldi’s implementation.
244+
#
245+
246+
fbank = torchaudio.compliance.kaldi.fbank(waveform, **params)
247+
248+
print("Shape of fbank: {}".format(fbank.size()))
249+
250+
plt.figure()
251+
plt.imshow(fbank.transpose(0,1).numpy(), cmap='gray')
252+
253+
254+
######################################################################
255+
# Conclusion
256+
# ----------
257+
#
258+
# We used an example sound signal to illustrate how to open an audio file
259+
# or using Torchaudio, and how to pre-process and transform an audio
260+
# signal. Given that Torchaudio is built on PyTorch, these techniques can
261+
# be used as building blocks for more advanced audio applications, such as
262+
# speech recognition, while leveraging GPUs.
263+
#

index.rst

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,15 @@ Image
109109
<div style='clear:both'></div>
110110

111111

112-
.. Audio
113-
.. ----------------------
112+
Audio
113+
----------------------
114+
115+
.. customgalleryitem::
116+
:figure: /_static/img/cat.jpg
117+
:tooltip: Preprocessing with Torchaudio Tutorial
118+
:description: :doc:`beginner/audio_preprocessing_tutorial`
114119

115-
.. Uncomment below when adding content
116-
.. .. raw:: html
120+
.. raw:: html
117121

118122
<div style='clear:both'></div>
119123

@@ -285,6 +289,8 @@ PyTorch in Other Languages
285289
:hidden:
286290
:caption: Audio
287291

292+
beginner/audio_preprocessing_tutorial
293+
288294
.. toctree::
289295
:maxdepth: 2
290296
:includehidden:

0 commit comments

Comments
 (0)