Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] Optimize before inlining #35562

Closed
wants to merge 11 commits into from

Conversation

eellison
Copy link
Contributor

Resubmit of #35424, only this time I run optimizations in the right order so the PR description is actually true.

This speeds up the inlining pass of FairSeq model from 180s -> 13s, and MaskRCNN model from 5s -> 1.5s.

@eellison eellison requested a review from ZolotukhinM March 27, 2020 19:43
@eellison eellison requested a review from apaszke as a code owner March 27, 2020 19:43
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Mar 27, 2020
@dr-ci
Copy link

dr-ci bot commented Mar 27, 2020

💊 CircleCI build failures summary and remediations

As of commit c51bfc1 (more details on the Dr. CI page):


None of the build failures appear to be your fault 💚


  • 4/4 broken upstream at merge base f421cf3 on Apr 06 from 2:27pm to 6:52pm PDT (15 commits; 8ef82fc - 4c14005)

    Please rebase on the viable/strict branch (expand for instructions)

    If your commit is newer than viable/strict, you can try basing on an older, stable commit:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase --onto FETCH_HEAD $(git merge-base origin/master HEAD)
    

    If your commit is older than viable/strict:

    git fetch https://github.com/pytorch/pytorch viable/strict
    git rebase FETCH_HEAD
    

    Check out the recency history of this "viable master" tracking branch.


🚧 4 upstream failures:

These were probably caused by upstream breakages:


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 88 times.

Copy link

@ZolotukhinM ZolotukhinM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@BowenBao
Copy link
Collaborator

BowenBao commented Mar 31, 2020

Edit:
Same error as #35401. @houseroad

Mar 31 01:07:01     def test_dim(self):
Mar 31 01:07:01         class DimModel(torch.jit.ScriptModule):
Mar 31 01:07:01             @torch.jit.script_method
Mar 31 01:07:01             def forward(self, input):
Mar 31 01:07:01                 out = input * 2
Mar 31 01:07:01                 out *= out.dim()
Mar 31 01:07:01                 return out
Mar 31 01:07:01         empty_input = torch.randn(0, requires_grad=True)
Mar 31 01:07:01         multi_dim_input = torch.randn(1, 2, 3, requires_grad=True)
Mar 31 01:07:01         self.run_test(DimModel(), empty_input)
Mar 31 01:07:01 >       self.run_test(DimModel(), multi_dim_input)
Mar 31 01:07:01 
Mar 31 01:07:01 test/onnx/test_pytorch_onnx_onnxruntime.py:2684: 
Mar 31 01:07:01 _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
Mar 31 01:07:01 test/onnx/test_pytorch_onnx_onnxruntime.py:126: in run_test
Mar 31 01:07:01     _run_test(model)
Mar 31 01:07:01 test/onnx/test_pytorch_onnx_onnxruntime.py:122: in _run_test
Mar 31 01:07:01     fixed_batch_size=fixed_batch_size)
Mar 31 01:07:01 test/onnx/test_pytorch_onnx_onnxruntime.py:86: in run_model_test
Mar 31 01:07:01     ort_test_with_input(ort_sess, input_copy, output, rtol, atol)
Mar 31 01:07:01 test/onnx/test_pytorch_onnx_onnxruntime.py:47: in ort_test_with_input
Mar 31 01:07:01     [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
Mar 31 01:07:01 _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
Mar 31 01:07:01 
Mar 31 01:07:01 .0 = <zip object at 0x7f3d38393d48>
Mar 31 01:07:01 
Mar 31 01:07:01 >   [np.testing.assert_allclose(out, ort_out, rtol=rtol, atol=atol) for out, ort_out in zip(outputs, ort_outs)]
Mar 31 01:07:01 E   AssertionError: 
Mar 31 01:07:01 E   Not equal to tolerance rtol=0.001, atol=1e-07
Mar 31 01:07:01 E   
Mar 31 01:07:01 E   Mismatched elements: 6 / 6 (100%)
Mar 31 01:07:01 E   Max absolute difference: 8.715158
Mar 31 01:07:01 E   Max relative difference: 0.6666667
Mar 31 01:07:01 E    x: array([[[ 3.081992, -0.586858, -4.357579],
Mar 31 01:07:01 E           [ 1.136863, -2.169045, -2.797191]]], dtype=float32)
Mar 31 01:07:01 E    y: array([[[  9.245976,  -1.760573, -13.072737],
Mar 31 01:07:01 E           [  3.410588,  -6.507134,  -8.391573]]], dtype=float32)

@eellison
Copy link
Contributor Author

@BowenBao how do you recommend I proceed ? as far as I can tell the test is not failing on master.

@BowenBao
Copy link
Collaborator

@eellison yea I checked a few other PRs and they don't seem to fail on this one, so I edited my previous comment. One thing bothers me is that you mentioned in the other PR that you could not repro this error locally. Let me change the randn in the failed test_dim to ones, this should help debugging.

@eellison
Copy link
Contributor Author

@BowenBao could you check out the PR ?

@eellison eellison force-pushed the optimize_inlining branch from af6ceaf to c7aabfb Compare April 3, 2020 16:49
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@eellison
Copy link
Contributor Author

eellison commented Apr 3, 2020

@sanekmelnikov @J0Nreynolds I'm getting a tensorboard test failure here. i'm having trouble running the test for unrelated reasons

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@eellison merged this pull request in 6bc8ffe.

ashishfarmer pushed a commit to ashishfarmer/pytorch that referenced this pull request Apr 13, 2020
Summary:
Resubmit of pytorch#35424, only this time I run optimizations in the right order so the PR description is actually true.

This speeds up the inlining pass of FairSeq model from 180s -> 13s, and MaskRCNN model from 5s -> 1.5s.
Pull Request resolved: pytorch#35562

Differential Revision: D20738922

Pulled By: eellison

fbshipit-source-id: 1439cf9d1f0bc780e2d64a744694f8b3b7ba4b70
facebook-github-bot pushed a commit that referenced this pull request Apr 16, 2020
Summary:
With #35562, we are running peephole optimization on inlining to reduce the number of nodes that are copied.

The tracer encodes the sizes in the graph like:
```
graph(%0 : Double(7)):
  %1 : Function = prim::Constant[name="tensor_size"]()
  %2 : Tensor = prim::CallFunction(%1, %0)
  return (%2)
```

however people would like to reuse the graph with different shapes so running size invalidations would invalidate that. long term it might be better for the tracer to not include shape information but there are downstream users of that.

Separates out FuseAddMM from peephole so that now there is a single `disable_size_optimizations` parameter, and onnx explicitly invokes fuseaddmm.
Pull Request resolved: #36404

Differential Revision: D20968974

Pulled By: eellison

fbshipit-source-id: 56f8f1699e3b0adeeccdfd5a67bb975fd41a2913
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants