From c2ba31143102069b6d16c06170425dc276341993 Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Tue, 3 Apr 2018 01:14:45 -0700 Subject: [PATCH 1/2] add command line option to control the num_inter_threads and num_intra_threads for inference session --- nmt/inference.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nmt/inference.py b/nmt/inference.py index 6f589337a..cf7924b5d 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -131,7 +131,10 @@ def single_worker_inference(infer_model, infer_data = load_data(inference_input_file, hparams) with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads + )) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run( @@ -190,7 +193,10 @@ def multi_worker_inference(infer_model, infer_data = infer_data[start_position:end_position] with tf.Session( - graph=infer_model.graph, config=utils.get_config_proto()) as sess: + graph=infer_model.graph, config=utils.get_config_proto( + num_intra_threads=hparams.num_intra_threads, + num_inter_threads=hparams.num_inter_threads + )) as sess: loaded_infer_model = model_helper.load_model( infer_model.model, ckpt, sess, "infer") sess.run(infer_model.iterator.initializer, From ff4dbe4f21ef465dce6474fba13ce03cb6fce6ce Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Tue, 14 Aug 2018 10:26:58 -0700 Subject: [PATCH 2/2] Added sorting algorithm for the input sentences based on the length, which could double the performance for large inference batch size --- nmt/inference.py | 17 +++++++++++++++-- nmt/utils/nmt_utils.py | 15 +++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/nmt/inference.py b/nmt/inference.py index cf7924b5d..e1bd0ff62 100644 --- a/nmt/inference.py +++ b/nmt/inference.py @@ -129,6 +129,18 @@ def single_worker_inference(infer_model, # Read data infer_data = load_data(inference_input_file, hparams) + infer_data_feed = infer_data + + #sort the input file if no hparams.inference_indices is defined + index_pair = {} + new_input =[] + if hparams.inference_indices is None: + input_length = [(len(line.split()), i) for i, line in enumerate(infer_data)] + sorted_input_bylens = sorted(input_length) + for ni, (_, oi) in enumerate(sorted_input_bylens): + new_input.append(infer_data[oi]) + index_pair[oi] = ni + infer_data_feed = new_input with tf.Session( graph=infer_model.graph, config=utils.get_config_proto( @@ -140,7 +152,7 @@ def single_worker_inference(infer_model, sess.run( infer_model.iterator.initializer, feed_dict={ - infer_model.src_placeholder: infer_data, + infer_model.src_placeholder: infer_data_feed, infer_model.batch_size_placeholder: hparams.infer_batch_size }) # Decode @@ -165,7 +177,8 @@ def single_worker_inference(infer_model, subword_option=hparams.subword_option, beam_width=hparams.beam_width, tgt_eos=hparams.eos, - num_translations_per_input=hparams.num_translations_per_input) + num_translations_per_input=hparams.num_translations_per_input, + index_pair=index_pair) def multi_worker_inference(infer_model, diff --git a/nmt/utils/nmt_utils.py b/nmt/utils/nmt_utils.py index 72f71b5c2..ee9976947 100644 --- a/nmt/utils/nmt_utils.py +++ b/nmt/utils/nmt_utils.py @@ -37,7 +37,8 @@ def decode_and_evaluate(name, beam_width, tgt_eos, num_translations_per_input=1, - decode=True): + decode=True, + index_pair=[]): """Decode a test set and compute a score according to the evaluation task.""" # Decode if decode: @@ -51,6 +52,7 @@ def decode_and_evaluate(name, num_translations_per_input = max( min(num_translations_per_input, beam_width), 1) + translation = [] while True: try: nmt_outputs, _ = model.decode(sess) @@ -62,17 +64,22 @@ def decode_and_evaluate(name, for sent_id in range(batch_size): for beam_id in range(num_translations_per_input): - translation = get_translation( + translation.append(get_translation( nmt_outputs[beam_id], sent_id, tgt_eos=tgt_eos, - subword_option=subword_option) - trans_f.write((translation + b"\n").decode("utf-8")) + subword_option=subword_option)) except tf.errors.OutOfRangeError: utils.print_time( " done, num sentences %d, num translations per input %d" % (num_sentences, num_translations_per_input), start_time) break + if len(index_pair) is 0: + for sentence in translation: + trans_f.write(sentence + b"\n").decode("utf-8") + else: + for i in index_pair: + trans_f.write((translation[index_pair[i]] + b"\n").decode("utf-8")) # Evaluation evaluation_scores = {}