Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: add logsumexp #596

Open
steff456 opened this issue Feb 14, 2023 · 2 comments
Open

RFC: add logsumexp #596

steff456 opened this issue Feb 14, 2023 · 2 comments
Labels
API extension Adds new functions or objects to the API. RFC Request for comments. Feature requests and proposed changes.

Comments

@steff456
Copy link
Member

steff456 commented Feb 14, 2023

This RFC seeks to include a new API in the array API specification for the purpose of computing the log of summed exponentials.

Overview

The Array API specification currently includes logaddexp which performs an element-wise operation on two input arrays, but does not include the reduction logsumexp. This API is commonly implemented in accelerator libraries for better numerical stability in deep learning applications.

This can be implemented using log(sum(exp)); however, such an implementation is not likely to be numerically stable.

Prior art

Proposal:

def logsumexp(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype: Optional[dtype] = None, keepdims: bool = False) -> array
  • dtype kwarg is for consistency with sum et al

Related

cc @kgryte

@steff456 steff456 added the API extension Adds new functions or objects to the API. label Feb 14, 2023
@kgryte kgryte added this to the v2023 milestone Jun 29, 2023
@kgryte
Copy link
Contributor

kgryte commented Jan 11, 2024

logsumexp was also mentioned as a candidate for inclusion in the special functions extension: #725. Accordingly, before moving forward with this proposal, we should first determine whether it makes sense to add in the main namespace or in that extension.

@kgryte kgryte removed this from the v2023 milestone Jan 11, 2024
@rgommers
Copy link
Member

I updated the PR description:

  • JAX has logsumexp exposed in two places (jax.nn and jax.scipy.special)
  • PyTorch also added an alias in torch.special.

So the path of least resistance is probably to add it in special. However, it would be strange to have logaddexp in the main namespace and logsumexp in an optional extension, since logsumexp is more commonly used than logaddexp, and they're very much related.

@kgryte kgryte added the RFC Request for comments. Feature requests and proposed changes. label Apr 4, 2024
@kgryte kgryte changed the title Add logsumexp function to the standard RFC: add logsumexp Apr 4, 2024
@github-project-automation github-project-automation bot moved this to Stage 1 in Proposals Aug 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. RFC Request for comments. Feature requests and proposed changes.
Projects
Status: Stage 1
Development

No branches or pull requests

3 participants