diff --git a/nmt/model_helper.py b/nmt/model_helper.py index 3b855e81f..b6bb3cba9 100644 --- a/nmt/model_helper.py +++ b/nmt/model_helper.py @@ -72,8 +72,8 @@ def create_train_model( src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables( src_vocab_file, tgt_vocab_file, hparams.share_vocab) - src_dataset = tf.contrib.data.TextLineDataset(src_file) - tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file) + src_dataset = tf.data.TextLineDataset(src_file) + tgt_dataset = tf.data.TextLineDataset(tgt_file) skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64) iterator = iterator_utils.get_iterator( @@ -132,8 +132,8 @@ def create_eval_model(model_creator, hparams, scope=None, extra_args=None): src_vocab_file, tgt_vocab_file, hparams.share_vocab) src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string) - src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder) - tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder) + src_dataset = tf.data.TextLineDataset(src_file_placeholder) + tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder) iterator = iterator_utils.get_iterator( src_dataset, tgt_dataset, @@ -185,7 +185,7 @@ def create_infer_model(model_creator, hparams, scope=None, extra_args=None): src_placeholder = tf.placeholder(shape=[None], dtype=tf.string) batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64) - src_dataset = tf.contrib.data.Dataset.from_tensor_slices( + src_dataset = tf.data.Dataset.from_tensor_slices( src_placeholder) iterator = iterator_utils.get_infer_iterator( src_dataset, diff --git a/nmt/utils/iterator_utils.py b/nmt/utils/iterator_utils.py index 035cdc047..5f976c3ca 100644 --- a/nmt/utils/iterator_utils.py +++ b/nmt/utils/iterator_utils.py @@ -101,7 +101,7 @@ def get_iterator(src_dataset, tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) - src_tgt_dataset = tf.contrib.data.Dataset.zip((src_dataset, tgt_dataset)) + src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index) if skip_count is not None: @@ -111,8 +111,7 @@ def get_iterator(src_dataset, src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.string_split([src]).values, tf.string_split([tgt]).values), - num_threads=num_threads, - output_buffer_size=output_buffer_size) + num_parallel_calls=num_threads) # Filter zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( @@ -121,13 +120,11 @@ def get_iterator(src_dataset, if src_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src[:src_max_len], tgt), - num_threads=num_threads, - output_buffer_size=output_buffer_size) + num_parallel_calls=num_threads) if tgt_max_len: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tgt[:tgt_max_len]), - num_threads=num_threads, - output_buffer_size=output_buffer_size) + num_parallel_calls=num_threads) if source_reverse: src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.reverse(src, axis=[0]), tgt), @@ -138,18 +135,17 @@ def get_iterator(src_dataset, src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32), tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)), - num_threads=num_threads, output_buffer_size=output_buffer_size) + num_parallel_calls=num_threads) # Create a tgt_input prefixed with and a tgt_output suffixed with . src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt: (src, tf.concat(([tgt_sos_id], tgt), 0), tf.concat((tgt, [tgt_eos_id]), 0)), - num_threads=num_threads, output_buffer_size=output_buffer_size) + num_parallel_calls=num_threads) # Add in sequence lengths. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)), - num_threads=num_threads, - output_buffer_size=output_buffer_size) + num_parallel_calls=num_threads) # Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...) def batching_func(x):