Skip to content

Commit

Permalink
Merge pull request #1810 from mozilla/update-new-hparams
Browse files Browse the repository at this point in the history
Update decoder hyperparameters
  • Loading branch information
reuben authored Jan 2, 2019
2 parents ce551f5 + 7c1315b commit fc46f43
Show file tree
Hide file tree
Showing 13 changed files with 46 additions and 51 deletions.
9 changes: 4 additions & 5 deletions examples/ffmpeg_vad_streaming/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ const util = require('util');
const BEAM_WIDTH = 1024;

// The alpha hyperparameter of the CTC decoder. Language Model weight
const LM_WEIGHT = 1.50;
const LM_ALPHA = 0.75;

// Valid word insertion weight. This is used to lessen the word insertion penalty
// when the inserted word is part of the vocabulary
const VALID_WORD_COUNT_WEIGHT = 2.25;
// The beta hyperparameter of the CTC decoder. Word insertion bonus.
const LM_BETA = 1.85;

// These constants are tied to the shape of the graph used (changing them changes
// the geometry of the first layer), so make sure you use the same constants that
Expand Down Expand Up @@ -63,7 +62,7 @@ if (args['lm'] && args['trie']) {
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
const lm_load_start = process.hrtime();
model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'],
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
LM_ALPHA, LM_BETA);
const lm_load_end = process.hrtime(lm_load_start);
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
}
Expand Down
13 changes: 6 additions & 7 deletions examples/mic_vad_streaming/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ brew install portaudio
usage: mic_vad_streaming.py [-h] [-v VAD_AGGRESSIVENESS] [--nospinner]
[-w SAVEWAV] -m MODEL [-a ALPHABET] [-l LM]
[-t TRIE] [-nf N_FEATURES] [-nc N_CONTEXT]
[-lw LM_WEIGHT] [-vwcw VALID_WORD_COUNT_WEIGHT]
[-la LM_ALPHA] [-lb LM_BETA]
[-bw BEAM_WIDTH]
Stream from microphone to DeepSpeech using VAD
Expand Down Expand Up @@ -56,13 +56,12 @@ optional arguments:
-nc N_CONTEXT, --n_context N_CONTEXT
Size of the context window used for producing
timesteps in the input vector. Default: 9
-lw LM_WEIGHT, --lm_weight LM_WEIGHT
-la LM_ALPHA, --lm_alpha LM_ALPHA
The alpha hyperparameter of the CTC decoder. Language
Model weight. Default: 1.5
-vwcw VALID_WORD_COUNT_WEIGHT, --valid_word_count_weight VALID_WORD_COUNT_WEIGHT
Valid word insertion weight. This is used to lessen
the word insertion penalty when the inserted word is
part of the vocabulary. Default: 2.1
Model weight. Default: 0.75
-lb LM_BETA, --lm_beta LM_BETA
The beta hyperparameter of the CTC decoder. Word insertion
bonus. Default: 1.85
-bw BEAM_WIDTH, --beam_width BEAM_WIDTH
Beam width used in the CTC decoder when building
candidate transcriptions. Default: 500
Expand Down
14 changes: 7 additions & 7 deletions examples/mic_vad_streaming/mic_vad_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def main(ARGS):
if ARGS.lm and ARGS.trie:
logging.info("ARGS.lm: %s", ARGS.lm)
logging.info("ARGS.trie: %s", ARGS.trie)
model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_weight, ARGS.valid_word_count_weight)
model.enableDecoderWithLM(ARGS.alphabet, ARGS.lm, ARGS.trie, ARGS.lm_alpha, ARGS.lm_beta)

# Start audio with VAD
vad_audio = VADAudio(aggressiveness=ARGS.vad_aggressiveness)
Expand Down Expand Up @@ -148,8 +148,8 @@ def main(ARGS):

if __name__ == '__main__':
BEAM_WIDTH = 500
LM_WEIGHT = 1.50
VALID_WORD_COUNT_WEIGHT = 2.10
LM_ALPHA = 0.75
LM_BETA = 1.85
N_FEATURES = 26
N_CONTEXT = 9

Expand All @@ -175,10 +175,10 @@ def main(ARGS):
help=f"Number of MFCC features to use. Default: {N_FEATURES}")
parser.add_argument('-nc', '--n_context', type=int, default=N_CONTEXT,
help=f"Size of the context window used for producing timesteps in the input vector. Default: {N_CONTEXT}")
parser.add_argument('-lw', '--lm_weight', type=float, default=LM_WEIGHT,
help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_WEIGHT}")
parser.add_argument('-vwcw', '--valid_word_count_weight', type=float, default=VALID_WORD_COUNT_WEIGHT,
help=f"Valid word insertion weight. This is used to lessen the word insertion penalty when the inserted word is part of the vocabulary. Default: {VALID_WORD_COUNT_WEIGHT}")
parser.add_argument('-la', '--lm_alpha', type=float, default=LM_ALPHA,
help=f"The alpha hyperparameter of the CTC decoder. Language Model weight. Default: {LM_ALPHA}")
parser.add_argument('-lb', '--lm_beta', type=float, default=LM_BETA,
help=f"The beta hyperparameter of the CTC decoder. Word insertion bonus. Default: {LM_BETA}")
parser.add_argument('-bw', '--beam_width', type=int, default=BEAM_WIDTH,
help=f"Beam width used in the CTC decoder when building candidate transcriptions. Default: {BEAM_WIDTH}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ static void Main(string[] args)
const uint N_CEP = 26;
const uint N_CONTEXT = 9;
const uint BEAM_WIDTH = 200;
const float LM_WEIGHT = 1.50f;
const float VALID_WORD_COUNT_WEIGHT = 2.10f;
const float LM_ALPHA = 0.75f;
const float LM_BETA = 1.85f;

Stopwatch stopwatch = new Stopwatch();

Expand Down Expand Up @@ -76,7 +76,7 @@ static void Main(string[] args)
alphabet ?? "alphabet.txt",
lm ?? "lm.binary",
trie ?? "trie",
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
LM_ALPHA, LM_BETA);
}
catch (IOException ex)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ public partial class MainWindow : Window
private const uint N_CEP = 26;
private const uint N_CONTEXT = 9;
private const uint BEAM_WIDTH = 500;
private const float LM_WEIGHT = 1.50f;
private const float VALID_WORD_COUNT_WEIGHT = 2.10f;
private const float LM_ALPHA = 0.75f;
private const float LM_BETA = 1.85f;



