-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Tutorial]NLP Sequence to sequence model for translation #1815
Conversation
@kazum @srkreddy1238 @masahi @PariksheetPinjari909 please have one round of review. thanks |
@siju-samuel, sorry for my late response. I think I can take a look this weekend. |
nnvm/python/nnvm/frontend/keras.py
Outdated
# In case of RNN dense, input shape will be (1, 1, n) | ||
if input_dim > 2: | ||
input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) | ||
if input_dim != 3 and input_shape[0] != input_shape[1] != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this check should be
if input_dim != 3 and input_shape[0] != input_shape[1] != 1: | |
if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: |
or
if input_dim != 3 and input_shape[0] != input_shape[1] != 1: | |
if not (input_dim == 3 and input_shape[0] == input_shape[1] == 1): |
nnvm/python/nnvm/frontend/keras.py
Outdated
in_data = _sym.squeeze(in_data, axis=0) | ||
in_data = _sym.split(in_data, indices_or_sections=time_steps, axis=0) | ||
for step in range(time_steps): | ||
ixh1 = _sym.dense(in_data[step], kernel_wt, use_bias=False, units=units) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think no need to use range here.
for step in in_data:
ixh1 = _sym.dense(step, kernel_wt, use_bias=False, units=units)
nnvm/python/nnvm/frontend/keras.py
Outdated
sym = symtab.get_var(sym_name, must_contain=True) | ||
insym.append(sym) | ||
|
||
# In some models, sym_name may not be available in inbound_nodes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you share an example model where sim_name
is not available in inbound_nodes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dmlc/web-data#124
Put the contents of keras folder to your execution path
Any encoder-decoder models will have this issue, since keras treat this as 2 different networks, but its not completely independant, (decoder network input linked to encoder output)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change ignores the encoder outputs in the decoder model, and uses zeros as an initial state instead. It looks wrong to me.
The root cause is that you are processing layers which are included in the imported model but not relevant to the current model. You can skip such layers with the below code.
if not model._node_key(keras_layer, node_idx) in model._network_nodes:
continue
Note that model._network_nodes contains keys of all nodes relevant to the current model.
# Base location for model related files. | ||
repo_base = 'https://github.com/dmlc/web-data/raw/master/keras/models/s2s_translate/' | ||
model_url = os.path.join(repo_base, model_file) | ||
data_url = os.path.join(repo_base, data_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The model and data files are not found in the repository yet. Can you share them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Randonly take some text and translate | ||
for seq_index in range(100): | ||
# Take one sequence and try to decode. | ||
index = random.randint(1, num_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the model is trained on num_samples, i will suggest to test the model on validation dataset.
download(data_url, model_file) | ||
|
||
latent_dim = 256 # Latent dimensionality of the encoding space. | ||
num_samples = 10000 # Number of samples to train on. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the script, pretrained model is used, no training is done. Can you update the comment.
3732425
to
807fd96
Compare
@kazum @PariksheetPinjari909 could you please review once again. Thanks. |
@@ -131,6 +131,14 @@ def _convert_dense(insym, keras_layer, symtab): | |||
if keras_layer.use_bias: | |||
params['use_bias'] = True | |||
params['bias'] = symtab.new_const(weightList[1]) | |||
input_shape = keras_layer.input_shape | |||
input_dim = len(input_shape) | |||
# In case of RNN dense, input shape will be (1, 1, n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current version doesn't have a check of input_shape[1] != 1
, so the input shape will be (1, m, n)?
@kazum @PariksheetPinjari909 Could you conclude on this pr? |
The current version almost looks good to me. I'll approve if my comment (#1815 (comment)) is addressed |
I just ran the model for 10 test samples from 10000 to 10050. In the result the output sequence doesn't match with the sequence provided in dataset. @siju-samuel can you have a look at this. |
It wont match exactly same as with the actual translation as shown in the text file. This is a only a porting of keras s2s to tvm. so this will match exactly same with keras output. with tvm, accuracy wont improve further. |
Should merge dmlc/web-data#124 @tqchen |
@merrymercy dmlc/web-data#124 is merged |
* [Tutorial]NLP Sequence to sequence model for translation * Review comments * Review comments updated
* [Tutorial]NLP Sequence to sequence model for translation * Review comments * Review comments updated
* [Tutorial]NLP Sequence to sequence model for translation * Review comments * Review comments updated
Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from others in the community.