-
Notifications
You must be signed in to change notification settings - Fork 538
Conversation
76f2fdb
to
a5034a8
Compare
Job PR-175/6 is complete. |
gluonnlp/model/beam_search.py
Outdated
@@ -92,30 +92,56 @@ def _expand_to_beam_size(data, beam_size, batch_size): | |||
Each NDArray should have shape (batch_size * beam_size, ...) | |||
""" | |||
if isinstance(data, list): | |||
return [_expand_to_beam_size(ele, beam_size, batch_size) for ele in data] | |||
assert not state_info or isinstance(state_info, list), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can move this part outside if else by using
assert not state_info or isinstance(state_info, type(data))
gluonnlp/model/beam_search.py
Outdated
@@ -124,14 +150,37 @@ def _choose_states(F, states, indices): | |||
Each NDArray/Symbol should have shape (N, ...). | |||
""" | |||
if isinstance(states, list): | |||
return [_choose_states(F, ele, indices) for ele in states] | |||
assert not state_info or isinstance(state_info, list), \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as previous comment
gluonnlp/model/beam_search.py
Outdated
else: | ||
raise NotImplementedError | ||
|
||
|
||
def _choose_states(F, states, indices): | ||
def _choose_states(F, states, state_info, indices): | ||
""" | ||
|
||
Parameters | ||
---------- | ||
F : ndarray or symbol | ||
states : Object contains NDArrays/Symbols | ||
Each NDArray/Symbol should have shape (N, ...). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix docstring. Now N
may not be in the first dim.
states = F.take(states, indices) | ||
if batch_axis != 0: | ||
states = states.swapaxes(0, batch_axis) | ||
return states |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the returned states always have batch in the first dim even it is not the case for the input? Would it cause some inconsistency?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because of the two swaps, the batch dimension should be where it used to be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can directly use take(axis=batch_axis) once apache/mxnet#11326 is merged.
Also, test case with different layouts is needed. |
e627b17
to
246f1aa
Compare
gluonnlp/model/beam_search.py
Outdated
.reshape((batch_size * beam_size,) + data.shape[1:]) | ||
if not state_info: | ||
state_info = {'__layout__': 'NC'} | ||
batch_axis = state_info['__layout__'].find('N') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the following will be better:
if not state_info:
batch_axis = 0
else:
batch_axis = state_info['__layout__'].find('N')
tests/unittest/test_beam_search.py
Outdated
return mx.nd.stack(*updated_states, axis=0) | ||
if not state_info: | ||
state_info = {'__layout__': 'NC'} | ||
batch_axis = state_info['__layout__'].find('N') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
if not state_info:
batch_axis = 0
else:
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's in the test, the change is not necessary.
Job PR-175/13 is complete. |
c6fcd77
to
d79403b
Compare
Job PR-175/18 is complete. |
* fix beam search script * add tests * address comment * fix * update test
* fix beam search script * add tests * address comment * fix * update test
Description
The following run was broken.
python beam_search_generator.py --bos I think --lm standard_lstm_lm_200
This PR fixes the issue by adding state_info to language model while preserving existing behaviors.
Checklist
Essentials
Changes