Expand Down Expand Up @@ -160,7 +160,7 @@ await Task.Run(() =>
{
try
{
if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_WEIGHT, VALID_WORD_COUNT_WEIGHT) != 0)
if (_sttClient.EnableDecoderWithLM("alphabet.txt", "lm.binary", "trie", LM_ALPHA, LM_BETA) != 0)
{
MessageBox.Show("Error loading LM.");
Dispatcher.Invoke(() => btnEnableLM.IsEnabled = true);
Expand Down
6 changes: 3 additions & 3 deletions examples/vad_transcriber/wavTranscriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ def load_model(models, alphabet, lm, trie):
N_FEATURES = 26
N_CONTEXT = 9
BEAM_WIDTH = 500
LM_WEIGHT = 1.50
VALID_WORD_COUNT_WEIGHT = 2.10
LM_ALPHA = 0.75
LM_BETA = 1.85

model_load_start = timer()
ds = Model(models, N_FEATURES, N_CONTEXT, alphabet, BEAM_WIDTH)
model_load_end = timer() - model_load_start
logging.debug("Loaded model in %0.3fs." % (model_load_end))

lm_load_start = timer()
ds.enableDecoderWithLM(alphabet, lm, trie, LM_WEIGHT, VALID_WORD_COUNT_WEIGHT)
ds.enableDecoderWithLM(alphabet, lm, trie, LM_ALPHA, LM_BETA)
lm_load_end = timer() - lm_load_start
logging.debug('Loaded language model in %0.3fs.' % (lm_load_end))

Expand Down
8 changes: 4 additions & 4 deletions native_client/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#define N_CEP 26
#define N_CONTEXT 9
#define BEAM_WIDTH 500
#define LM_WEIGHT 1.50f
#define VALID_WORD_COUNT_WEIGHT 2.10f
#define LM_ALPHA 0.75f
#define LM_BETA 1.85f

typedef struct {
const char* string;
Expand Down Expand Up @@ -253,8 +253,8 @@ main(int argc, char **argv)
alphabet,
lm,
trie,
LM_WEIGHT,
VALID_WORD_COUNT_WEIGHT);
LM_ALPHA,
LM_BETA);
if (status != 0) {
fprintf(stderr, "Could not enable CTC decoder with LM.\n");
return 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public class DeepSpeechActivity extends AppCompatActivity {
final int N_CEP = 26;
final int N_CONTEXT = 9;
final int BEAM_WIDTH = 50;
final float LM_WEIGHT = 1.50f;
final float VALID_WORD_COUNT_WEIGHT = 2.10f;
final float LM_ALPHA = 0.75f;
final float LM_BETA = 1.85f;

private char readLEChar(RandomAccessFile f) throws IOException {
byte b1 = f.readByte();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ public void destroyModel() {
impl.DestroyModel(this._msp);
}

public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_weight, float valid_word_count_weight) {
impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_weight, valid_word_count_weight);
public void enableDecoderWihLM(String alphabet, String lm, String trie, float lm_alpha, float lm_beta) {
impl.EnableDecoderWithLM(this._msp, alphabet, lm, trie, lm_alpha, lm_beta);
}

public String stt(short[] buffer, int buffer_size, int sample_rate) {
Expand Down
9 changes: 4 additions & 5 deletions native_client/javascript/client.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ const util = require('util');
const BEAM_WIDTH = 500;

// The alpha hyperparameter of the CTC decoder. Language Model weight
const LM_WEIGHT = 1.50;
const LM_ALPHA = 0.75;

// Valid word insertion weight. This is used to lessen the word insertion penalty
// when the inserted word is part of the vocabulary
const VALID_WORD_COUNT_WEIGHT = 2.10;
// The beta hyperparameter of the CTC decoder. Word insertion bonus.
const LM_BETA = 1.85;


// These constants are tied to the shape of the graph used (changing them changes
Expand Down Expand Up @@ -102,7 +101,7 @@ audioStream.on('finish', () => {
console.error('Loading language model from files %s %s', args['lm'], args['trie']);
const lm_load_start = process.hrtime();
model.enableDecoderWithLM(args['alphabet'], args['lm'], args['trie'],
LM_WEIGHT, VALID_WORD_COUNT_WEIGHT);
LM_ALPHA, LM_BETA);
const lm_load_end = process.hrtime(lm_load_start);
console.error('Loaded language model in %ds.', totalTime(lm_load_end));
}
Expand Down
10 changes: 4 additions & 6 deletions native_client/python/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@
BEAM_WIDTH = 500

# The alpha hyperparameter of the CTC decoder. Language Model weight
LM_WEIGHT = 1.50
LM_ALPHA = 0.75

# Valid word insertion weight. This is used to lessen the word insertion penalty
# when the inserted word is part of the vocabulary
VALID_WORD_COUNT_WEIGHT = 2.10
# The beta hyperparameter of the CTC decoder. Word insertion bonus.
LM_BETA = 1.85


# These constants are tied to the shape of the graph used (changing them changes
Expand Down Expand Up @@ -85,8 +84,7 @@ def main():
if args.lm and args.trie:
print('Loading language model from files {} {}'.format(args.lm, args.trie), file=sys.stderr)
lm_load_start = timer()
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_WEIGHT,
VALID_WORD_COUNT_WEIGHT)
ds.enableDecoderWithLM(args.alphabet, args.lm, args.trie, LM_ALPHA, LM_BETA)
lm_load_end = timer() - lm_load_start
print('Loaded language model in {:.3}s.'.format(lm_load_end), file=sys.stderr)

Expand Down
4 changes: 2 additions & 2 deletions tc-tests-utils.sh
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ assert_correct_multi_ldc93s1()

assert_correct_ldc93s1_prodmodel()
{
assert_correct_inference "$1" "she had a due in greasy wash water year"
assert_correct_inference "$1" "she had a due and greasy wash water year"
}

assert_correct_ldc93s1_prodmodel_stereo_44k()
{
assert_correct_inference "$1" "she had a due in greasy wash water year"
assert_correct_inference "$1" "she had a due and greasy wash water year"
}

assert_correct_warning_upsampling()
Expand Down
4 changes: 2 additions & 2 deletions util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ def create_flags():
tf.app.flags.DEFINE_string ('lm_binary_path', 'data/lm/lm.binary', 'path to the language model binary file created with KenLM')
tf.app.flags.DEFINE_string ('lm_trie_path', 'data/lm/trie', 'path to the language model trie file created with native_client/generate_trie')
tf.app.flags.DEFINE_integer ('beam_width', 1024, 'beam width used in the CTC decoder when building candidate transcriptions')
tf.app.flags.DEFINE_float ('lm_alpha', 1.50, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
tf.app.flags.DEFINE_float ('lm_beta', 2.10, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')
tf.app.flags.DEFINE_float ('lm_alpha', 0.75, 'the alpha hyperparameter of the CTC decoder. Language Model weight.')
tf.app.flags.DEFINE_float ('lm_beta', 1.85, 'the beta hyperparameter of the CTC decoder. Word insertion weight.')

# Inference mode

Expand Down

0 comments on commit fc46f43

Please sign in to comment.