Skip to content

Commit

Permalink
Remove longtime deprecated functions.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707025703
  • Loading branch information
mtthss authored and OptaxDev committed Dec 17, 2024
1 parent 78fe20e commit 73b716d
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 28 deletions.
2 changes: 0 additions & 2 deletions docs/api/optimizer_wrappers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ Optimizer Wrappers
LookaheadState
masked
MaskedState
maybe_update
MaybeUpdateState
MultiSteps
MultiStepsState
ShouldSkipUpdateFunction
Expand Down
4 changes: 0 additions & 4 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@
from optax._src.utils import multi_normal
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state
from optax._src.wrappers import maybe_update
from optax._src.wrappers import MaybeUpdateState

# TODO(mtthss): remove tree_utils aliases after updates.
adaptive_grad_clip = transforms.adaptive_grad_clip
Expand Down Expand Up @@ -372,8 +370,6 @@
"MaskOrFn",
"MaskedState",
"matrix_inverse_pth_root",
"maybe_update",
"MaybeUpdateState",
"multi_normal",
"multi_transform",
"MultiSteps",
Expand Down
22 changes: 0 additions & 22 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
# ==============================================================================
"""Transformation wrappers."""

from collections.abc import Callable
import functools

import chex
import jax.numpy as jnp
from optax._src import base
from optax.transforms import _accumulation
from optax.transforms import _conditionality
from optax.transforms import _layouts
Expand All @@ -42,19 +36,3 @@
ShouldSkipUpdateFunction = _accumulation.ShouldSkipUpdateFunction
skip_not_finite = _accumulation.skip_not_finite
skip_large_updates = _accumulation.skip_large_updates


@functools.partial(
chex.warn_deprecated_function,
replacement='optax.transforms.maybe_transform',
)
def maybe_update(
inner: base.GradientTransformation,
should_update_fn: Callable[[jnp.ndarray], jnp.ndarray],
) -> base.GradientTransformationExtraArgs:
return conditionally_transform(
inner=inner, should_transform_fn=should_update_fn
)


MaybeUpdateState = ConditionallyTransformState

0 comments on commit 73b716d

Please sign in to comment.