-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement ScalarLoop in torch backend #958
base: main
Are you sure you want to change the base?
Conversation
@Ch0ronomato thanks for taking a stab, I left some comments above |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #958 +/- ##
==========================================
+ Coverage 81.90% 81.96% +0.06%
==========================================
Files 182 182
Lines 47879 47914 +35
Branches 8617 8632 +15
==========================================
+ Hits 39214 39272 +58
+ Misses 6492 6474 -18
+ Partials 2173 2168 -5
|
carry = update(*carry, *constants) | ||
return torch.stack(carry) | ||
|
||
return torch.compiler.disable(scalar_loop) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do recursive=False?
@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓 |
@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr? |
If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it's in the right direction, but need a bit more help to understand the new code if you can provide it :)
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") | ||
|
||
|
||
@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.
np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) | ||
|
||
expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) | ||
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.
from torch import is_tensor | ||
|
||
if is_tensor(out): | ||
return out.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will probably create conflict when one of my other PRs gets merged as an FYI.
final_inputs[i] = list(layer) | ||
|
||
# make sure we still have the same number of things | ||
assert len(final_inputs) == len(shaped_inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can put these into the unit test if that's preferred now.
torch.zeros(*input_shapes[-1]) | ||
] | ||
mock_inner_func = MagicMock() | ||
ret_value = torch.rand(2, 2).unbind(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename to expected
mock_inner_func.f.return_value = ret_value | ||
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) | ||
result = elemwise_fn(*args) | ||
for actual, expected in zip(ret_value, result): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are backwards fyi
Description
Adds
ScalarLoop
for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.Related Issue
Checklist
Type of change