-
Notifications
You must be signed in to change notification settings - Fork 52
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Can an API for dealing with composite operations be defined? #474
Comments
I'd love to know why they didn't work. |
@leofang for For
|
import numpy as np
import numba
import jax
import jax.numpy as jnp
# Here `xp` and `compiler` can be injected
compiler = jax # numba
xp = jnp # np
@compiler.jit
def abs2(x):
"""Squared absolute value"""
return xp.real(x) * xp.real(x) + xp.imag(x) * xp.imag(x)
x = xp.arange(3000, dtype=xp.complex64)
# %timeit abs2(x)
# %timeit xp.abs(x) For NumPy + Numba: >>> %timeit abs2(x)
2.33 µs ± 10.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
3.82 µs ± 56.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) For JAX: >>> %timeit abs2(x)
2.58 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
86.1 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each) For PyTorch, I tried with
There's other PyTorch compilers though, and this really should work given the right one (e.g., TorchDynamo: https://github.com/pytorch/torchdynamo#usage-example). There's also Transonic (see https://transonic.readthedocs.io/en/latest/backends/pythran.html), which has CuPy will work with Numba too: https://docs.cupy.dev/en/stable/user_guide/interoperability.html#numba. Triton has a I think the point of this issue is not that we should add an API for this or that someone should write a separate package for these (that is possible, but not sure it's high-value). More that this is the right way of doing it in principle, so every function that can easily be written this way probably should be - rather than expanding the API surface of the standard. |
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
We've had discussions several times about whether some function that is a composite of other functions already present in the standard can be added. The most recent example is gh-460, which proposed adding
abs2
(=abs(x) ** 2
). Typically this isn't worth it, unless there are very compelling reasons. Fusing function calls is something that multiple libraries have compilers for, so there may not be a gain for those libraries to add a function likeabs2
, just extra API surface. And performance-wise, the gain must be quite large for it to be justified to add a function.The discussion then turned to "is it possible to write for example a standardized way for array API consumers to element-wise apply arbitrary functions to arrays"? Something along the lines of
np.vectorize
.Another direction could be to try and write a portable JIT-able set of functions. Something like:
There's multiple options for compilers here, and they work with different array libraries. Perhaps a more future-proof direction (
np.vectorize
wasn't quite a success, andnumexpr
type string expressions are a bit of a hack as well ...).This is not a worked out proposal, just opening it here as a follow-up to the discussion in gh-460 and as a tracker issue to collect more ideas and serve as a reference for when other composite functions are proposed to be added to the standard.
The text was updated successfully, but these errors were encountered: