Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can an API for dealing with composite operations be defined? #474

Closed
rgommers opened this issue Sep 5, 2022 · 3 comments
Closed

Can an API for dealing with composite operations be defined? #474

rgommers opened this issue Sep 5, 2022 · 3 comments

Comments

@rgommers
Copy link
Member

rgommers commented Sep 5, 2022

We've had discussions several times about whether some function that is a composite of other functions already present in the standard can be added. The most recent example is gh-460, which proposed adding abs2 (= abs(x) ** 2). Typically this isn't worth it, unless there are very compelling reasons. Fusing function calls is something that multiple libraries have compilers for, so there may not be a gain for those libraries to add a function like abs2, just extra API surface. And performance-wise, the gain must be quite large for it to be justified to add a function.

The discussion then turned to "is it possible to write for example a standardized way for array API consumers to element-wise apply arbitrary functions to arrays"? Something along the lines of np.vectorize.

Another direction could be to try and write a portable JIT-able set of functions. Something like:

# Here `xp` and `compiler` can be injected

@compiler.jit
def abs2(x):
    """Squared absolute value"""
    return xp.real(x) * xp.real(x) + xp.imag(x) * xp.imag(x)

There's multiple options for compilers here, and they work with different array libraries. Perhaps a more future-proof direction (np.vectorize wasn't quite a success, and numexpr type string expressions are a bit of a hack as well ...).

This is not a worked out proposal, just opening it here as a follow-up to the discussion in gh-460 and as a tracker issue to collect more ideas and serve as a reference for when other composite functions are proposed to be added to the standard.

@rgommers rgommers added the RFC Request for comments. Feature requests and proposed changes. label Sep 5, 2022
@leofang
Copy link
Contributor

leofang commented Sep 6, 2022

np.vectorize wasn't quite a success, and numexpr type string expressions are a bit of a hack as well ...

I'd love to know why they didn't work.

@rgommers
Copy link
Member Author

rgommers commented Sep 6, 2022

@leofang for np.vectorize, from the docstring: The vectorize function is provided primarily for convenience, not for performance. The implementation is essentially a for loop. If you look at the implementation, you see that no C compiler is invoked: https://github.com/numpy/numpy/blob/2a6daf39cc4fd895ab803edf018907cb8044f821/numpy/lib/function_base.py#L2118, it's just pure Python code that happens to be a bit faster than actual for-loops; but not by all that much, there's no fusing of operations.

For numexpr, it does do fusing and is fast. It's more a design and usability issue:

  1. You need to have working a C++ compiler installed. That's a nonstarter for end users, so you can only use it in a (numpy-using) library and ship compiled extensions. So unlike something that also works with JIT compilers and pure Python code, it's much harder to write portable code (ala array API standard compliant code).
  2. The UX is strings: ne.evaluate("2*x +y + 1"). This works, but it's 2022 and that just doesn't seem like a healthy way of doing things.

@rgommers
Copy link
Member Author

rgommers commented Sep 7, 2022

import numpy as np
import numba

import jax
import jax.numpy as jnp


# Here `xp` and `compiler` can be injected
compiler = jax  # numba
xp = jnp  # np

@compiler.jit
def abs2(x):
    """Squared absolute value"""
    return xp.real(x) * xp.real(x) + xp.imag(x) * xp.imag(x)


x = xp.arange(3000, dtype=xp.complex64)
# %timeit abs2(x)
# %timeit xp.abs(x)

For NumPy + Numba:

>>> %timeit abs2(x)
2.33 µs ± 10.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
3.82 µs ± 56.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

For JAX:

>>> %timeit abs2(x)
2.58 µs ± 17.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> %timeit xp.abs(x)**2
86.1 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

For PyTorch, I tried with torch.jit.script (note that it's not modulename.jit, so it needed a tweak). It does not yield any speedup even after adding type annotations:

>>> %timeit abs2(x)
9.71 µs ± 36.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
>>> %timeit xp.abs(x)**2
9.86 µs ± 74.1 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

There's other PyTorch compilers though, and this really should work given the right one (e.g., TorchDynamo: https://github.com/pytorch/torchdynamo#usage-example).

There's also Transonic (see https://transonic.readthedocs.io/en/latest/backends/pythran.html), which has @jit and @boost (for AOT compilation) decorators, and supports Pythran, Numba and Cython as backends. It is basically a worked-out version of this basic hand-wavy "here's how composite ops could be written".

CuPy will work with Numba too: https://docs.cupy.dev/en/stable/user_guide/interoperability.html#numba.

Triton has a jit decorator, but has a lower-level programming model so pure Python + type annotations is not enough (see for example https://triton-lang.org/master/getting-started/tutorials/01-vector-add.html).

I think the point of this issue is not that we should add an API for this or that someone should write a separate package for these (that is possible, but not sure it's high-value). More that this is the right way of doing it in principle, so every function that can easily be written this way probably should be - rather than expanding the API surface of the standard.

@kgryte kgryte removed the RFC Request for comments. Feature requests and proposed changes. label Apr 4, 2024
@data-apis data-apis locked and limited conversation to collaborators Apr 4, 2024
@kgryte kgryte converted this issue into discussion #781 Apr 4, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants