Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax devicearray prod #19724

Merged
merged 16 commits into from
Aug 16, 2023
Merged

Jax devicearray prod #19724

merged 16 commits into from
Aug 16, 2023

Conversation

VictorOdede
Copy link
Contributor

Close #19417

@ivy-leaves ivy-leaves added the JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist label Jul 20, 2023
Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @VictorOdede

Thanks for the PR, all looks great just a few changes are needed

init_tree="jax.numpy.array",
method_name="prod",
dtype_x_axis=helpers.dtype_values_axis(
available_dtypes=["int64"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't fully understand why this is the only supported dtype

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @hirwa-nshuti , thanks for the review. I've also tried testing with other valid dtypes but the ground-truth is always returning int64 causing the tests to fail with an AssertionError since returned dtype is different from ground truth. I haven't encountered this problem before, is there a way to solve this issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then if there is a special promotion being done should be handled in the frontend function not on the testing side to ensure that we cover all the corner cases

@VictorOdede
Copy link
Contributor Author

Hey @hirwa-nshuti , I've added type promotion for all int and uint dtypes. Now all tests are passing except for Paddle which sometimes fails because it doesn't support uint64

@VictorOdede VictorOdede requested a review from fnhirwa July 31, 2023 08:35
Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey
Just a few suggested changes🙂
Keep up the good work😊

@@ -123,6 +123,22 @@ def nonzero(self, *, size=None, fill_value=None):
fill_value=fill_value,
)

def prod(self, axis=None):
arr = self
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understandings is that the promotion being done here should be done on the jax.numpy.product function as we can promote all unsigned integers to uint64 by a simple ivy.is_uint_dtype and ivy.is_int_dtype functions.

Here we will simply need to call the frontend function once we ensured that it is working correctly🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great pointers. Will look into this asap.

@VictorOdede VictorOdede requested a review from fnhirwa August 4, 2023 14:32
@VictorOdede
Copy link
Contributor Author

Hey @hirwa-nshuti , the tests are passing for the most part but sometimes it fails for higher precision dtypes like this:
jax_error
Is this the expected behaviour?

Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VictorOdede
Copy link
Contributor Author

For precision errors handling you can use the torelance values mentioned in the helper function

https://github.com/unifyai/ivy/blob/356e6e5654970ee4b6f7cb9b41b3b572436e1262/ivy_tests/test_ivy/helpers/function_testing.py#L485C4-L485C4

Thanks! Will look into this.

@VictorOdede VictorOdede requested a review from fnhirwa August 10, 2023 08:59
Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey,
Great work so far I have added the suggestions in the comments below
Thanks

@@ -695,6 +695,13 @@ def product(
promote_integers=True,
out=None,
):
if ivy.is_uint_dtype(a.dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only supported when the promote_integers is True

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can redefine prod function as

@to_ivy_arrays_and_back
def product(
    a,
    *,
    axis=None,
    dtype=None,
    keepdims=False,
    initial=None,
    where=None,
    promote_integers=True,
    out=None,
):
    if ivy.is_array(where):
        a = ivy.where(where, a, ivy.default(out, ivy.ones_like(a)), out=out)
    if promote_integers:
        if ivy.is_uint_dtype(a.dtype):
            dtype = "uint64"
        elif ivy.is_int_dtype(a.dtype):
            dtype = "int64"
    if initial is not None:
        if axis is not None:
            s = ivy.to_list(ivy.shape(a, as_array=True))
            s[axis] = 1
            header = ivy.full(ivy.Shape(tuple(s)), initial)
            a = ivy.concat([header, a], axis=axis)
        else:
            a[0] *= initial
    return ivy.prod(a, axis=axis, dtype=dtype, keepdims=keepdims, out=out)
    ```

@VictorOdede
Copy link
Contributor Author

Hey, Great work so far I have added the suggestions in the comments below Thanks

Thanks, will work on implementing this

@VictorOdede VictorOdede requested a review from fnhirwa August 11, 2023 13:34
Copy link
Contributor

@fnhirwa fnhirwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Going to merge

@fnhirwa fnhirwa merged commit aeac992 into ivy-llc:main Aug 16, 2023
sushmanthreddy pushed a commit to sushmanthreddy/ivy that referenced this pull request Aug 17, 2023
Co-authored-by: hirwa-nshuti <hirwanshutiflx@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
JAX Frontend Developing the JAX Frontend, checklist triggered by commenting add_frontend_checklist
Projects
None yet
Development

Successfully merging this pull request may close these issues.

prod
3 participants