-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax devicearray prod #19724
Jax devicearray prod #19724
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @VictorOdede
Thanks for the PR, all looks great just a few changes are needed
init_tree="jax.numpy.array", | ||
method_name="prod", | ||
dtype_x_axis=helpers.dtype_values_axis( | ||
available_dtypes=["int64"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can't fully understand why this is the only supported dtype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @hirwa-nshuti , thanks for the review. I've also tried testing with other valid dtypes but the ground-truth is always returning int64
causing the tests to fail with an AssertionError
since returned dtype is different from ground truth. I haven't encountered this problem before, is there a way to solve this issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then if there is a special promotion being done should be handled in the frontend function not on the testing side to ensure that we cover all the corner cases
Hey @hirwa-nshuti , I've added type promotion for all |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey
Just a few suggested changes🙂
Keep up the good work😊
@@ -123,6 +123,22 @@ def nonzero(self, *, size=None, fill_value=None): | |||
fill_value=fill_value, | |||
) | |||
|
|||
def prod(self, axis=None): | |||
arr = self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From my understandings is that the promotion being done here should be done on the jax.numpy.product
function as we can promote all unsigned integers
to uint64
by a simple ivy.is_uint_dtype
and ivy.is_int_dtype
functions.
Here we will simply need to call the frontend function once we ensured that it is working correctly🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great pointers. Will look into this asap.
Hey @hirwa-nshuti , the tests are passing for the most part but sometimes it fails for higher precision dtypes like this: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For precision errors handling you can use the torelance values mentioned in the helper function
Thanks! Will look into this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey,
Great work so far I have added the suggestions in the comments below
Thanks
@@ -695,6 +695,13 @@ def product( | |||
promote_integers=True, | |||
out=None, | |||
): | |||
if ivy.is_uint_dtype(a.dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are only supported when the promote_integers
is True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can redefine prod function as
@to_ivy_arrays_and_back
def product(
a,
*,
axis=None,
dtype=None,
keepdims=False,
initial=None,
where=None,
promote_integers=True,
out=None,
):
if ivy.is_array(where):
a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out)
if promote_integers:
if ivy.is_uint_dtype(a.dtype):
dtype = "uint64"
elif ivy.is_int_dtype(a.dtype):
dtype = "int64"
if initial is not None:
if axis is not None:
s = ivy.to_list(ivy.shape(a, as_array=True))
s[axis] = 1
header = ivy.full(ivy.Shape(tuple(s)), initial)
a = ivy.concat([header, a], axis=axis)
else:
a[0] *= initial
return ivy.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
```
Thanks, will work on implementing this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Going to merge
Co-authored-by: hirwa-nshuti <hirwanshutiflx@gmail.com>
Close #19417