Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix MSE foward, use decomposition for MSE backward #860

Merged
merged 5 commits into from
Jun 9, 2022
Merged

Conversation

samdow
Copy link
Contributor

@samdow samdow commented Jun 7, 2022

Because MSE backward was redispatching, it ended up having some shape errors if it redispatched twice. This switches it to use the decomposition from core so that we don't try to guess based on the size of the input if it came from a redispatch or an original call. Specifically:

  • when reduction != none, the original output would be of size [N]. We would be padding the grad_output with 1s if need be and then redispatching (with reduction == none). This led to issues because reduction == none assumes that grad output and self have the same number of elements
  • when reduction == none, often times the grad_output would be flattened (a little unclear why this was happening, tbh). This was causing reshape errors. If we tried to fix this by just reshaping grad_output to match self, we would end up with errors from the redispatch because grad_output wouldn't have enough elements in that case

There were also some issues in the forward rule where if reduction was none, the output would be flattened where it shouldn't be. Specifically, if you had

input: [B, in_0, in_1, ...]
target: [B, in_0, in_1, ...]
output (should be): [B, in_0, in_1, ...]
output (produced): [B, in_0 * in_1 * ...]

This was not caught because all op info tests that had a none reduction only had 1D inputs. This adds new tests and fixes the forward + backward formula

Fixes #858

@Chillee
Copy link
Contributor

Chillee commented Jun 8, 2022

Can we just directly use the decomposition instead of writing a batching rule for it? It doesn't look like the batching rules is doing any fusion of the ops, so I don't think the perf will be significantly different.

@samdow
Copy link
Contributor Author

samdow commented Jun 8, 2022

Can we just directly use the decomposition instead of writing a batching rule for it?

Sure. The issue right now is that the decomposition in core is wrong in the case where reduction=none because of some shaping stuff. It will end up with the same errors theseus was seeing. So I'll need to add a test for that case in core, fix it, and then set up the infra here to use that. Because we're coming up on the release, I'm thinking that I'll merge this so that we have a fix that will definitely be in the release and then updating the decomposition

@samdow samdow force-pushed the fix_mse_backward branch 2 times, most recently from ed21e1f to fe8dd85 Compare June 9, 2022 00:03
@samdow samdow changed the title Use decomposition for MSE backward Fix MSE Foward, use decomposition for MSE backward Jun 9, 2022
@samdow samdow changed the title Fix MSE Foward, use decomposition for MSE backward Fix MSE foward, use decomposition for MSE backward Jun 9, 2022
@samdow samdow merged commit 26d4cfc into main Jun 9, 2022
@samdow samdow mentioned this pull request Jun 9, 2022
zou3519 pushed a commit that referenced this pull request Jun 13, 2022
* use decomposition for mse backward

* only reshape if there was no reduction

* add tests, fix shape of mse loss forward

* remove mse xfail

* simplify backwards rule
zou3519 added a commit that referenced this pull request Jun 15, 2022
* Fix MSE forward, use decomposition for MSE backward (#860)

* use decomposition for mse backward

* only reshape if there was no reduction

* add tests, fix shape of mse loss forward

* remove mse xfail

* simplify backwards rule

* [release] fixup previous commit

* Remove test/test_functorch_lagging_op_db.py (#845)

These tests are expected to fail, but we didn't communicate that very
well and:
1. we have gotten multiple questions about them
2. we need to special case it in our CI
3. we don't even use the test anymore!

So we are deleting it.

Related: #835

* Disable calling Tensor.requires_grad_() inside a functorch transform (#849)

Fixes #847

We do not allow users to call requires_grad_() inside a functorch
transform. This is because the user is effectively saying
"hey, I want another layer of autograd if I call requires_grad_()", but
that doesn't actually work because to set up a layer of autograd we need
to do some work (e.g. push autograd onto the DynamicLayerStack).

Instead, when a user calls requires_grad_() (and similarly retain_grad),
we raise a nice error message.

This has the intended consequence of causing
torch.autograd.functional.{jvp, vjp, jacobian} to error out when called
inside of a functorch transform. Users should use the functorch
equivalent.

Test Plan:
- added tests

* "set_inplace_requires_grad_allowed" should be a context manager (#870)

Test Plan:
- run existing tests; code reading

* Fix index.Tensor, index_put batching rules (#862)

Fixes #859

Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]`
in the code for details. This PR rewrites the index.Tensor and index_put
batching rules.

The TL;DR is:
- advanced indexing has different behavior depending on if the "advanced
indices are adjacent":
https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
- we have to take this into account in our batching rules, because
index.Tensor and index_put handle these internally.

Test Plan
- I added new test cases for getitem and aten.ops.index_put via OpInfo
testing.

Future
- primtorch should have a sane decomposition that we can use
- We haven't fixed the index_put_ batching rule yet. TODO later...
- Upstream our test cases (see next section) into pytorch/pytorch

Co-authored-by: Samantha Andow <samdow@fb.com>
zou3519 pushed a commit that referenced this pull request Jun 15, 2022
* use decomposition for mse backward

* only reshape if there was no reduction

* add tests, fix shape of mse loss forward

* remove mse xfail

* simplify backwards rule
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…rch/functorch#860)

* use decomposition for mse backward

* only reshape if there was no reduction

* add tests, fix shape of mse loss forward

* remove mse xfail

* simplify backwards rule
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…rch/functorch#860)

* use decomposition for mse backward

* only reshape if there was no reduction

* add tests, fix shape of mse loss forward

* remove mse xfail

* simplify backwards rule
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

jacrev(vjp(mse_loss)) raises an error
3 participants