Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][frontend] aten::copy_ support for pytorch #15502

Merged
merged 13 commits into from
Oct 19, 2023

Conversation

jhlee525
Copy link
Contributor

@jhlee525 jhlee525 commented Aug 7, 2023

Although #9375 has been rejected, I tried a different way to support aten::copy_ op.

aten::copy_ behaves differently from other inplace ops, "pure inplace" way, unlike other inplace nodes' one, which output graph(torch.Graph) still relaying it's output to users so that a DAG can be structed. However, aten::copy_ op returns itself, which dangles all of mutations.

For example, a torch module like

class Test(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor):
        x[:5, :5] = x[:5, :5] + 1
        return x

generates the graph

graph(%self : __torch__.Test,
      %x : Float(10, 10, strides=[10, 1], requires_grad=0, device=cpu)):
  %4 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
  %5 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
  %6 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
  %7 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %8 : Float(5, 10, strides=[10, 1], requires_grad=0, device=cpu) = aten::slice(%x, %4, %5, %6, %7)
  %9 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %10 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
  %11 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
  %12 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %13 : Float(5, 5, strides=[10, 1], requires_grad=0, device=cpu) = aten::slice(%8, %9, %10, %11, %12)
  %14 : Long(requires_grad=0, device=cpu) = prim::Constant[value={1}]()
  %15 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %16 : Float(5, 5, strides=[5, 1], requires_grad=0, device=cpu) = aten::add(%13, %14, %15)
  %17 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
  %18 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
  %19 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
  %20 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %21 : Float(5, 10, strides=[10, 1], requires_grad=0, device=cpu) = aten::slice(%x, %17, %18, %19, %20)
  %22 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %23 : int = prim::Constant[value=0]() # /home/jhlee/tvm/test.py:10:0
  %24 : int = prim::Constant[value=5]() # /home/jhlee/tvm/test.py:10:0
  %25 : int = prim::Constant[value=1]() # /home/jhlee/tvm/test.py:10:0
  %26 : Float(5, 5, strides=[10, 1], requires_grad=0, device=cpu) = aten::slice(%21, %22, %23, %24, %25)
  %27 : bool = prim::Constant[value=0]()
  %28 : Float(5, 5, strides=[10, 1], requires_grad=0, device=cpu) = aten::copy_(%26, %16, %27)
  return (%x)

which returns %x itself.

My approach to handle this problem is:

  1. in from_pytorch, insert a pass that redirects output of aten::copy_(_redirect_inplace_output), after _run_jit_passes is called, in torch level(torch.Graph)
  2. when handling aten::copy node, we collect from it's parents to collect aten::select and aten::slice nodes, to generate indices of source. I referenced pytorch repository, behavior of torch -> onnx conversion

I'm not familiar with making a PR to this repository, so please let me know if there is any feedbacks or questions.

@tvm-bot
Copy link
Collaborator

tvm-bot commented Aug 7, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@jhlee525 jhlee525 changed the title aten::copy_ support in pytorch frontend [relay][frontend]aten::copy_ support in pytorch frontend Aug 7, 2023
@jhlee525 jhlee525 changed the title [relay][frontend]aten::copy_ support in pytorch frontend [relay][frontend] aten::copy_ support for pytorch Aug 7, 2023
@jhlee525
Copy link
Contributor Author

  1. Is this PR can be merged to main, in regards to design of relay architecture?
  2. The test seems failed in CI server, but it is hard to know what the problem is. In my local, my test case passes without any fault. I tried to figure out what was happened in CI server, but it seems the log doesn't show enough message. Probably it's segmentation fault. Could anybody give me an advice to handle this problem?

@rebel-jhlee
Copy link
Contributor

@masahi

@masahi
Copy link
Member

masahi commented Aug 17, 2023

You've got a segfault from your test.

@rebel-jhlee
Copy link
Contributor

@masahi It's ready for review. Sorry for late response.

@@ -4470,6 +4558,26 @@ def _run_jit_passes(graph, enable_lower_all_tuples=True):
torch._C._jit_pass_lower_all_tuples(graph)


def _redirect_inplace_output(graph):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give an example of what this pass does, by documenting IR before / after this pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok An example added

return x

inputs = torch.randn(10, 10)
verify_model(InplaceCopy(), [inputs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more tests, using various tricky examples to make sure that the conversion works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added more test, tried to test this function with various case. Please let me know about any suggestions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@jhlee525 jhlee525 requested a review from masahi October 18, 2023 02:35
Copy link
Member

@masahi masahi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's try this approach for our first copy_ support.

@masahi masahi merged commit b7aada1 into apache:main Oct 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants