Skip to content
This repository was archived by the owner on Dec 11, 2023. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions nmt/model_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def create_train_model(
src_vocab_table, tgt_vocab_table = vocab_utils.create_vocab_tables(
src_vocab_file, tgt_vocab_file, hparams.share_vocab)

src_dataset = tf.contrib.data.TextLineDataset(src_file)
tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file)
src_dataset = tf.data.TextLineDataset(src_file)
tgt_dataset = tf.data.TextLineDataset(tgt_file)
skip_count_placeholder = tf.placeholder(shape=(), dtype=tf.int64)

iterator = iterator_utils.get_iterator(
Expand Down Expand Up @@ -132,8 +132,8 @@ def create_eval_model(model_creator, hparams, scope=None, extra_args=None):
src_vocab_file, tgt_vocab_file, hparams.share_vocab)
src_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
tgt_file_placeholder = tf.placeholder(shape=(), dtype=tf.string)
src_dataset = tf.contrib.data.TextLineDataset(src_file_placeholder)
tgt_dataset = tf.contrib.data.TextLineDataset(tgt_file_placeholder)
src_dataset = tf.data.TextLineDataset(src_file_placeholder)
tgt_dataset = tf.data.TextLineDataset(tgt_file_placeholder)
iterator = iterator_utils.get_iterator(
src_dataset,
tgt_dataset,
Expand Down Expand Up @@ -185,7 +185,7 @@ def create_infer_model(model_creator, hparams, scope=None, extra_args=None):
src_placeholder = tf.placeholder(shape=[None], dtype=tf.string)
batch_size_placeholder = tf.placeholder(shape=[], dtype=tf.int64)

src_dataset = tf.contrib.data.Dataset.from_tensor_slices(
src_dataset = tf.data.Dataset.from_tensor_slices(
src_placeholder)
iterator = iterator_utils.get_infer_iterator(
src_dataset,
Expand Down
18 changes: 7 additions & 11 deletions nmt/utils/iterator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_iterator(src_dataset,
tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)

src_tgt_dataset = tf.contrib.data.Dataset.zip((src_dataset, tgt_dataset))
src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))

src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)
if skip_count is not None:
Expand All @@ -111,8 +111,7 @@ def get_iterator(src_dataset,

src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.string_split([src]).values, tf.string_split([tgt]).values),
num_threads=num_threads,
output_buffer_size=output_buffer_size)
num_parallel_calls=num_threads)

# Filter zero length input sequences.
src_tgt_dataset = src_tgt_dataset.filter(
Expand All @@ -121,13 +120,11 @@ def get_iterator(src_dataset,
if src_max_len:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (src[:src_max_len], tgt),
num_threads=num_threads,
output_buffer_size=output_buffer_size)
num_parallel_calls=num_threads)
if tgt_max_len:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (src, tgt[:tgt_max_len]),
num_threads=num_threads,
output_buffer_size=output_buffer_size)
num_parallel_calls=num_threads)
if source_reverse:
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.reverse(src, axis=[0]), tgt),
Expand All @@ -138,18 +135,17 @@ def get_iterator(src_dataset,
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
num_threads=num_threads, output_buffer_size=output_buffer_size)
num_parallel_calls=num_threads)
# Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (src,
tf.concat(([tgt_sos_id], tgt), 0),
tf.concat((tgt, [tgt_eos_id]), 0)),
num_threads=num_threads, output_buffer_size=output_buffer_size)
num_parallel_calls=num_threads)
# Add in sequence lengths.
src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
num_threads=num_threads,
output_buffer_size=output_buffer_size)
num_parallel_calls=num_threads)

# Bucket by source sequence length (buckets for lengths 0-9, 10-19, ...)
def batching_func(x):
Expand Down