Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose components in sub-package #978

Merged
merged 1 commit into from
May 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
from optax import projections
from optax import schedules
from optax import second_order
from optax import transforms
from optax import tree_utils
from optax._src.alias import adabelief
from optax._src.alias import adadelta
3 changes: 3 additions & 0 deletions optax/optax_test.py
Original file line number Diff line number Diff line change
@@ -15,14 +15,17 @@
"""Tests for optax."""

from absl.testing import absltest

import optax
from optax import transforms


class OptaxTest(absltest.TestCase):
"""Test optax can be imported correctly."""

def test_import(self):
self.assertTrue(hasattr(optax, 'GradientTransformation'))
self.assertTrue(hasattr(transforms, 'partition'))


if __name__ == '__main__':
56 changes: 56 additions & 0 deletions optax/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -23,14 +23,70 @@
from optax.transforms._accumulation import skip_not_finite
from optax.transforms._accumulation import trace
from optax.transforms._accumulation import TraceState
from optax.transforms._adding import add_decayed_weights
from optax.transforms._adding import add_noise
from optax.transforms._adding import AddNoiseState
from optax.transforms._clipping import adaptive_grad_clip
from optax.transforms._clipping import clip
from optax.transforms._clipping import clip_by_block_rms
from optax.transforms._clipping import clip_by_global_norm
from optax.transforms._clipping import per_example_global_norm_clip
from optax.transforms._clipping import per_example_layer_norm_clip
from optax.transforms._clipping import unitwise_clip
from optax.transforms._clipping import unitwise_norm
from optax.transforms._combining import chain
from optax.transforms._combining import named_chain
from optax.transforms._combining import partition
from optax.transforms._combining import PartitionState
from optax.transforms._conditionality import apply_if_finite
from optax.transforms._conditionality import ApplyIfFiniteState
from optax.transforms._conditionality import conditionally_mask
from optax.transforms._conditionality import conditionally_transform
from optax.transforms._conditionality import ConditionallyMaskState
from optax.transforms._conditionality import ConditionallyTransformState
from optax.transforms._conditionality import ConditionFn
from optax.transforms._constraining import keep_params_nonnegative
from optax.transforms._constraining import NonNegativeParamsState
from optax.transforms._constraining import zero_nans
from optax.transforms._constraining import ZeroNansState
from optax.transforms._layouts import flatten
from optax.transforms._masking import masked
from optax.transforms._masking import MaskedNode
from optax.transforms._masking import MaskedState


__all__ = (
"adaptive_grad_clip",
"add_decayed_weights",
"add_noise",
"AddNoiseState",
"apply_if_finite",
"ApplyIfFiniteState",
"chain",
"clip_by_block_rms",
"clip_by_global_norm",
"clip",
"conditionally_mask",
"ConditionallyMaskState",
"conditionally_transform",
"ConditionallyTransformState",
"ema",
"EmaState",
"flatten",
"keep_params_nonnegative",
"masked",
"MaskedState",
"MultiSteps",
"MultiStepsState",
"named_chain",
"NonNegativeParamsState",
"partition",
"PartitionState",
"ShouldSkipUpdateFunction",
"skip_large_updates",
"skip_not_finite",
"trace",
"TraceState",
"zero_nans",
"ZeroNansState",
)