Skip to content

Commit

Permalink
Review comments updated
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel committed Nov 6, 2018
1 parent 807fd96 commit 1882340
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tutorials/nnvm/nlp/keras_s2s_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
download(data_url, model_file)

latent_dim = 256 # Latent dimensionality of the encoding space.
num_samples = 10000 # Number of samples used for training.
test_samples = 10000 # Number of samples used for testing.

######################################################################
# Process the data file
Expand All @@ -78,10 +78,10 @@
target_characters = set()
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.read().split('\n')
num_samples = min(num_samples, len(lines))
test_samples = min(test_samples, len(lines))
max_encoder_seq_length = 0
max_decoder_seq_length = 0
for line in lines[:num_samples]:
for line in lines[:test_samples]:
input_text, target_text = line.split('\t')
# We use "tab" as the "start sequence" character
# for the targets, and "\n" as "end sequence" character.
Expand Down Expand Up @@ -228,10 +228,10 @@ def generate_input_seq(input_text):
######################################################################
# Run the model
# -------------
# Randonly take some text and translate
# Randonly take some text from test samples and translate
for seq_index in range(100):
# Take one sentence randomly and try to decode.
index = random.randint(1, num_samples)
index = random.randint(1, test_samples)
input_text, _ = lines[index].split('\t')
input_seq = generate_input_seq(input_text)
decoded_sentence = decode_sequence(input_seq)
Expand Down

0 comments on commit 1882340

Please sign in to comment.