From f05346511d4b279005295a1d49abad42d0762568 Mon Sep 17 00:00:00 2001 From: Siju Date: Thu, 29 Nov 2018 06:52:17 +0530 Subject: [PATCH] [Tutorial]NLP Sequence to sequence model for translation (#1815) * [Tutorial]NLP Sequence to sequence model for translation * Review comments * Review comments updated --- nnvm/python/nnvm/frontend/keras.py | 45 +++- .../python/frontend/keras/test_forward.py | 20 +- tutorials/nnvm/nlp/keras_s2s_translate.py | 238 ++++++++++++++++++ 3 files changed, 286 insertions(+), 17 deletions(-) create mode 100644 tutorials/nnvm/nlp/keras_s2s_translate.py diff --git a/nnvm/python/nnvm/frontend/keras.py b/nnvm/python/nnvm/frontend/keras.py index a1e089b210c5f..9dabebc14b90e 100644 --- a/nnvm/python/nnvm/frontend/keras.py +++ b/nnvm/python/nnvm/frontend/keras.py @@ -131,6 +131,14 @@ def _convert_dense(insym, keras_layer, symtab): if keras_layer.use_bias: params['use_bias'] = True params['bias'] = symtab.new_const(weightList[1]) + input_shape = keras_layer.input_shape + input_dim = len(input_shape) + # In case of RNN dense, input shape will be (1, 1, n) + if input_dim > 2: + input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) + if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: + raise ValueError("Cannot flatten the inputs with shape.", input_shape, " for dense.") + insym = _sym.squeeze(insym, axis=0) out = _sym.dense(data=insym, **params) # defuse activation if sys.version_info.major < 3: @@ -139,6 +147,8 @@ def _convert_dense(insym, keras_layer, symtab): act_type = keras_layer.activation.__name__ if act_type != 'linear': out = _convert_activation(out, act_type, symtab) + if input_dim > 2: + out = _sym.expand_dims(out, axis=0) return out @@ -408,10 +418,11 @@ def _convert_lstm(insym, keras_layer, symtab): insym = [insym, h_sym, c_sym] in_data = insym[0] - in_state_h = insym[1] - in_state_c = insym[2] + next_h = insym[1] + next_c = insym[2] weightList = keras_layer.get_weights() + inp_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.input_shape)[0]) kernel_wt = symtab.new_const(weightList[0].transpose([1, 0])) recurrent_wt = symtab.new_const(weightList[1].transpose([1, 0])) @@ -419,16 +430,20 @@ def _convert_lstm(insym, keras_layer, symtab): units = list(weightList[0].shape)[1] - in_data = _sym.flatten(in_data) - ixh1 = _sym.dense(in_data, kernel_wt, use_bias=False, units=units) - ixh2 = _sym.dense(in_state_h, recurrent_wt, in_bias, use_bias=True, units=units) - gate = ixh1 + ixh2 - gates = _sym.split(gate, indices_or_sections=4, axis=1) - in_gate = _convert_recurrent_activation(gates[0], keras_layer) - in_transform = _convert_recurrent_activation(gates[1], keras_layer) - next_c = in_transform * in_state_c + in_gate * _convert_activation(gates[2], keras_layer, None) - out_gate = _convert_recurrent_activation(gates[3], keras_layer) - next_h = out_gate * _convert_activation(next_c, keras_layer, None) + time_steps = inp_shape[1] + in_data = _sym.squeeze(in_data, axis=0) + in_data = _sym.split(in_data, indices_or_sections=time_steps, axis=0) + #loop for the number of time_steps + for data in in_data: + ixh1 = _sym.dense(data, kernel_wt, use_bias=False, units=units) + ixh2 = _sym.dense(next_h, recurrent_wt, in_bias, use_bias=True, units=units) + gate = ixh1 + ixh2 + gates = _sym.split(gate, indices_or_sections=4, axis=1) + in_gate = _convert_recurrent_activation(gates[0], keras_layer) + in_transform = _convert_recurrent_activation(gates[1], keras_layer) + next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None) + out_gate = _convert_recurrent_activation(gates[3], keras_layer) + next_h = out_gate * _convert_activation(next_c, keras_layer, None) out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) out = _sym.reshape(next_h, shape=out_shape) @@ -656,6 +671,12 @@ def from_keras(model): raise TypeError("Unknown layer type or unsupported Keras version : {}" .format(keras_layer)) for node_idx, node in enumerate(inbound_nodes): + # If some nodes in imported model is not relevant to the current model, + # skip such layers. model._network_nodes contains keys of all nodes relevant + # to the current model. + if not model._node_key(keras_layer, node_idx) in model._network_nodes: + continue + insym = [] # Since Keras allows creating multiple layers from the same name instance, diff --git a/nnvm/tests/python/frontend/keras/test_forward.py b/nnvm/tests/python/frontend/keras/test_forward.py index 96c51a94ff694..618af3b2e4176 100644 --- a/nnvm/tests/python/frontend/keras/test_forward.py +++ b/nnvm/tests/python/frontend/keras/test_forward.py @@ -74,7 +74,7 @@ def test_forward_elemwise_add(): verify_keras_frontend(keras_model) -def test_forward_dense(): +def _test_forward_dense(): data = keras.layers.Input(shape=(32,32,1)) x = keras.layers.Flatten()(data) x = keras.layers.Dropout(0.5)(x) @@ -82,6 +82,15 @@ def test_forward_dense(): keras_model = keras.models.Model(data, x) verify_keras_frontend(keras_model) +def _test_forward_dense_with_3d_inp(): + data = keras.layers.Input(shape=(1, 20)) + x = keras.layers.Dense(10, activation='relu', kernel_initializer='uniform')(data) + keras_model = keras.models.Model(data, x) + verify_keras_frontend(keras_model, need_transpose=False) + +def test_forward_dense(): + _test_forward_dense() + _test_forward_dense_with_3d_inp() def test_forward_pool(): data = keras.layers.Input(shape=(32,32,1)) @@ -226,8 +235,8 @@ def test_forward_reuse_layers(): keras_model = keras.models.Model(data, z) verify_keras_frontend(keras_model) -def _test_LSTM(inputs, hidden, return_state=True): - data = keras.layers.Input(shape=(1, inputs)) +def _test_LSTM(time_steps, inputs, hidden, return_state=True): + data = keras.layers.Input(shape=(time_steps, inputs)) lstm_out = keras.layers.LSTM(hidden, return_state=return_state, recurrent_activation='sigmoid', @@ -250,8 +259,9 @@ def _test_LSTM_MultiLayer(inputs, hidden): def test_forward_LSTM(): - _test_LSTM(8, 8, return_state=True) - _test_LSTM(4, 4, return_state=False) + _test_LSTM(1, 8, 8, return_state=True) + _test_LSTM(1, 4, 4, return_state=False) + _test_LSTM(20, 16, 256, return_state=False) _test_LSTM_MultiLayer(4, 4) def _test_RNN(inputs, units): diff --git a/tutorials/nnvm/nlp/keras_s2s_translate.py b/tutorials/nnvm/nlp/keras_s2s_translate.py new file mode 100644 index 0000000000000..77c7f23902f4c --- /dev/null +++ b/tutorials/nnvm/nlp/keras_s2s_translate.py @@ -0,0 +1,238 @@ +""" +Keras LSTM Sequence to Sequence Model for Translation +================================= +**Author**: `Siju Samuel `_ + +This script demonstrates how to implement a basic character-level sequence-to-sequence model. +We apply it to translating short English sentences into short French sentences, +character-by-character. + +# Summary of the algorithm + +- We start with input sequences from a domain (e.g. English sentences) + and corresponding target sequences from another domain + (e.g. French sentences). +- An encoder LSTM turns input sequences to 2 state vectors + (we keep the last LSTM state and discard the outputs). +- A decoder LSTM is trained to turn the target sequences into + the same sequence but offset by one timestep in the future, + a training process called "teacher forcing" in this context. + Is uses as initial state the state vectors from the encoder. + Effectively, the decoder learns to generate `targets[t+1...]` + given `targets[...t]`, conditioned on the input sequence. + +This script loads the s2s.h5 model saved in repository +https://github.com/dmlc/web-data/raw/master/keras/models/s2s_translate/lstm_seq2seq.py +and generates sequences from it. It assumes that no changes have been made (for example: +latent_dim is unchanged, and the input data and model architecture are unchanged). + +# References + +- Sequence to Sequence Learning with Neural Networks + https://arxiv.org/abs/1409.3215 +- Learning Phrase Representations using + RNN Encoder-Decoder for Statistical Machine Translation + https://arxiv.org/abs/1406.1078 + +See lstm_seq2seq.py for more details on the model architecture and how it is trained. +""" + +from keras.models import Model, load_model +from keras.layers import Input +import random +import os +import numpy as np +import keras +import tvm +import nnvm + +###################################################################### +# Download required files +# ----------------------- +# Download files listed below from dmlc web-data repo. +model_file = "s2s_translate.h5" +data_file = "fra-eng.txt" + +# Base location for model related files. +repo_base = 'https://github.com/dmlc/web-data/raw/master/keras/models/s2s_translate/' +model_url = os.path.join(repo_base, model_file) +data_url = os.path.join(repo_base, data_file) + +# Download files listed below. +from mxnet.gluon.utils import download +download(model_url, model_file) +download(data_url, model_file) + +latent_dim = 256 # Latent dimensionality of the encoding space. +test_samples = 10000 # Number of samples used for testing. + +###################################################################### +# Process the data file +# --------------------- +# Vectorize the data. We use the same approach as the training script. +# NOTE: the data must be identical, in order for the character -> integer +# mappings to be consistent. +input_texts = [] +target_texts = [] +input_characters = set() +target_characters = set() +with open(data_file, 'r', encoding='utf-8') as f: + lines = f.read().split('\n') +test_samples = min(test_samples, len(lines)) +max_encoder_seq_length = 0 +max_decoder_seq_length = 0 +for line in lines[:test_samples]: + input_text, target_text = line.split('\t') + # We use "tab" as the "start sequence" character + # for the targets, and "\n" as "end sequence" character. + target_text = '\t' + target_text + '\n' + max_encoder_seq_length = max(max_encoder_seq_length, len(input_text)) + max_decoder_seq_length = max(max_decoder_seq_length, len(target_text)) + for char in input_text: + if char not in input_characters: + input_characters.add(char) + for char in target_text: + if char not in target_characters: + target_characters.add(char) + +input_characters = sorted(list(input_characters)) +target_characters = sorted(list(target_characters)) +num_encoder_tokens = len(input_characters) +num_decoder_tokens = len(target_characters) +input_token_index = dict( + [(char, i) for i, char in enumerate(input_characters)]) +target_token_index = dict( + [(char, i) for i, char in enumerate(target_characters)]) + +# Reverse-lookup token index to decode sequences back to something readable. +reverse_target_char_index = dict( + (i, char) for char, i in target_token_index.items()) + +###################################################################### +# Load Keras Model +# ---------------- +# Restore the model and construct the encoder and decoder. +model = load_model(model_file) +encoder_inputs = model.input[0] # input_1 + +encoder_outputs, state_h_enc, state_c_enc = model.layers[2].output # lstm_1 +encoder_states = [state_h_enc, state_c_enc] +encoder_model = Model(encoder_inputs, encoder_states) + +decoder_inputs = model.input[1] # input_2 +decoder_state_input_h = Input(shape=(latent_dim,), name='input_3') +decoder_state_input_c = Input(shape=(latent_dim,), name='input_4') +decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c] +decoder_lstm = model.layers[3] +decoder_outputs, state_h_dec, state_c_dec = decoder_lstm( + decoder_inputs, initial_state=decoder_states_inputs) +decoder_states = [state_h_dec, state_c_dec] +decoder_dense = model.layers[4] +decoder_outputs = decoder_dense(decoder_outputs) +decoder_model = Model( + [decoder_inputs] + decoder_states_inputs, + [decoder_outputs] + decoder_states) + +###################################################################### +# Compile both encoder and decoder model on NNVM +# ---------------------------------------------- +# Creates NNVM graph definition from keras model file. +from tvm.contrib import graph_runtime +target = 'llvm' +ctx = tvm.cpu(0) + +# Parse Encoder model +sym, params = nnvm.frontend.from_keras(encoder_model) +inp_enc_shape = (1, max_encoder_seq_length, num_encoder_tokens) +shape_dict = {'input_1': inp_enc_shape} + +# Build Encoder model +with nnvm.compiler.build_config(opt_level=2): + enc_graph, enc_lib, enc_params = nnvm.compiler.build(sym, target, shape_dict, params=params) +print("Encoder build ok.") + +# Create graph runtime for encoder model +tvm_enc = graph_runtime.create(enc_graph, enc_lib, ctx) +tvm_enc.set_input(**enc_params) + +# Parse Decoder model +inp_dec_shape = (1, 1, num_decoder_tokens) +shape_dict = {'input_2': inp_dec_shape, + 'input_3': (1, latent_dim), + 'input_4': (1, latent_dim)} + +# Build Decoder model +sym, params = nnvm.frontend.from_keras(decoder_model) +with nnvm.compiler.build_config(opt_level=2): + dec_graph, dec_lib, dec_params = nnvm.compiler.build(sym, target, shape_dict, params=params) +print("Decoder build ok.") + +# Create graph runtime for decoder model +tvm_dec = graph_runtime.create(dec_graph, dec_lib, ctx) +tvm_dec.set_input(**dec_params) + +# Decodes an input sequence. +def decode_sequence(input_seq): + # Set the input for encoder model. + tvm_enc.set_input('input_1', input_seq) + + # Run encoder model + tvm_enc.run() + + # Get states from encoder network + h = tvm_enc.get_output(0).asnumpy() + c = tvm_enc.get_output(1).asnumpy() + + # Populate the first character of target sequence with the start character. + sampled_token_index = target_token_index['\t'] + + # Sampling loop for a batch of sequences + decoded_sentence = '' + while True: + # Generate empty target sequence of length 1. + target_seq = np.zeros((1, 1, num_decoder_tokens), dtype='float32') + # Update the target sequence (of length 1). + target_seq[0, 0, sampled_token_index] = 1. + + # Set the input and states for decoder model. + tvm_dec.set_input('input_2', target_seq) + tvm_dec.set_input('input_3', h) + tvm_dec.set_input('input_4', c) + # Run decoder model + tvm_dec.run() + + output_tokens = tvm_dec.get_output(0).asnumpy() + h = tvm_dec.get_output(1).asnumpy() + c = tvm_dec.get_output(2).asnumpy() + + # Sample a token + sampled_token_index = np.argmax(output_tokens[0, -1, :]) + sampled_char = reverse_target_char_index[sampled_token_index] + + # Exit condition: either hit max length or find stop character. + if sampled_char == '\n': + break + + # Update the sentence + decoded_sentence += sampled_char + if len(decoded_sentence) > max_decoder_seq_length: + break + return decoded_sentence + +def generate_input_seq(input_text): + input_seq = np.zeros((1, max_encoder_seq_length, num_encoder_tokens), dtype='float32') + for t, char in enumerate(input_text): + input_seq[0, t, input_token_index[char]] = 1. + return input_seq + +###################################################################### +# Run the model +# ------------- +# Randonly take some text from test samples and translate +for seq_index in range(100): + # Take one sentence randomly and try to decode. + index = random.randint(1, test_samples) + input_text, _ = lines[index].split('\t') + input_seq = generate_input_seq(input_text) + decoded_sentence = decode_sequence(input_seq) + print((seq_index + 1), ": ", input_text, "==>", decoded_sentence)