Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement np.intersect1d #3726

Merged
merged 4 commits into from
Jul 13, 2020
Merged

Conversation

aldragan0
Copy link
Contributor

This is an implementation of np.intersect1d from #70 (and #2078 )

@aldragan0 aldragan0 marked this pull request as ready for review July 12, 2020 20:36
@jekbradbury jekbradbury merged commit 0d81e98 into jax-ml:master Jul 13, 2020
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks! One comment about a potential performance improvement.


if return_indices:
indices = argsort(ar)
aux = ar[indices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible to use lax.sort directly to get the sorted array and the indices in a single pass (if I recall correctly, argsort uses this and then discards the sorted array).

@jekbradbury
Copy link
Contributor

I think Jake and I were reviewing at the same time! His suggestion sounds good; maybe you can make another PR to incorporate it?

bchetioui pushed a commit to bchetioui/jax that referenced this pull request Jul 13, 2020
* Implement np.intersect1d

* Add jitable helper to function

* Fix argsort failing tests

* Fix linter errors
@aldragan0 aldragan0 deleted the numpy-intersect1d branch July 13, 2020 20:07
@aldragan0
Copy link
Contributor Author

Sure, I'll add it and link to this PR

NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 14, 2020
* Implement np.intersect1d

* Add jitable helper to function

* Fix argsort failing tests

* Fix linter errors
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 14, 2020
* Implement np.intersect1d

* Add jitable helper to function

* Fix argsort failing tests

* Fix linter errors
NeilGirdhar pushed a commit to NeilGirdhar/jax that referenced this pull request Jul 24, 2020
* Implement np.intersect1d

* Add jitable helper to function

* Fix argsort failing tests

* Fix linter errors
@zhaopku
Copy link

zhaopku commented Jun 21, 2023

It seems that this function returns arrays of dynamic shapes, and is thus not compatible with jit?

@jakevdp
Copy link
Collaborator

jakevdp commented Jun 21, 2023

Correct, the semantics of numpy's intersect1d are incompatible with JAX transformations like jit.

#3614 tracks the effort to add size hints to functions that share this issue – many are done but intersect1d is not (it's been a while, but I recall one reason I didn't tackle it is that it wasn't clear to me what a natural fill value should be in the case that size is larger than the true intersection).

Many of these kinds of features are implemented on an as-needed basis. If this is important to your use-case, you might chime-in on that issue with ideas of what kind of shape-hint behavior would be useful to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants