-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
@@ -630,7 +630,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, | |||
return next_h, [next_h] | |||
|
|||
|
|||
class SequentialRNNCell(RecurrentCell): | |||
class SequentialRNNCell(HybridRecurrentCell): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might break existing code where the member cells are not HybridRecurrentCells. Maybe have a separate class like HybridSequentialRNNCell?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add tests?
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params) | ||
|
||
def hybrid_forward(self, F, x, *args, **kwargs): | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HybridSequentialRNNCell should have hybrid_forward implemented similarly to SequentialRNNCell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. The hybrid_forward method should have the forward logic similar to what SequentialRNNCell has in its __call__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong do you remember why SequentialRNNCell has hybrid_forward
?
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
def hybrid_forward(self, F, x, *args, **kwargs): | ||
raise NotImplementedError | ||
def hybrid_forward(self, F, *args, **kwargs): | ||
super(HybridSequentialRNNCell, self).hybrid_forward(args, kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this class is an abstract base class and should not implement hybrid_forward.
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params) | ||
|
||
def hybrid_forward(self, F, x, *args, **kwargs): | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. The hybrid_forward method should have the forward logic similar to what SequentialRNNCell has in its __call__
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params) | ||
|
||
def hybrid_forward(self, F, x, *args, **kwargs): | ||
raise NotImplementedError |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@piiswrong do you remember why SequentialRNNCell has hybrid_forward
?
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
@@ -79,9 +79,10 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None): | |||
assert length is None or len(inputs) == length | |||
if isinstance(inputs[0], symbol.Symbol): | |||
F = symbol | |||
# TODO: batch_size cannot got here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
symbol should be able to infer this just fine
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
return _cells_begin_state(self._children.values(), **kwargs) | ||
|
||
def __call__(self, inputs, states): | ||
raise NotImplementedError("HybridSequentialRNN cannot be stepped. Please use unroll") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should allow forward unless bidirectional cell is registered, similar to sequential rnn cell.
a hybrid cell means that the graph inside the cell (i.e. for a single step) can be hybridized.
Ping @chinakook |
@chinakook pinging again. It would be great if we could get this in. Would you address the review comments? Thanks. |
@szha Is that all? |
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
next_states = [] | ||
p = 0 | ||
for cell in self._children.values(): | ||
assert not isinstance(cell, BidirectionalCell) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move the assertion outside to fail fast. assert all(not isinstance(cell) for cell in ...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, I'd suggest including a similar error message to what SequentialRNNCell step has.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chinakook otherwise looks good to me. Thanks.
@chinakook would you mind doing a rebase? we have a couple of flaky tests that are already fixed in the upstream. |
Rebasing like this? I've not very familiar with rebasing. Should I close this and open another PR? |
suppose you have a git remote called "upstream" (which you can get by doing |
@@ -171,6 +171,54 @@ def test_stack(): | |||
assert outs == [(10, 100), (10, 100), (10, 100)] | |||
|
|||
|
|||
def test_hybridstack(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a unit test for the NC layout bug you fixed?
def hybrid_forward(self, F, x): | ||
return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True) | ||
|
||
x = mx.nd.random.uniform(shape=(10, 3, 100)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chinakook could you split this into a list to verify your fix?
@chinakook thanks for the contribution. @eric-haibin-lin I will add the test. |
Description
(Brief description on what this PR is about)
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments