-
Notifications
You must be signed in to change notification settings - Fork 102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix MSE foward, use decomposition for MSE backward #860
Conversation
Can we just directly use the decomposition instead of writing a batching rule for it? It doesn't look like the batching rules is doing any fusion of the ops, so I don't think the perf will be significantly different. |
Sure. The issue right now is that the decomposition in core is wrong in the case where reduction=none because of some shaping stuff. It will end up with the same errors theseus was seeing. So I'll need to add a test for that case in core, fix it, and then set up the infra here to use that. Because we're coming up on the release, I'm thinking that I'll merge this so that we have a fix that will definitely be in the release and then updating the decomposition |
ed21e1f
to
fe8dd85
Compare
* use decomposition for mse backward * only reshape if there was no reduction * add tests, fix shape of mse loss forward * remove mse xfail * simplify backwards rule
* Fix MSE forward, use decomposition for MSE backward (#860) * use decomposition for mse backward * only reshape if there was no reduction * add tests, fix shape of mse loss forward * remove mse xfail * simplify backwards rule * [release] fixup previous commit * Remove test/test_functorch_lagging_op_db.py (#845) These tests are expected to fail, but we didn't communicate that very well and: 1. we have gotten multiple questions about them 2. we need to special case it in our CI 3. we don't even use the test anymore! So we are deleting it. Related: #835 * Disable calling Tensor.requires_grad_() inside a functorch transform (#849) Fixes #847 We do not allow users to call requires_grad_() inside a functorch transform. This is because the user is effectively saying "hey, I want another layer of autograd if I call requires_grad_()", but that doesn't actually work because to set up a layer of autograd we need to do some work (e.g. push autograd onto the DynamicLayerStack). Instead, when a user calls requires_grad_() (and similarly retain_grad), we raise a nice error message. This has the intended consequence of causing torch.autograd.functional.{jvp, vjp, jacobian} to error out when called inside of a functorch transform. Users should use the functorch equivalent. Test Plan: - added tests * "set_inplace_requires_grad_allowed" should be a context manager (#870) Test Plan: - run existing tests; code reading * Fix index.Tensor, index_put batching rules (#862) Fixes #859 Start reading at `NOTE: [advanced indexing (index.Tensor) batch rule]` in the code for details. This PR rewrites the index.Tensor and index_put batching rules. The TL;DR is: - advanced indexing has different behavior depending on if the "advanced indices are adjacent": https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing - we have to take this into account in our batching rules, because index.Tensor and index_put handle these internally. Test Plan - I added new test cases for getitem and aten.ops.index_put via OpInfo testing. Future - primtorch should have a sane decomposition that we can use - We haven't fixed the index_put_ batching rule yet. TODO later... - Upstream our test cases (see next section) into pytorch/pytorch Co-authored-by: Samantha Andow <samdow@fb.com>
* use decomposition for mse backward * only reshape if there was no reduction * add tests, fix shape of mse loss forward * remove mse xfail * simplify backwards rule
…rch/functorch#860) * use decomposition for mse backward * only reshape if there was no reduction * add tests, fix shape of mse loss forward * remove mse xfail * simplify backwards rule
…rch/functorch#860) * use decomposition for mse backward * only reshape if there was no reduction * add tests, fix shape of mse loss forward * remove mse xfail * simplify backwards rule
Because MSE backward was redispatching, it ended up having some shape errors if it redispatched twice. This switches it to use the decomposition from core so that we don't try to guess based on the size of the input if it came from a redispatch or an original call. Specifically:
There were also some issues in the forward rule where if reduction was none, the output would be flattened where it shouldn't be. Specifically, if you had
This was not caught because all op info tests that had a none reduction only had 1D inputs. This adds new tests and fixes the forward + backward formula
Fixes #858