diff --git a/optax/__init__.py b/optax/__init__.py index 1f3c5456e..0840b1d5a 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -73,23 +73,6 @@ from optax._src.base import TransformUpdateFn from optax._src.base import Updates from optax._src.base import with_extra_args_support -from optax._src.clipping import adaptive_grad_clip -from optax._src.clipping import AdaptiveGradClipState -from optax._src.clipping import clip -from optax._src.clipping import clip_by_block_rms -from optax._src.clipping import clip_by_global_norm -from optax._src.clipping import ClipByGlobalNormState -from optax._src.clipping import ClipState -from optax._src.clipping import per_example_global_norm_clip -from optax._src.clipping import per_example_layer_norm_clip -from optax._src.combine import chain -from optax._src.combine import multi_transform -from optax._src.combine import MultiTransformState -from optax._src.combine import named_chain -from optax._src.constrain import keep_params_nonnegative -from optax._src.constrain import NonNegativeParamsState -from optax._src.constrain import zero_nans -from optax._src.constrain import ZeroNansState from optax._src.factorized import FactoredState from optax._src.factorized import scale_by_factored_rms from optax._src.linear_algebra import global_norm @@ -107,15 +90,9 @@ from optax._src.numerics import safe_int32_increment from optax._src.numerics import safe_norm from optax._src.numerics import safe_root_mean_squares -from optax._src.transform import add_decayed_weights -from optax._src.transform import add_noise -from optax._src.transform import AddDecayedWeightsState -from optax._src.transform import AddNoiseState from optax._src.transform import apply_every from optax._src.transform import ApplyEvery from optax._src.transform import centralize -from optax._src.transform import ema -from optax._src.transform import EmaState from optax._src.transform import normalize_by_update_norm from optax._src.transform import scale from optax._src.transform import scale_by_adadelta @@ -157,34 +134,58 @@ from optax._src.transform import ScaleByRStdDevState from optax._src.transform import ScaleByScheduleState from optax._src.transform import ScaleBySM3State -from optax._src.transform import ScaleByTrustRatioState -from optax._src.transform import ScaleState -from optax._src.transform import trace -from optax._src.transform import TraceState from optax._src.update import apply_updates from optax._src.update import incremental_update from optax._src.update import periodic_update from optax._src.utils import multi_normal from optax._src.utils import scale_gradient from optax._src.utils import value_and_grad_from_state -from optax._src.wrappers import apply_if_finite -from optax._src.wrappers import ApplyIfFiniteState -from optax._src.wrappers import conditionally_mask -from optax._src.wrappers import conditionally_transform -from optax._src.wrappers import ConditionallyMaskState -from optax._src.wrappers import ConditionallyTransformState -from optax._src.wrappers import flatten -from optax._src.wrappers import masked -from optax._src.wrappers import MaskedNode -from optax._src.wrappers import MaskedState from optax._src.wrappers import maybe_update from optax._src.wrappers import MaybeUpdateState -from optax._src.wrappers import MultiSteps -from optax._src.wrappers import MultiStepsState -from optax._src.wrappers import ShouldSkipUpdateFunction -from optax._src.wrappers import skip_large_updates -from optax._src.wrappers import skip_not_finite +# TODO(mtthss): remove tree_utils aliases after updates. +adaptive_grad_clip = transforms.adaptive_grad_clip +AdaptiveGradClipState = EmptyState +clip = transforms.clip +clip_by_block_rms = transforms.clip_by_block_rms +clip_by_global_norm = transforms.clip_by_global_norm +ClipByGlobalNormState = EmptyState +ClipState = EmptyState +per_example_global_norm_clip = transforms.per_example_global_norm_clip +per_example_layer_norm_clip = transforms.per_example_layer_norm_clip +keep_params_nonnegative = transforms.keep_params_nonnegative +NonNegativeParamsState = transforms.NonNegativeParamsState +zero_nans = transforms.zero_nans +ZeroNansState = transforms.ZeroNansState +chain = transforms.chain +multi_transform = transforms.partition +MultiTransformState = transforms.PartitionState +named_chain = transforms.named_chain +trace = transforms.trace +TraceState = transforms.TraceState +ema = transforms.ema +EmaState = transforms.EmaState +add_noise = transforms.add_noise +AddNoiseState = transforms.AddNoiseState +add_decayed_weights = transforms.add_decayed_weights +AddDecayedWeightsState = EmptyState +ScaleByTrustRatioState = EmptyState +ScaleState = EmptyState +apply_if_finite = transforms.apply_if_finite +ApplyIfFiniteState = transforms.ApplyIfFiniteState +conditionally_mask = transforms.conditionally_mask +conditionally_transform = transforms.conditionally_transform +ConditionallyMaskState = transforms.ConditionallyMaskState +ConditionallyTransformState = transforms.ConditionallyTransformState +flatten = transforms.flatten +masked = transforms.masked +MaskedNode = transforms.MaskedNode +MaskedState = transforms.MaskedState +MultiSteps = transforms.MultiSteps +MultiStepsState = transforms.MultiStepsState +ShouldSkipUpdateFunction = transforms.ShouldSkipUpdateFunction +skip_large_updates = transforms.skip_large_updates +skip_not_finite = transforms.skip_not_finite # TODO(mtthss): remove tree_utils aliases after updates. tree_map_params = tree_utils.tree_map_params