diff --git a/docs/api/combining_optimizers.rst b/docs/api/combining_optimizers.rst index 1e940530..43c7f281 100644 --- a/docs/api/combining_optimizers.rst +++ b/docs/api/combining_optimizers.rst @@ -6,14 +6,14 @@ Combining Optimizers .. autosummary:: chain named_chain - multi_transform + partition Chain ~~~~~ .. autofunction:: chain .. autofunction:: named_chain -Multi-transform -~~~~~~~~~~~~~~~ -.. autofunction:: multi_transform -.. autoclass:: MultiTransformState +Partition +~~~~~~~~~ +.. autofunction:: partition +.. autoclass:: PartitionmState diff --git a/optax/__init__.py b/optax/__init__.py index 246038bf..d88b698e 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -166,8 +166,10 @@ zero_nans = transforms.zero_nans ZeroNansState = transforms.ZeroNansState chain = transforms.chain -multi_transform = transforms.partition -MultiTransformState = transforms.PartitionState +partition = transforms.partition +PartitionState = transforms.PartitionState +multi_transform = transforms.partition # for backwards compatibility +MultiTransformState = transforms.PartitionState # for backwards compatibility named_chain = transforms.named_chain trace = transforms.trace TraceState = transforms.TraceState @@ -373,10 +375,10 @@ "MaskedState", "matrix_inverse_pth_root", "multi_normal", - "multi_transform", + "multi_transform", # for backwards compatibility "MultiSteps", "MultiStepsState", - "MultiTransformState", + "MultiTransformState", # for backwards compatibility "nadam", "nadamw", "nnls", @@ -386,6 +388,8 @@ "ntxent", "OptState", "Params", + "partition", + "PartitionState", "periodic_update", "per_example_global_norm_clip", "per_example_layer_norm_clip", diff --git a/optax/_src/combine.py b/optax/_src/combine.py index a7709b46..a89f2f95 100644 --- a/optax/_src/combine.py +++ b/optax/_src/combine.py @@ -18,5 +18,6 @@ chain = _combining.chain named_chain = _combining.named_chain -multi_transform = _combining.partition +partition = _combining.partition +multi_transform = _combining.partition # for backwards compatibility MultiTransformState = _combining.PartitionState diff --git a/optax/_src/combine_test.py b/optax/_src/combine_test.py index 09c25851..c24a69da 100644 --- a/optax/_src/combine_test.py +++ b/optax/_src/combine_test.py @@ -151,12 +151,12 @@ def update_fn_kwargs(updates, state, params=None, **extra_args): opt.update(params, state, params=params, ignored_kwarg='hi') -class MultiTransformTest(chex.TestCase): - """Tests for the multi_transform wrapper.""" +class PartitionTest(chex.TestCase): + """Tests for the partition wrapper.""" @chex.all_variants @parameterized.parameters(True, False) - def test_multi_transform(self, use_fn): + def test_partition(self, use_fn): params = {'a1': 1.0, 'b1': 2.0, 'z1': {'a2': 3.0, 'z2': {'c1': 4.0}}} params = jax.tree.map(jnp.asarray, params) input_updates = jax.tree.map(lambda x: x / 10.0, params) @@ -168,7 +168,7 @@ def test_multi_transform(self, use_fn): param_labels = _map_keys_fn(lambda k, _: k[0]) if not use_fn: param_labels = param_labels(params) - tx = combine.multi_transform(tx_dict, param_labels) + tx = combine.partition(tx_dict, param_labels) update_fn = self.variant(tx.update) state = self.variant(tx.init)(params) @@ -206,7 +206,7 @@ def update_without_arg(updates, state, params=None): opt_no_arg = base.GradientTransformation(init, update_without_arg) opt_extra_arg = base.GradientTransformationExtraArgs(init, update_with_arg) - opt = combine.multi_transform( + opt = combine.partition( { 'a': opt_no_arg, 'b': opt_extra_arg, @@ -225,7 +225,7 @@ def update_without_arg(updates, state, params=None): @parameterized.parameters(list, tuple, dict) def test_empty(self, container): - init_fn, update_fn = combine.multi_transform( + init_fn, update_fn = combine.partition( {0: alias.sgd(1.0)}, lambda _: 0 ) updates, _ = update_fn(container(), init_fn(container())) @@ -249,7 +249,7 @@ def test_labels_mismatch(self, use_extra_label, use_fn): 1: alias.adam(1.0, b1=0.0, b2=0.0), 2: transform.trace(1.0), } - init_fn, update_fn = combine.multi_transform( + init_fn, update_fn = combine.partition( transforms, (lambda _: label_tree) if use_fn else label_tree ) diff --git a/optax/contrib/_muon.py b/optax/contrib/_muon.py index d5c6a05d..ef3cf947 100644 --- a/optax/contrib/_muon.py +++ b/optax/contrib/_muon.py @@ -250,7 +250,7 @@ def muon( Bernstein et al., `Old Optimizer, New Norm: An Anthology `_, 2024 """ - return combine.multi_transform( + return combine.partition( transforms={ 'muon': combine.chain( scale_by_muon(