Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In flatname space import from the subpackage not from _src. #1147

Merged
merged 1 commit into from
Dec 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 43 additions & 42 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,6 @@
from optax._src.base import TransformUpdateFn
from optax._src.base import Updates
from optax._src.base import with_extra_args_support
from optax._src.clipping import adaptive_grad_clip
from optax._src.clipping import AdaptiveGradClipState
from optax._src.clipping import clip
from optax._src.clipping import clip_by_block_rms
from optax._src.clipping import clip_by_global_norm
from optax._src.clipping import ClipByGlobalNormState
from optax._src.clipping import ClipState
from optax._src.clipping import per_example_global_norm_clip
from optax._src.clipping import per_example_layer_norm_clip
from optax._src.combine import chain
from optax._src.combine import multi_transform
from optax._src.combine import MultiTransformState
from optax._src.combine import named_chain
from optax._src.constrain import keep_params_nonnegative
from optax._src.constrain import NonNegativeParamsState
from optax._src.constrain import zero_nans
from optax._src.constrain import ZeroNansState
from optax._src.factorized import FactoredState
from optax._src.factorized import scale_by_factored_rms
from optax._src.linear_algebra import global_norm
Expand All @@ -107,15 +90,9 @@
from optax._src.numerics import safe_int32_increment
from optax._src.numerics import safe_norm
from optax._src.numerics import safe_root_mean_squares
from optax._src.transform import add_decayed_weights
from optax._src.transform import add_noise
from optax._src.transform import AddDecayedWeightsState
from optax._src.transform import AddNoiseState
from optax._src.transform import apply_every
from optax._src.transform import ApplyEvery
from optax._src.transform import centralize
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import normalize_by_update_norm
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
Expand Down Expand Up @@ -157,34 +134,58 @@
from optax._src.transform import ScaleByRStdDevState
from optax._src.transform import ScaleByScheduleState
from optax._src.transform import ScaleBySM3State
from optax._src.transform import ScaleByTrustRatioState
from optax._src.transform import ScaleState
from optax._src.transform import trace
from optax._src.transform import TraceState
from optax._src.update import apply_updates
from optax._src.update import incremental_update
from optax._src.update import periodic_update
from optax._src.utils import multi_normal
from optax._src.utils import scale_gradient
from optax._src.utils import value_and_grad_from_state
from optax._src.wrappers import apply_if_finite
from optax._src.wrappers import ApplyIfFiniteState
from optax._src.wrappers import conditionally_mask
from optax._src.wrappers import conditionally_transform
from optax._src.wrappers import ConditionallyMaskState
from optax._src.wrappers import ConditionallyTransformState
from optax._src.wrappers import flatten
from optax._src.wrappers import masked
from optax._src.wrappers import MaskedNode
from optax._src.wrappers import MaskedState
from optax._src.wrappers import maybe_update
from optax._src.wrappers import MaybeUpdateState
from optax._src.wrappers import MultiSteps
from optax._src.wrappers import MultiStepsState
from optax._src.wrappers import ShouldSkipUpdateFunction
from optax._src.wrappers import skip_large_updates
from optax._src.wrappers import skip_not_finite

# TODO(mtthss): remove tree_utils aliases after updates.
adaptive_grad_clip = transforms.adaptive_grad_clip
AdaptiveGradClipState = EmptyState
clip = transforms.clip
clip_by_block_rms = transforms.clip_by_block_rms
clip_by_global_norm = transforms.clip_by_global_norm
ClipByGlobalNormState = EmptyState
ClipState = EmptyState
per_example_global_norm_clip = transforms.per_example_global_norm_clip
per_example_layer_norm_clip = transforms.per_example_layer_norm_clip
keep_params_nonnegative = transforms.keep_params_nonnegative
NonNegativeParamsState = transforms.NonNegativeParamsState
zero_nans = transforms.zero_nans
ZeroNansState = transforms.ZeroNansState
chain = transforms.chain
multi_transform = transforms.partition
MultiTransformState = transforms.PartitionState
named_chain = transforms.named_chain
trace = transforms.trace
TraceState = transforms.TraceState
ema = transforms.ema
EmaState = transforms.EmaState
add_noise = transforms.add_noise
AddNoiseState = transforms.AddNoiseState
add_decayed_weights = transforms.add_decayed_weights
AddDecayedWeightsState = EmptyState
ScaleByTrustRatioState = EmptyState
ScaleState = EmptyState
apply_if_finite = transforms.apply_if_finite
ApplyIfFiniteState = transforms.ApplyIfFiniteState
conditionally_mask = transforms.conditionally_mask
conditionally_transform = transforms.conditionally_transform
ConditionallyMaskState = transforms.ConditionallyMaskState
ConditionallyTransformState = transforms.ConditionallyTransformState
flatten = transforms.flatten
masked = transforms.masked
MaskedNode = transforms.MaskedNode
MaskedState = transforms.MaskedState
MultiSteps = transforms.MultiSteps
MultiStepsState = transforms.MultiStepsState
ShouldSkipUpdateFunction = transforms.ShouldSkipUpdateFunction
skip_large_updates = transforms.skip_large_updates
skip_not_finite = transforms.skip_not_finite

# TODO(mtthss): remove tree_utils aliases after updates.
tree_map_params = tree_utils.tree_map_params
Expand Down
Loading