Skip to content

Commit

Permalink
Added optimizer for lm_alpha + lm_beta
Browse files Browse the repository at this point in the history
  • Loading branch information
kdavis-mozilla committed Feb 25, 2020
1 parent 47272a7 commit cf5bb9b
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
59 changes: 59 additions & 0 deletions optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import sys

import optuna
import absl.app
from ds_ctcdecoder import Scorer
import tensorflow.compat.v1 as tfv1

from DeepSpeech import create_model
from evaluate import evaluate
from util.config import Config, initialize_globals
from util.flags import create_flags, FLAGS
from util.logging import log_error
from util.evaluate_tools import wer_cer_batch


def character_based():
is_character_based = False
if FLAGS.scorer_path:
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
is_character_based = scorer.is_utf8_mode()
return is_character_based

def objective(trial):
FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)

tfv1.reset_default_graph()
samples = evaluate(FLAGS.test_files.split(','), create_model)

is_character_based = trial.study.user_attrs['is_character_based']

wer, cer = wer_cer_batch(samples)
return cer if is_character_based else wer

def main(_):
initialize_globals()

if not FLAGS.test_files:
log_error('You need to specify what files to use for evaluation via '
'the --test_files flag.')
sys.exit(1)

is_character_based = character_based()

study = optuna.create_study()
study.set_user_attr("is_character_based", is_character_based)
study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'],
study.best_params['lm_beta'],
study.best_value))


if __name__ == '__main__':
create_flags()
absl.app.run(main)
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ bs4
requests
librosa
soundfile

# Requirements for optimizer
optuna
6 changes: 6 additions & 0 deletions util/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def create_flags():

f.DEFINE_string('one_shot_infer', '', 'one-shot inference mode: specify a wav file and the script will load the checkpoint and perform inference on it.')

# Optimizer mode

f.DEFINE_float('lm_alpha_max', 5, 'the maximum of the alpha hyperparameter of the CTC decoder explored during hyperparameter optimization. Language Model weight.')
f.DEFINE_float('lm_beta_max', 5, 'the maximum beta hyperparameter of the CTC decoder explored during hyperparameter optimization. Word insertion weight.')
f.DEFINE_integer('n_trials', 2400, 'the number of trials to run during hyperparameter optimization.')

# Register validators for paths which require a file to be specified

f.register_validator('alphabet_config_path',
Expand Down

0 comments on commit cf5bb9b

Please sign in to comment.