diff --git a/docs/api/optimizer_wrappers.rst b/docs/api/optimizer_wrappers.rst index 2e534f193..7749fbd6c 100644 --- a/docs/api/optimizer_wrappers.rst +++ b/docs/api/optimizer_wrappers.rst @@ -12,8 +12,6 @@ Optimizer Wrappers LookaheadState masked MaskedState - maybe_update - MaybeUpdateState MultiSteps MultiStepsState ShouldSkipUpdateFunction diff --git a/optax/__init__.py b/optax/__init__.py index 0840b1d5a..c5f1d5494 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -140,8 +140,6 @@ from optax._src.utils import multi_normal from optax._src.utils import scale_gradient from optax._src.utils import value_and_grad_from_state -from optax._src.wrappers import maybe_update -from optax._src.wrappers import MaybeUpdateState # TODO(mtthss): remove tree_utils aliases after updates. adaptive_grad_clip = transforms.adaptive_grad_clip @@ -372,8 +370,6 @@ "MaskOrFn", "MaskedState", "matrix_inverse_pth_root", - "maybe_update", - "MaybeUpdateState", "multi_normal", "multi_transform", "MultiSteps", diff --git a/optax/_src/wrappers.py b/optax/_src/wrappers.py index d94334c23..d4c92cbc1 100644 --- a/optax/_src/wrappers.py +++ b/optax/_src/wrappers.py @@ -14,12 +14,6 @@ # ============================================================================== """Transformation wrappers.""" -from collections.abc import Callable -import functools - -import chex -import jax.numpy as jnp -from optax._src import base from optax.transforms import _accumulation from optax.transforms import _conditionality from optax.transforms import _layouts @@ -42,19 +36,3 @@ ShouldSkipUpdateFunction = _accumulation.ShouldSkipUpdateFunction skip_not_finite = _accumulation.skip_not_finite skip_large_updates = _accumulation.skip_large_updates - - -@functools.partial( - chex.warn_deprecated_function, - replacement='optax.transforms.maybe_transform', -) -def maybe_update( - inner: base.GradientTransformation, - should_update_fn: Callable[[jnp.ndarray], jnp.ndarray], -) -> base.GradientTransformationExtraArgs: - return conditionally_transform( - inner=inner, should_transform_fn=should_update_fn - ) - - -MaybeUpdateState = ConditionallyTransformState