-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented Repeat and Unique Ops in PyTorch #890
Implemented Repeat and Unique Ops in PyTorch #890
Conversation
return_index=return_index, | ||
return_inverse=return_inverse, | ||
return_counts=return_counts, | ||
axis=axis, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need these optional kwargs. I know jax implementations were doing it but I don't see why. These functions are never called with them changed.
Its enough to rely on the scope to access them
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #890 +/- ##
=======================================
Coverage 81.36% 81.37%
=======================================
Files 171 171
Lines 46811 46828 +17
Branches 11420 11421 +1
=======================================
+ Hits 38088 38105 +17
Misses 6539 6539
Partials 2184 2184
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. I'd definitely parameterize these tests to get coverage over the whole power set of options (and to shorten the tests), but I won't make it a blocker.
I think you are right. Should I parametrize the tests? |
Yeah if you agree definitely go ahead. |
a50c103
to
d561a1d
Compare
tests/link/pytorch/test_extra_ops.py
Outdated
@pytest.mark.parametrize( | ||
"return_index, return_inverse, return_counts", | ||
[ | ||
(False, True, False), | ||
(False, True, True), | ||
pytest.param( | ||
True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError) | ||
), | ||
], | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should these be distinict parametrizes to test all combinations?
@pytest.mark.parametrize( | |
"return_index, return_inverse, return_counts", | |
[ | |
(False, True, False), | |
(False, True, True), | |
pytest.param( | |
True, False, False, marks=pytest.mark.xfail(raises=NotImplementedError) | |
), | |
], | |
) | |
@pytest.mark.parametrize("return_index", (False, pytest.param(True, marks=...))) | |
@pytest.mark.parametrize("return_inverse", (False, True)) | |
@pytest.mark.parametrize("return_counts", (False, True)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't want to test all the combinations. return_index=True
should always fail. I wanna test that once instead of 4 times. The combination return_inverse=False
and return_counts=False
is tested somewhere else, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason not to test all the combinations? How fast is the test running with 3 vs the 9?
tests/link/pytorch/test_extra_ops.py
Outdated
|
||
test_value = np.arange(6, dtype="float64").reshape((3, 2)) | ||
|
||
out = pt.repeat(a, (1, 2, 3) if axis == 0 else (3, 3), axis=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, do we allow axis=None
like numpy
? It may fail because we don't yet have Reshape implemented but we should test
[[1.0, 1.0, 2.0], [1.0, 1.0, 2.0], [3.0, 3.0, 0.0]], dtype="float64" | ||
) | ||
|
||
out = pt.unique(a, axis=axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does PyTensor only allows integer axis at the moment? No None
or partial multiple axis? If we do, we should test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All good with this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't check the code, but we need to see if multiple (but not all) axis is supported by PyTensor and if so test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean multiple axis like in the case of ArgMax
and Max
?
I don't think so:
axis : int, optional
a3f478c
to
62e453d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good thanks!
Description
Implements
Repeat
andUnique
Ops inPyTorch
.Related Issue
Checklist
Type of change