Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Eye Op in PyTorch #877

Merged
merged 6 commits into from
Jul 7, 2024

Conversation

twaclaw
Copy link
Contributor

@twaclaw twaclaw commented Jul 2, 2024

Description

  • Implemented Eye Op with diagonal offset in PyTorch
    - torch.eye doesn't support a diagonal offset param k directly. The implementation of this offset required composing two torch functions.
  • Added tests for different dimensions and diagonal offset, including corner cases.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Copy link

codecov bot commented Jul 3, 2024

Codecov Report

Attention: Patch coverage is 76.92308% with 3 lines in your changes missing coverage. Please review.

Project coverage is 81.05%. Comparing base (ca10298) to head (08d23ab).
Report is 120 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/basic.py 76.92% 2 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #877      +/-   ##
==========================================
- Coverage   81.05%   81.05%   -0.01%     
==========================================
  Files         170      170              
  Lines       46980    46992      +12     
  Branches    11509    11510       +1     
==========================================
+ Hits        38079    38088       +9     
- Misses       6695     6697       +2     
- Partials     2206     2207       +1     
Files with missing lines Coverage Δ
pytensor/link/pytorch/dispatch/basic.py 83.82% <76.92%> (-1.90%) ⬇️

@twaclaw
Copy link
Contributor Author

twaclaw commented Jul 3, 2024

@ricardoV94, is there something odd with the coverage reports?

@ricardoV94
Copy link
Member

@ricardoV94, is there something odd with the coverage reports?

It's flaky don't worry

@ricardoV94 ricardoV94 added enhancement New feature or request torch PyTorch backend labels Jul 4, 2024
Comment on lines 247 to 254
trange = range(1, 6)
for _N in trange:
for _M in trange:
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
compare_pytorch_and_py(
FunctionGraph([N, M, k], [out]),
[np.array(_N), np.array(_M), np.array(_k)],
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case let's compile a single pytensor function and call repeatedly with the different inputs, instead of creating a function for every combination of test values. Something like:

fn = pytensor.function([N, M, k], out, mode=torch_mode)  # Whatever mode is used internally in `compare_pytorch_and_py`
for _N in trange....:
  ...
  np.testing.assert_allclose(fn(_N, _M, _k), np.eye(_N, _M, _k))

M = scalar("M", dtype="int64")
k = scalar("k", dtype="int64")

out = eye(N, M, k, dtype="float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to specify dtype here? It should be handled automatically by the config.floatX setting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If dtype is something the user can control directly we should probably test with default and a user-specified non-default?

pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
pytensor/link/pytorch/dispatch/basic.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_basic.py Show resolved Hide resolved
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
tests/link/pytorch/test_basic.py Outdated Show resolved Hide resolved
@jessegrabowski
Copy link
Member

Left some comments, some are style choices (feel free to push back on those) but some need to be addressed (pytorch_mode), so I marked it as changes needed. The branch also needs to be rebased.

@twaclaw
Copy link
Contributor Author

twaclaw commented Jul 6, 2024 via email

@ricardoV94
Copy link
Member

ricardoV94 commented Jul 6, 2024

I guess just test dtype=config.floatX vs dtype="int64". Because that requires a different function I would use parametrize for that

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why codecov is whining, but this looks good to go on my end.

@ricardoV94
Copy link
Member

I think codecov cannot see line hits the way pytorch uses the functions

@ricardoV94 ricardoV94 merged commit 4ea96b2 into pymc-devs:main Jul 7, 2024
56 of 57 checks passed
@ricardoV94
Copy link
Member

Cool stuff @twaclaw!

@twiecki
Copy link
Member

twiecki commented Aug 9, 2024

@twaclaw I've been trying to contact you through LI, or maybe send me an email.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants