Skip to content

Commit

Permalink
In flatname space import from the subpackage not from _src.
Browse files Browse the repository at this point in the history
This is step 1 to then remove the code in _src once no active ckpts depend on it.

PiperOrigin-RevId: 702292661
  • Loading branch information
mtthss authored and OptaxDev committed Dec 3, 2024
1 parent 02a1bd7 commit 5fd744a
Showing 1 changed file with 43 additions and 42 deletions.
85 changes: 43 additions & 42 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,6 @@
from optax._src.base import TransformUpdateFn
from optax._src.base import Updates
from optax._src.base import with_extra_args_support
from optax._src.clipping import adaptive_grad_clip
from optax._src.clipping import AdaptiveGradClipState
from optax._src.clipping import clip
from optax._src.clipping import clip_by_block_rms
from optax._src.clipping import clip_by_global_norm
from optax._src.clipping import ClipByGlobalNormState
from optax._src.clipping import ClipState
from optax._src.clipping import per_example_global_norm_clip
from optax._src.clipping import per_example_layer_norm_clip
from optax._src.combine import chain
from optax._src.combine import multi_transform
from optax._src.combine import MultiTransformState
from optax._src.combine import named_chain
from optax._src.constrain import keep_params_nonnegative
from optax._src.constrain import NonNegativeParamsState
from optax._src.constrain import zero_nans
from optax._src.constrain import ZeroNansState
from optax._src.factorized import FactoredState
from optax._src.factorized import scale_by_factored_rms
from optax._src.linear_algebra import global_norm
Expand All @@ -107,15 +90,9 @@
from optax._src.numerics import safe_int32_increment
from optax._src.numerics import safe_norm
from optax._src.numerics import safe_root_mean_squares
from optax._src.transform import add_decayed_weights
from optax._src.transform import add_noise
from optax._src.transform import AddDecayedWeightsState
from optax._src.transform import AddNoiseState
from optax._src.transform import apply_every
from optax._src.transform import ApplyEvery
from optax._src.transform import centralize
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
Expand Down Expand Up @@ -157,34 +134,58 @@
from optax._src.transform import ScaleByRStdDevState
from optax._src.transform import ScaleByScheduleState
from optax._src.transform import ScaleBySM3State
from optax._src.transform import ScaleByTrustRatioState
from optax._src.transform import ScaleState
from optax._src.transform import trace
from optax._src.transform import TraceState
from optax._src.update import apply_updates
from optax._src.update import incremental_update
from optax._src.update import periodic_update
from optax._src.utils import multi_normal
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state
from optax._src.wrappers import apply_if_finite
from optax._src.wrappers import ApplyIfFiniteState
from optax._src.wrappers import conditionally_mask
from optax._src.wrappers import conditionally_transform
from optax._src.wrappers import ConditionallyMaskState
from optax._src.wrappers import ConditionallyTransformState
from optax._src.wrappers import flatten
from optax._src.wrappers import masked
from optax._src.wrappers import MaskedNode
from optax._src.wrappers import MaskedState
from optax._src.wrappers import maybe_update
from optax._src.wrappers import MaybeUpdateState
from optax._src.wrappers import MultiSteps
from optax._src.wrappers import MultiStepsState
from optax._src.wrappers import ShouldSkipUpdateFunction
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite

# TODO(mtthss): remove tree_utils aliases after updates.
adaptive_grad_clip = transforms.adaptive_grad_clip
AdaptiveGradClipState = EmptyState
clip = transforms.clip
clip_by_block_rms = transforms.clip_by_block_rms
clip_by_global_norm = transforms.clip_by_global_norm
ClipByGlobalNormState = EmptyState
ClipState = EmptyState
per_example_global_norm_clip = transforms.per_example_global_norm_clip
per_example_layer_norm_clip = transforms.per_example_layer_norm_clip
keep_params_nonnegative = transforms.keep_params_nonnegative
NonNegativeParamsState = transforms.NonNegativeParamsState
zero_nans = transforms.zero_nans
ZeroNansState = transforms.ZeroNansState
chain = transforms.chain
multi_transform = transforms.partition
MultiTransformState = transforms.PartitionState
named_chain = transforms.named_chain
trace = transforms.trace
TraceState = transforms.TraceState
ema = transforms.ema
EmaState = transforms.EmaState
add_noise = transforms.add_noise
AddNoiseState = transforms.AddNoiseState
add_decayed_weights = transforms.add_decayed_weights
AddDecayedWeightsState = EmptyState
ScaleByTrustRatioState = EmptyState
ScaleState = EmptyState
apply_if_finite = transforms.apply_if_finite
ApplyIfFiniteState = transforms.ApplyIfFiniteState
conditionally_mask = transforms.conditionally_mask
conditionally_transform = transforms.conditionally_transform
ConditionallyMaskState = transforms.ConditionallyMaskState
ConditionallyTransformState = transforms.ConditionallyTransformState
flatten = transforms.flatten
masked = transforms.masked
MaskedNode = transforms.MaskedNode
MaskedState = transforms.MaskedState
MultiSteps = transforms.MultiSteps
MultiStepsState = transforms.MultiStepsState
ShouldSkipUpdateFunction = transforms.ShouldSkipUpdateFunction
skip_large_updates = transforms.skip_large_updates
skip_not_finite = transforms.skip_not_finite

# TODO(mtthss): remove tree_utils aliases after updates.
tree_map_params = tree_utils.tree_map_params
Expand Down

0 comments on commit 5fd744a

Please sign in to comment.