-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Torch] Support Python list, more realistic recurrent networks #5306
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Is torchscript list immutable or mutable like python's list?
Yes it is mutable. List Even though variables that are updated in the loop are supposed to be passed to class ListAppend(nn.Module):
def forward(self, input):
# type: (Tensor) -> List[Tensor]
outputs = []
for i in range(input.size(0)):
outputs.append(input)
return outputs
To workaround the difficulty of list append, I use list concat to append one element at the tail of a list. The original LSTM models in Pytorch repo do not use list append either and use concat instead, probably for the same reason. |
From an outsider, it seems like the more principled approach is to translate list to Reference of List. We could then write passses to remove Reference if possible. |
f1dedd4
to
8d56f33
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall lgtm. Some minor comments regarding to registering static tensor array ops.
8d56f33
to
3074c9a
Compare
@kevinthesun Thanks for the review! Please have a look at the last commit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks @masahi @MarisaKirisame |
…e#5306) * use funcs from prelude, pass around convert_map * get relay input type from user ishape * handle tuple unpack * experimenting with static tensor array * use prelude concat instead of cons + rev * minor clean up * fix layer norm conversion bug, unwrap tensor array * add infer shape on tensor array * pass around prelude for now * compile worked but runtime error * fix tensor array wrapping * begin list dynamic test * is_list_dynamic first version * finish dynamic list test * a few fix * use shape_of function if Any is found * improve size conversion * working on adding free vars to loop block * fixed inlined inner loop issue * clean up free var handling * add support for tensor array concat * adding ta concat on last axis * fix concat, but got runtime error * disable concat on axis -1 for now * add lstm tests * revert unrelated change * fix stacked bidir test * minor fix to test * relax tol a bit, revert dnnl change to avoid conflict * simplify infer type, use input tensor shape rather than concat shape * more shape fix
…e#5306) * use funcs from prelude, pass around convert_map * get relay input type from user ishape * handle tuple unpack * experimenting with static tensor array * use prelude concat instead of cons + rev * minor clean up * fix layer norm conversion bug, unwrap tensor array * add infer shape on tensor array * pass around prelude for now * compile worked but runtime error * fix tensor array wrapping * begin list dynamic test * is_list_dynamic first version * finish dynamic list test * a few fix * use shape_of function if Any is found * improve size conversion * working on adding free vars to loop block * fixed inlined inner loop issue * clean up free var handling * add support for tensor array concat * adding ta concat on last axis * fix concat, but got runtime error * disable concat on axis -1 for now * add lstm tests * revert unrelated change * fix stacked bidir test * minor fix to test * relax tol a bit, revert dnnl change to avoid conflict * simplify infer type, use input tensor shape rather than concat shape * more shape fix
…e#5306) * use funcs from prelude, pass around convert_map * get relay input type from user ishape * handle tuple unpack * experimenting with static tensor array * use prelude concat instead of cons + rev * minor clean up * fix layer norm conversion bug, unwrap tensor array * add infer shape on tensor array * pass around prelude for now * compile worked but runtime error * fix tensor array wrapping * begin list dynamic test * is_list_dynamic first version * finish dynamic list test * a few fix * use shape_of function if Any is found * improve size conversion * working on adding free vars to loop block * fixed inlined inner loop issue * clean up free var handling * add support for tensor array concat * adding ta concat on last axis * fix concat, but got runtime error * disable concat on axis -1 for now * add lstm tests * revert unrelated change * fix stacked bidir test * minor fix to test * relax tol a bit, revert dnnl change to avoid conflict * simplify infer type, use input tensor shape rather than concat shape * more shape fix
This PR builds on the control flow support added in #4964 and aims to support more realistic recurrent networks than the simple one in #4964. Specifically the goal is to enable translating LSTM models in the PyTorch repo https://github.com/pytorch/pytorch/tree/master/benchmarks/fastrnns described in their blog post.
Translating these models requires taking care of dynamic lists and tensor shape. I added necessary support in the Torch frontend using prelude List ADT, static Tensor array #5103, and Any. Previously we can only translate "tensors in, tensors out" type of models, but now we can ingest more complex inputs such as list of tuple of tensors.
See the new test cases for the kinds of models we can support now. I added some variants of LSTMs:
The result of translating three layer stacked bidirectional LSTMs is dumped here. Even though this model has three loop nest,
the two outer loops are unrolled by Torchscript. So in the Relay IR dump, there are 3 (layers) x 2 (direction) while loops.
please review @kevinthesun @zhiics @MarisaKirisame @icemelon9 @jwfromm @wweic @alexwong
cc @tqchen @jroesch @ajtulloch @junrushao1994