From f1911b122724ee2a4538116f4410a2c5689363a5 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 18 Jul 2017 10:14:37 +0800 Subject: [PATCH 1/3] add mfcc feature for DS2 --- deep_speech_2/README.md | 6 ++- deep_speech_2/compute_mean_std.py | 8 +++- .../data_utils/featurizer/audio_featurizer.py | 48 +++++++++++++++++-- .../featurizer/speech_featurizer.py | 15 +++--- deep_speech_2/data_utils/normalizer.py | 2 +- deep_speech_2/requirements.txt | 1 + deep_speech_2/train.py | 7 +++ 7 files changed, 74 insertions(+), 13 deletions(-) diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index 3b20bf4944..a92b671cb5 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -38,7 +38,11 @@ python datasets/librispeech/librispeech.py --help python compute_mean_std.py ``` -`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. +`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, currently the mfcc feature is also supported. To train and infer based on mfcc feature, you can regenerate this file by + +``` +python compute_mean_std.py --specgram_type mfcc +``` More help for arguments: diff --git a/deep_speech_2/compute_mean_std.py b/deep_speech_2/compute_mean_std.py index 9c301c93f6..0cc84e7302 100644 --- a/deep_speech_2/compute_mean_std.py +++ b/deep_speech_2/compute_mean_std.py @@ -10,6 +10,12 @@ parser = argparse.ArgumentParser( description='Computing mean and stddev for feature normalizer.') +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--manifest_path", default='datasets/manifest.train', @@ -39,7 +45,7 @@ def main(): augmentation_pipeline = AugmentationPipeline(args.augmentation_config) - audio_featurizer = AudioFeaturizer() + audio_featurizer = AudioFeaturizer(specgram_type=args.specgram_type) def augment_and_featurize(audio_segment): augmentation_pipeline.transform_audio(audio_segment) diff --git a/deep_speech_2/data_utils/featurizer/audio_featurizer.py b/deep_speech_2/data_utils/featurizer/audio_featurizer.py index 4b4d02c60f..271e535b6a 100644 --- a/deep_speech_2/data_utils/featurizer/audio_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/audio_featurizer.py @@ -6,13 +6,15 @@ import numpy as np from data_utils import utils from data_utils.audio import AudioSegment +from python_speech_features import mfcc +from python_speech_features import delta class AudioFeaturizer(object): """Audio featurizer, for extracting features from audio contents of AudioSegment or SpeechSegment. - Currently, it only supports feature type of linear spectrogram. + Currently, it supports feature types of linear spectrogram and mfcc. :param specgram_type: Specgram feature type. Options: 'linear'. :type specgram_type: str @@ -20,9 +22,10 @@ class AudioFeaturizer(object): :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins + :param max_freq: When specgram_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned. + returned; when specgram_type is 'mfcc', max_feq is the + highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Audio are resampled (if upsampling or downsampling is allowed) to this before @@ -91,6 +94,9 @@ def _compute_specgram(self, samples, sample_rate): return self._compute_linear_specgram( samples, sample_rate, self._stride_ms, self._window_ms, self._max_freq) + elif self._specgram_type == 'mfcc': + return self._compute_mfcc(samples, sample_rate, self._stride_ms, + self._window_ms, self._max_freq) else: raise ValueError("Unknown specgram_type %s. " "Supported values: linear." % self._specgram_type) @@ -142,3 +148,39 @@ def _specgram_real(self, samples, window_size, stride_size, sample_rate): # prepare fft frequency list freqs = float(sample_rate) / window_size * np.arange(fft.shape[0]) return fft, freqs + + def _compute_mfcc(self, + samples, + sample_rate, + stride_ms=10.0, + window_ms=20.0, + max_freq=None): + """Compute mfcc from samples.""" + if max_freq is None: + max_freq = sample_rate / 2 + if max_freq > sample_rate / 2: + raise ValueError("max_freq must be greater than half of " + "sample rate.") + if stride_ms > window_ms: + raise ValueError("Stride size must not be greater than " + "window size.") + # compute 13 cepstral coefficients, and the first one is replaced + # by log(frame energy) + mfcc_feat = mfcc( + signal=samples, + samplerate=sample_rate, + winlen=0.001 * window_ms, + winstep=0.001 * stride_ms, + highfreq=max_freq) + # Deltas + d_mfcc_feat = delta(mfcc_feat, 2) + # Deltas-Deltas + dd_mfcc_feat = delta(d_mfcc_feat, 2) + # concat above three features + concat_mfcc_feat = [ + np.concatenate((mfcc_feat[i], d_mfcc_feat[i], dd_mfcc_feat[i])) + for i in xrange(len(mfcc_feat)) + ] + # transpose to be consistent with the linear specgram situation + concat_mfcc_feat = np.transpose(concat_mfcc_feat) + return concat_mfcc_feat diff --git a/deep_speech_2/data_utils/featurizer/speech_featurizer.py b/deep_speech_2/data_utils/featurizer/speech_featurizer.py index 26283892e8..a947588db4 100644 --- a/deep_speech_2/data_utils/featurizer/speech_featurizer.py +++ b/deep_speech_2/data_utils/featurizer/speech_featurizer.py @@ -11,23 +11,24 @@ class SpeechFeaturizer(object): """Speech featurizer, for extracting features from both audio and transcript contents of SpeechSegment. - Currently, for audio parts, it only supports feature type of linear - spectrogram; for transcript parts, it only supports char-level tokenizing - and conversion into a list of token indices. Note that the token indexing - order follows the given vocabulary file. + Currently, for audio parts, it supports feature types of linear + spectrogram and mfcc; for transcript parts, it only supports char-level + tokenizing and conversion into a list of token indices. Note that the + token indexing order follows the given vocabulary file. :param vocab_filepath: Filepath to load vocabulary for token indices conversion. :type specgram_type: basestring - :param specgram_type: Specgram feature type. Options: 'linear'. + :param specgram_type: Specgram feature type. Options: 'linear', 'mfcc'. :type specgram_type: str :param stride_ms: Striding size (in milliseconds) for generating frames. :type stride_ms: float :param window_ms: Window size (in milliseconds) for generating frames. :type window_ms: float - :param max_freq: Used when specgram_type is 'linear', only FFT bins + :param max_freq: When specgram_type is 'linear', only FFT bins corresponding to frequencies between [0, max_freq] are - returned. + returned; when specgram_type is 'mfcc', max_freq is the + highest band edge of mel filters. :types max_freq: None|float :param target_sample_rate: Speech are resampled (if upsampling or downsampling is allowed) to this before diff --git a/deep_speech_2/data_utils/normalizer.py b/deep_speech_2/data_utils/normalizer.py index c123d25d20..1f4aae9a09 100644 --- a/deep_speech_2/data_utils/normalizer.py +++ b/deep_speech_2/data_utils/normalizer.py @@ -16,7 +16,7 @@ class FeatureNormalizer(object): if mean_std_filepath is provided (not None), the normalizer will directly initilize from the file. Otherwise, both manifest_path and featurize_func should be given for on-the-fly mean and stddev computing. - + :param mean_std_filepath: File containing the pre-computed mean and stddev. :type mean_std_filepath: None|basestring :param manifest_path: Manifest of instances for computing mean and stddev. diff --git a/deep_speech_2/requirements.txt b/deep_speech_2/requirements.txt index 2ae7d0895a..721fa28110 100755 --- a/deep_speech_2/requirements.txt +++ b/deep_speech_2/requirements.txt @@ -2,3 +2,4 @@ wget==3.2 scipy==0.13.1 resampy==0.1.5 https://github.com/kpu/kenlm/archive/master.zip +python_speech_features diff --git a/deep_speech_2/train.py b/deep_speech_2/train.py index 3a2d0cad9e..6481074c6e 100644 --- a/deep_speech_2/train.py +++ b/deep_speech_2/train.py @@ -53,6 +53,12 @@ default=True, type=distutils.util.strtobool, help="Use sortagrad or not. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--max_duration", default=27.0, @@ -130,6 +136,7 @@ def data_generator(): augmentation_config=args.augmentation_config, max_duration=args.max_duration, min_duration=args.min_duration, + specgram_type=args.specgram_type, num_threads=args.num_threads_data) train_generator = data_generator() From fa50fac4f8792d08f7d367bbcccab207f626d654 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Wed, 19 Jul 2017 22:40:01 +0800 Subject: [PATCH 2/3] update several scripts to support mfcc --- deep_speech_2/README.md | 2 ++ deep_speech_2/evaluate.py | 7 +++++++ deep_speech_2/infer.py | 7 +++++++ deep_speech_2/tune.py | 7 +++++++ 4 files changed, 23 insertions(+) diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index a92b671cb5..24f0b3c3fd 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -44,6 +44,8 @@ python compute_mean_std.py python compute_mean_std.py --specgram_type mfcc ``` +and specify the ```specgram_type``` to ```mfcc``` in each step, including training, inference etc. + More help for arguments: ``` diff --git a/deep_speech_2/evaluate.py b/deep_speech_2/evaluate.py index 00516dcbf0..19eabf4e5a 100644 --- a/deep_speech_2/evaluate.py +++ b/deep_speech_2/evaluate.py @@ -86,6 +86,12 @@ default=500, type=int, help="Width for beam search decoding. (default: %(default)d)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--decode_manifest_path", default='datasets/manifest.test', @@ -111,6 +117,7 @@ def evaluate(): vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config diff --git a/deep_speech_2/infer.py b/deep_speech_2/infer.py index bb81feac16..8175263027 100644 --- a/deep_speech_2/infer.py +++ b/deep_speech_2/infer.py @@ -51,6 +51,12 @@ default=multiprocessing.cpu_count(), type=int, help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -118,6 +124,7 @@ def infer(): vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config diff --git a/deep_speech_2/tune.py b/deep_speech_2/tune.py index 19a2d55951..2fcca48628 100644 --- a/deep_speech_2/tune.py +++ b/deep_speech_2/tune.py @@ -50,6 +50,12 @@ default=multiprocessing.cpu_count(), type=int, help="Number of cpu processes for beam search. (default: %(default)s)") +parser.add_argument( + "--specgram_type", + default='linear', + type=str, + help="Feature type of audio data: 'linear' (power spectrum)" + " or 'mfcc'. (default: %(default)s)") parser.add_argument( "--mean_std_filepath", default='mean_std.npz', @@ -133,6 +139,7 @@ def tune(): vocab_filepath=args.vocab_filepath, mean_std_filepath=args.mean_std_filepath, augmentation_config='{}', + specgram_type=args.specgram_type, num_threads=args.num_threads_data) # create network config From 653d59fa29aecba83a07368adfb1e3b66a601ec8 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 20 Jul 2017 11:47:46 +0800 Subject: [PATCH 3/3] follow comments to modify README.md --- deep_speech_2/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deep_speech_2/README.md b/deep_speech_2/README.md index 24f0b3c3fd..3010c0e536 100644 --- a/deep_speech_2/README.md +++ b/deep_speech_2/README.md @@ -38,13 +38,13 @@ python datasets/librispeech/librispeech.py --help python compute_mean_std.py ``` -`python compute_mean_std.py` computes mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, currently the mfcc feature is also supported. To train and infer based on mfcc feature, you can regenerate this file by +It will compute mean and stdandard deviation for audio features, and save them to a file with a default name `./mean_std.npz`. This file will be used in both training and inferencing. The default feature of audio data is power spectrum, and the mfcc feature is also supported. To train and infer based on mfcc feature, please generate this file by ``` python compute_mean_std.py --specgram_type mfcc ``` -and specify the ```specgram_type``` to ```mfcc``` in each step, including training, inference etc. +and specify ```--specgram_type mfcc``` when running train.py, infer.py, evaluator.py or tune.py. More help for arguments: