Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate scipy.special.lpmn & lpmn_values #25675

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
Expand Down
26 changes: 24 additions & 2 deletions jax/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
log_softmax as log_softmax,
logit as logit,
logsumexp as logsumexp,
lpmn as lpmn,
lpmn_values as lpmn_values,
lpmn as _deprecated_lpmn,
lpmn_values as _deprecated_lpmn_values,
multigammaln as multigammaln,
ndtr as ndtr,
ndtri as ndtri,
Expand All @@ -65,3 +65,25 @@
from jax._src.third_party.scipy.special import (
fresnel as fresnel,
)

_deprecations = {
# Added Nov 20 2024
"lpmn": (
"jax.scipy.special.lpmn is deprecated; no replacement is planned.",
_deprecated_lpmn,
),
"XlaRuntimeError": (
"jax.scipy.special.lpmn_values is deprecated; no replacement is planned.",
_deprecated_lpmn_values,
),
}

import typing as _typing
if _typing.TYPE_CHECKING:
lpmn = _deprecated_lpmn
lpmn_values = _deprecated_lpmn_values
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
2 changes: 2 additions & 0 deletions tests/lax_scipy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ def scipy_fun(z):
shape=[(5,), (10,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testLpmn(self, l_max, shape, dtype):
if jtu.is_device_tpu(6, "e"):
self.skipTest("TODO(b/364258243): fails on TPU v6e")
Expand All @@ -354,6 +355,7 @@ def scipy_fun(z, m=l_max, n=l_max):
shape=[(2,), (3,), (4,), (64,)],
dtype=float_dtypes,
)
@jtu.ignore_warning(category=DeprecationWarning, message=".*scipy.special.lpmn.*")
def testNormalizedLpmnValues(self, l_max, shape, dtype):
rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9)
args_maker = lambda: [rng(shape, dtype)]
Expand Down
Loading