Add testing for backwards passes #191
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Here we add correctness tests for backwards passes of ops.
This PR does the following things
Figures out which ops not to test. (explained in depth at the top of BackendBench/backwards_utils.py + avoiding inplace ops) For simplcity we are not testing a) in place ops as we cannot just pass in the test args, but need special casing b) ops that require special handling with their args, c) one off corner cases. Every other
To do backwards passes (since the tensors naturally don't require grad in our suites), right now we add a gradient to all tensors in args and kwargs. This logic (+ test for if we should even run a backwards pass) is put in the suite as this can be handled on a per test level. For example in a follow up PR for this, we can add a backwards pass column in the torchbench dataset.
We also compare gradients and clear gradients after use to validate the backwards pass. We use the same allclose function as before. Note we don't copy tensors/args as sometimes they are views (at least in opinfo) which makes cloning difficult.
There are also a bunch of unit tests added to make sure the gradient checking utils work as expected.
Test Plan:
With this really slow correctish mm implementation we get
uv run python BackendBench/scripts/main.py --suite torchbench --topn 1 --backend directory --ops "mm" --check-backwards ... correctness score (mean pass rate over all operators): 1.00 performance score (geomean speedup over all operators): 0.00 perf@p score (rate of correct samples with a speedup greater than p, p=1.0): 0.00 backwards correctness score (mean pass rate over all operators which support backwards): 1.00
With the bad monkey patched implementation we get
The following two commands with aten also work as expected (100% correctness on forwards and backwards)
Todo: