Skip to content

Commit

Permalink
Merge pull request #1791 from mozilla/fix-alphabet-handling
Browse files Browse the repository at this point in the history
Fix handling of Alphabet around evaluate.py
  • Loading branch information
reuben authored Dec 26, 2018
2 parents e549f89 + dd6938f commit ce551f5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
3 changes: 1 addition & 2 deletions DeepSpeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,7 @@ def test():
hdf5_cache_path=FLAGS.test_cached_features_path)

graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)

evaluate.evaluate(test_data, graph, Config.alphabet)
evaluate.evaluate(test_data, graph)


def create_inference_graph(batch_size=1, n_steps=16, tflite=False):
Expand Down
13 changes: 5 additions & 8 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def calculate_report(labels, decodings, distances, losses):
return samples_wer, samples


def evaluate(test_data, inference_graph, alphabet):
def evaluate(test_data, inference_graph):
scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta,
FLAGS.lm_binary_path, FLAGS.lm_trie_path,
Config.alphabet)
Expand Down Expand Up @@ -175,10 +175,10 @@ def create_windows(features):
# Second pass, decode logits and compute WER and edit distance metrics
for logits, batch in bar(zip(logitses, split_data(test_data, FLAGS.test_batch_size))):
seq_lengths = batch['features_len'].values.astype(np.int32)
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, alphabet, FLAGS.beam_width,
decoded = ctc_beam_search_decoder_batch(logits, seq_lengths, Config.alphabet, FLAGS.beam_width,
num_processes=num_processes, scorer=scorer)

ground_truths.extend(alphabet.decode(l) for l in batch['transcript'])
ground_truths.extend(Config.alphabet.decode(l) for l in batch['transcript'])
predictions.extend(d[0][1] for d in decoded)

distances = [levenshtein(a, b) for a, b in zip(ground_truths, predictions)]
Expand Down Expand Up @@ -211,14 +211,11 @@ def main(_):
'the --test_files flag.')
exit(1)

global alphabet
alphabet = Alphabet(FLAGS.alphabet_config_path)

# sort examples by length, improves packing of batches and timesteps
test_data = preprocess(
FLAGS.test_files.split(','),
FLAGS.test_batch_size,
alphabet=alphabet,
alphabet=Config.alphabet,
numcep=Config.n_input,
numcontext=Config.n_context,
hdf5_cache_path=FLAGS.hdf5_test_set).sort_values(
Expand All @@ -228,7 +225,7 @@ def main(_):
from DeepSpeech import create_inference_graph
graph = create_inference_graph(batch_size=FLAGS.test_batch_size, n_steps=-1)

samples = evaluate(test_data, graph, alphabet)
samples = evaluate(test_data, graph)

if FLAGS.test_output_file:
# Save decoded tuples as JSON, converting NumPy floats to Python floats
Expand Down

0 comments on commit ce551f5

Please sign in to comment.