Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Support Python list, more realistic recurrent networks #5306

Merged
merged 31 commits into from
Apr 13, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Apr 11, 2020

This PR builds on the control flow support added in #4964 and aims to support more realistic recurrent networks than the simple one in #4964. Specifically the goal is to enable translating LSTM models in the PyTorch repo https://github.com/pytorch/pytorch/tree/master/benchmarks/fastrnns described in their blog post.

Translating these models requires taking care of dynamic lists and tensor shape. I added necessary support in the Torch frontend using prelude List ADT, static Tensor array #5103, and Any. Previously we can only translate "tensors in, tensors out" type of models, but now we can ingest more complex inputs such as list of tuple of tensors.

See the new test cases for the kinds of models we can support now. I added some variants of LSTMs:

  • LSTM with layer normalization
  • Bidirectional
  • Stacked
  • Stacked and Bidirectional

The result of translating three layer stacked bidirectional LSTMs is dumped here. Even though this model has three loop nest,

for layer in num_layers:
    for dir in ["forward", "backward"]:
        seq_len = input.size(0)
        for i in seq_len:
              ...

the two outer loops are unrolled by Torchscript. So in the Relay IR dump, there are 3 (layers) x 2 (direction) while loops.

please review @kevinthesun @zhiics @MarisaKirisame @icemelon9 @jwfromm @wweic @alexwong
cc @tqchen @jroesch @ajtulloch @junrushao1994

Copy link
Contributor

@MarisaKirisame MarisaKirisame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Is torchscript list immutable or mutable like python's list?

@masahi
Copy link
Member Author

masahi commented Apr 11, 2020

LGTM. Is torchscript list immutable or mutable like python's list?

Yes it is mutable. List append is mapped to aten::append in Torchscript and it is entirely side-effecting operation. See below for a simple module that only does list append and its Torchscript representation.

Even though variables that are updated in the loop are supposed to be passed to prim::Loop op to become loop variables, this does not apply to side effecting operations like list append. Translating this module to Relay is complicated because also in Relay loop variables are the only ones that can be updated between iteration. If we naively translate it, the list outputs.1 below appears free in the loop and can not be updated.

class ListAppend(nn.Module):
    def forward(self, input):
        # type: (Tensor) -> List[Tensor]
        outputs = []
        for i in range(input.size(0)):
            outputs.append(input)
        return outputs
graph(%self : __torch__.ListAppend,
      %input.1 : Tensor):
  %8 : bool = prim::Constant[value=1]() # rnn_test.py:142:8
  %4 : int = prim::Constant[value=0]() # rnn_test.py:142:34
  %outputs.1 : Tensor[] = prim::ListConstruct()
  %5 : int = aten::size(%input.1, %4) # rnn_test.py:142:23
   = prim::Loop(%5, %8) # rnn_test.py:142:8
    block0(%i : int):
      %12 : Tensor[] = aten::append(%outputs.1, %input.1) # rnn_test.py:143:12
      -> (%8)
  return (%outputs.1)

To workaround the difficulty of list append, I use list concat to append one element at the tail of a list. The original LSTM models in Pytorch repo do not use list append either and use concat instead, probably for the same reason.

@MarisaKirisame
Copy link
Contributor

From an outsider, it seems like the more principled approach is to translate list to Reference of List. We could then write passses to remove Reference if possible.

Copy link
Contributor

@kevinthesun kevinthesun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm. Some minor comments regarding to registering static tensor array ops.

python/tvm/relay/frontend/pytorch.py Show resolved Hide resolved
python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
python/tvm/relay/frontend/pytorch.py Show resolved Hide resolved
@masahi
Copy link
Member Author

masahi commented Apr 13, 2020

@kevinthesun Thanks for the review! Please have a look at the last commit.

Copy link
Contributor

@kevinthesun kevinthesun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kevinthesun kevinthesun merged commit 0145cd5 into apache:master Apr 13, 2020
@kevinthesun
Copy link
Contributor

Thanks @masahi @MarisaKirisame

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
…e#5306)

* use funcs from prelude, pass around convert_map

* get relay input type from user ishape

* handle tuple unpack

* experimenting with static tensor array

* use prelude concat instead of cons + rev

* minor clean up

* fix layer norm conversion bug, unwrap tensor array

* add infer shape on tensor array

* pass around prelude for now

* compile worked but runtime error

* fix tensor array wrapping

* begin list dynamic test

* is_list_dynamic first version

* finish dynamic list test

* a few fix

* use shape_of function if Any is found

* improve size conversion

* working on adding free vars to loop block

* fixed inlined inner loop issue

* clean up free var handling

* add support for tensor array concat

* adding ta concat on last axis

* fix concat, but got runtime error

* disable concat on axis -1 for now

* add lstm tests

* revert unrelated change

* fix stacked bidir test

* minor fix to test

* relax tol a bit, revert dnnl change to avoid conflict

* simplify infer type, use input tensor shape rather than concat shape

* more shape fix
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
…e#5306)

* use funcs from prelude, pass around convert_map

* get relay input type from user ishape

* handle tuple unpack

* experimenting with static tensor array

* use prelude concat instead of cons + rev

* minor clean up

* fix layer norm conversion bug, unwrap tensor array

* add infer shape on tensor array

* pass around prelude for now

* compile worked but runtime error

* fix tensor array wrapping

* begin list dynamic test

* is_list_dynamic first version

* finish dynamic list test

* a few fix

* use shape_of function if Any is found

* improve size conversion

* working on adding free vars to loop block

* fixed inlined inner loop issue

* clean up free var handling

* add support for tensor array concat

* adding ta concat on last axis

* fix concat, but got runtime error

* disable concat on axis -1 for now

* add lstm tests

* revert unrelated change

* fix stacked bidir test

* minor fix to test

* relax tol a bit, revert dnnl change to avoid conflict

* simplify infer type, use input tensor shape rather than concat shape

* more shape fix
dpankratz pushed a commit to dpankratz/incubator-tvm that referenced this pull request Apr 24, 2020
…e#5306)

* use funcs from prelude, pass around convert_map

* get relay input type from user ishape

* handle tuple unpack

* experimenting with static tensor array

* use prelude concat instead of cons + rev

* minor clean up

* fix layer norm conversion bug, unwrap tensor array

* add infer shape on tensor array

* pass around prelude for now

* compile worked but runtime error

* fix tensor array wrapping

* begin list dynamic test

* is_list_dynamic first version

* finish dynamic list test

* a few fix

* use shape_of function if Any is found

* improve size conversion

* working on adding free vars to loop block

* fixed inlined inner loop issue

* clean up free var handling

* add support for tensor array concat

* adding ta concat on last axis

* fix concat, but got runtime error

* disable concat on axis -1 for now

* add lstm tests

* revert unrelated change

* fix stacked bidir test

* minor fix to test

* relax tol a bit, revert dnnl change to avoid conflict

* simplify infer type, use input tensor shape rather than concat shape

* more shape fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants