Skip to content

Conversation

PaliC
Copy link
Contributor

@PaliC PaliC commented Oct 12, 2025

Summary:

Here we add correctness tests for backwards passes of ops.

This PR does the following things

  1. Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other

  2. To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset.

  3. We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. Note we don't copy tensors/args as sometimes they are views (at least in opinfo) which makes cloning difficult.

  4. There are also a bunch of unit tests added to make sure the gradient checking utils work as expected.

Test Plan:

With this really slow correctish mm implementation we get

uv run python BackendBench/scripts/main.py --suite torchbench --topn 1  --backend directory --ops "mm" --check-backwards
...
correctness score (mean pass rate over all operators): 1.00
performance score (geomean speedup over all operators): 0.00
perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00
backwards correctness score (mean pass rate over all operators which support backwards): 1.00

With the bad monkey patched implementation we get

uv run python BackendBench/scripts/main.py --suite torchbench --topn 1  --backend directory --ops "mm" --check-backwards
...
correctness score (mean pass rate over all operators): 0.00
performance score (geomean speedup over all operators): 1.00
perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00
backwards correctness score (mean pass rate over all operators which support backwards): 0.00

The following two commands with aten also work as expected (100% correctness on forwards and backwards)

``uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --check-backwards``
`uv run python BackendBench/scripts/main.py --suite torchbench --topn 2 --backend aten --check-backwards`

Todo:

  • rename is_correct -> correct_output (originally in this pr but added noise for reviewers)
  • performance tests
  • for torchbench suite put backwards checking in dataset
  • Assuming the above support ops which have conditions on their args
  • support inplace ops

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 12, 2025
@PaliC PaliC force-pushed the pr191 branch 14 times, most recently from 0f3f721 to c39c753 Compare October 16, 2025 01:03
@PaliC PaliC marked this pull request as ready for review October 16, 2025 01:04
@PaliC PaliC changed the title [WIP] Add testing for backwards passes Add testing for backwards passes Oct 16, 2025
Summary:

Here we add correctness tests for backwards passes of ops.

This PR does the following things
1) Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other

2) To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset.

3) We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. Note we don't copy tensors/args as sometimes they are views (at least in opinfo) which makes cloning difficult.

4) There are also a bunch of unit tests added to make sure the gradient checking utils work as expected.

Test Plan:

With this really slow correctish [mm implementation](https://gist.github.com/PaliC/e62859f0286f6bfa338ccb4140e9e74f) we get
```bash
uv run python BackendBench/scripts/main.py --suite torchbench --topn 1  --backend directory --ops "mm" --check-backwards
...
correctness score (mean pass rate over all operators): 1.00
performance score (geomean speedup over all operators): 0.00
perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00
backwards correctness score (mean pass rate over all operators which support backwards): 1.00
```

With the bad monkey patched implementation we get
```
uv run python BackendBench/scripts/main.py --suite torchbench --topn 1  --backend directory --ops "mm" --check-backwards
...
correctness score (mean pass rate over all operators): 0.00
performance score (geomean speedup over all operators): 1.00
perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00
backwards correctness score (mean pass rate over all operators which support backwards): 0.00
```

The following two commands with aten also work as expected (100% correctness on forwards and backwards)
```
``uv run python BackendBench/scripts/main.py --suite opinfo --backend aten --check-backwards``
`uv run python BackendBench/scripts/main.py --suite torchbench --topn 2 --backend aten --check-backwards`
```

Todo:
- [ ] rename is_correct -> correct_output (originally in this pr but added noise for reviewers)
- [ ] performance tests
- [ ] for torchbench suite put backwards checking in dataset
- [ ] Assuming the above support ops which have conditions on their args
- [ ] support inplace ops
@PaliC PaliC changed the title Add testing for backwards passes [WIP] Add testing for backwards passes Oct 16, 2025
@PaliC PaliC changed the title [WIP] Add testing for backwards passes Add testing for backwards passes Oct 16, 2025
Copy link

meta-cla bot commented Oct 17, 2025

Hi @PaliC!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant