Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ScalarLoop in torch backend #958

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

Ch0ronomato
Copy link
Contributor

Description

Adds ScalarLoop for pytorch. I do it as a loop as opposed to trying to vectorize it...lmk if I should go that approach or not.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Aug 3, 2024
@ricardoV94
Copy link
Member

@Ch0ronomato thanks for taking a stab, I left some comments above

Copy link

codecov bot commented Aug 11, 2024

Codecov Report

Attention: Patch coverage is 88.46154% with 6 lines in your changes missing coverage. Please review.

Project coverage is 81.96%. Comparing base (a377c22) to head (920f5a4).
Report is 21 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/scalar.py 84.00% 2 Missing and 2 partials ⚠️
pytensor/link/pytorch/linker.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #958      +/-   ##
==========================================
+ Coverage   81.90%   81.96%   +0.06%     
==========================================
  Files         182      182              
  Lines       47879    47914      +35     
  Branches     8617     8632      +15     
==========================================
+ Hits        39214    39272      +58     
+ Misses       6492     6474      -18     
+ Partials     2173     2168       -5     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/elemwise.py 74.13% <100.00%> (+5.38%) ⬆️
pytensor/link/pytorch/linker.py 91.66% <50.00%> (-8.34%) ⬇️
pytensor/link/pytorch/dispatch/scalar.py 72.91% <84.00%> (+12.04%) ⬆️

... and 17 files with indirect coverage changes

pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/scalar.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 changed the title Add torch scalar loop Implement ScalarLoop in torch backend Sep 1, 2024
carry = update(*carry, *constants)
return torch.stack(carry)

return torch.compiler.disable(scalar_loop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do recursive=False?

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 - these failures in the CI look a bit strange; i'll look into them before merging...hopefully they go away with merging main 😓

@Ch0ronomato
Copy link
Contributor Author

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

@ricardoV94
Copy link
Member

@ricardoV94 #1031 is blocking the elemwise test - how do you want to proceed with this pr?

If we can't elemwise it there's not much point to the ScalarLoop. Maybe we need to loop manually instead of vmap for this Op

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it's in the right direction, but need a bit more help to understand the new code if you can provide it :)

pytensor/link/pytorch/dispatch/elemwise.py Show resolved Hide resolved
pytensor/link/pytorch/dispatch/elemwise.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise")


@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set this up so we can try different shapes, but I stuck this one to get started. If you think we should add more lmk.

np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0]))

expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0)
expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm bullish on itertools stuff but I think I saw mention earlier that list comprehensions are preferred. I can refactor it if so.

from torch import is_tensor

if is_tensor(out):
return out.cpu()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably create conflict when one of my other PRs gets merged as an FYI.

final_inputs[i] = list(layer)

# make sure we still have the same number of things
assert len(final_inputs) == len(shaped_inputs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can put these into the unit test if that's preferred now.

torch.zeros(*input_shapes[-1])
]
mock_inner_func = MagicMock()
ret_value = torch.rand(2, 2).unbind(0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename to expected

mock_inner_func.f.return_value = ret_value
elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None)
result = elemwise_fn(*args)
for actual, expected in zip(ret_value, result):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are backwards fyi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants