Skip to content

Commit

Permalink
Expose components in sub-package
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 638188465
  • Loading branch information
mtthss authored and OptaxDev committed May 29, 2024
1 parent 36ee9f4 commit 99587dc
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from optax import projections
from optax import schedules
from optax import second_order
from optax import transforms
from optax import tree_utils
from optax._src.alias import adabelief
from optax._src.alias import adadelta
Expand Down
3 changes: 3 additions & 0 deletions optax/optax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,17 @@
"""Tests for optax."""

from absl.testing import absltest

import optax
from optax import transforms


class OptaxTest(absltest.TestCase):
"""Test optax can be imported correctly."""

def test_import(self):
self.assertTrue(hasattr(optax, 'GradientTransformation'))
self.assertTrue(hasattr(transforms, 'partition'))


if __name__ == '__main__':
Expand Down
56 changes: 56 additions & 0 deletions optax/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,70 @@
from optax.transforms._accumulation import skip_not_finite
from optax.transforms._accumulation import trace
from optax.transforms._accumulation import TraceState
from optax.transforms._adding import add_decayed_weights
from optax.transforms._adding import add_noise
from optax.transforms._adding import AddNoiseState
from optax.transforms._clipping import adaptive_grad_clip
from optax.transforms._clipping import clip
from optax.transforms._clipping import clip_by_block_rms
from optax.transforms._clipping import clip_by_global_norm
from optax.transforms._clipping import per_example_global_norm_clip
from optax.transforms._clipping import per_example_layer_norm_clip
from optax.transforms._clipping import unitwise_clip
from optax.transforms._clipping import unitwise_norm
from optax.transforms._combining import chain
from optax.transforms._combining import named_chain
from optax.transforms._combining import partition
from optax.transforms._combining import PartitionState
from optax.transforms._conditionality import apply_if_finite
from optax.transforms._conditionality import ApplyIfFiniteState
from optax.transforms._conditionality import conditionally_mask
from optax.transforms._conditionality import conditionally_transform
from optax.transforms._conditionality import ConditionallyMaskState
from optax.transforms._conditionality import ConditionallyTransformState
from optax.transforms._conditionality import ConditionFn
from optax.transforms._constraining import keep_params_nonnegative
from optax.transforms._constraining import NonNegativeParamsState
from optax.transforms._constraining import zero_nans
from optax.transforms._constraining import ZeroNansState
from optax.transforms._layouts import flatten
from optax.transforms._masking import masked
from optax.transforms._masking import MaskedNode
from optax.transforms._masking import MaskedState


__all__ = (
"adaptive_grad_clip",
"add_decayed_weights",
"add_noise",
"AddNoiseState",
"apply_if_finite",
"ApplyIfFiniteState",
"chain",
"clip_by_block_rms",
"clip_by_global_norm",
"clip",
"conditionally_mask",
"ConditionallyMaskState",
"conditionally_transform",
"ConditionallyTransformState",
"ema",
"EmaState",
"flatten",
"keep_params_nonnegative",
"masked",
"MaskedState",
"MultiSteps",
"MultiStepsState",
"named_chain",
"NonNegativeParamsState",
"partition",
"PartitionState",
"ShouldSkipUpdateFunction",
"skip_large_updates",
"skip_not_finite",
"trace",
"TraceState",
"zero_nans",
"ZeroNansState",
)

0 comments on commit 99587dc

Please sign in to comment.