Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

fix beam search script #175

Merged
merged 5 commits into from
Jun 30, 2018
Merged

fix beam search script #175

merged 5 commits into from
Jun 30, 2018

Conversation

szha
Copy link
Member

@szha szha commented Jun 27, 2018

Description

The following run was broken.
python beam_search_generator.py --bos I think --lm standard_lstm_lm_200

This PR fixes the issue by adding state_info to language model while preserving existing behaviors.

Checklist

Essentials

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

Changes

  • add state_info to models
  • use state_info to decide state shape in beam search

@szha szha requested a review from sxjscience June 27, 2018 05:20
@szha
Copy link
Member Author

szha commented Jun 27, 2018

@sxjscience @hhexiy

@szha szha force-pushed the fix_beam branch 4 times, most recently from 76f2fdb to a5034a8 Compare June 27, 2018 15:53
@mli
Copy link
Member

mli commented Jun 27, 2018

Job PR-175/6 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-175/6/index.html

@@ -92,30 +92,56 @@ def _expand_to_beam_size(data, beam_size, batch_size):
Each NDArray should have shape (batch_size * beam_size, ...)
"""
if isinstance(data, list):
return [_expand_to_beam_size(ele, beam_size, batch_size) for ele in data]
assert not state_info or isinstance(state_info, list), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move this part outside if else by using
assert not state_info or isinstance(state_info, type(data))

@@ -124,14 +150,37 @@ def _choose_states(F, states, indices):
Each NDArray/Symbol should have shape (N, ...).
"""
if isinstance(states, list):
return [_choose_states(F, ele, indices) for ele in states]
assert not state_info or isinstance(state_info, list), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as previous comment

else:
raise NotImplementedError


def _choose_states(F, states, indices):
def _choose_states(F, states, state_info, indices):
"""

Parameters
----------
F : ndarray or symbol
states : Object contains NDArrays/Symbols
Each NDArray/Symbol should have shape (N, ...).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix docstring. Now N may not be in the first dim.

states = F.take(states, indices)
if batch_axis != 0:
states = states.swapaxes(0, batch_axis)
return states
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the returned states always have batch in the first dim even it is not the case for the input? Would it cause some inconsistency?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of the two swaps, the batch dimension should be where it used to be.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can directly use take(axis=batch_axis) once apache/mxnet#11326 is merged.

@szhengac
Copy link
Member

Also, test case with different layouts is needed.

@szha szha force-pushed the fix_beam branch 2 times, most recently from e627b17 to 246f1aa Compare June 28, 2018 00:51
.reshape((batch_size * beam_size,) + data.shape[1:])
if not state_info:
state_info = {'__layout__': 'NC'}
batch_axis = state_info['__layout__'].find('N')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the following will be better:

if not state_info:
   batch_axis = 0
else:
   batch_axis = state_info['__layout__'].find('N')

return mx.nd.stack(*updated_states, axis=0)
if not state_info:
state_info = {'__layout__': 'NC'}
batch_axis = state_info['__layout__'].find('N')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

if not state_info:
  batch_axis = 0
else:
 ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's in the test, the change is not necessary.

@mli
Copy link
Member

mli commented Jun 30, 2018

Job PR-175/13 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-175/13/index.html

@szha szha force-pushed the fix_beam branch 2 times, most recently from c6fcd77 to d79403b Compare June 30, 2018 17:49
@mli
Copy link
Member

mli commented Jun 30, 2018

Job PR-175/18 is complete.
Docs are uploaded to http://gluon-nlp-staging.s3-accelerate.dualstack.amazonaws.com/PR-175/18/index.html

@szha szha merged commit b926873 into dmlc:master Jun 30, 2018
@szha szha deleted the fix_beam branch June 30, 2018 18:37
leezu pushed a commit to leezu/gluon-nlp that referenced this pull request Jul 11, 2018
* fix beam search script

* add tests

* address comment

* fix

* update test
paperplanet pushed a commit to paperplanet/gluon-nlp that referenced this pull request Jun 9, 2019
* fix beam search script

* add tests

* address comment

* fix

* update test
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants