Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate multi_transform in favor of partition #1216

Merged
merged 1 commit into from
Mar 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/api/combining_optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ Combining Optimizers
.. autosummary::
chain
named_chain
multi_transform
partition

Chain
~~~~~
.. autofunction:: chain
.. autofunction:: named_chain

Multi-transform
~~~~~~~~~~~~~~~
.. autofunction:: multi_transform
.. autoclass:: MultiTransformState
Partition
~~~~~~~~~
.. autofunction:: partition
.. autoclass:: PartitionmState
12 changes: 8 additions & 4 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,10 @@
zero_nans = transforms.zero_nans
ZeroNansState = transforms.ZeroNansState
chain = transforms.chain
multi_transform = transforms.partition
MultiTransformState = transforms.PartitionState
partition = transforms.partition
PartitionState = transforms.PartitionState
multi_transform = transforms.partition # for backwards compatibility
MultiTransformState = transforms.PartitionState # for backwards compatibility
named_chain = transforms.named_chain
trace = transforms.trace
TraceState = transforms.TraceState
Expand Down Expand Up @@ -373,10 +375,10 @@
"MaskedState",
"matrix_inverse_pth_root",
"multi_normal",
"multi_transform",
"multi_transform", # for backwards compatibility
"MultiSteps",
"MultiStepsState",
"MultiTransformState",
"MultiTransformState", # for backwards compatibility
"nadam",
"nadamw",
"nnls",
Expand All @@ -386,6 +388,8 @@
"ntxent",
"OptState",
"Params",
"partition",
"PartitionState",
"periodic_update",
"per_example_global_norm_clip",
"per_example_layer_norm_clip",
Expand Down
3 changes: 2 additions & 1 deletion optax/_src/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@

chain = _combining.chain
named_chain = _combining.named_chain
multi_transform = _combining.partition
partition = _combining.partition
multi_transform = _combining.partition # for backwards compatibility
MultiTransformState = _combining.PartitionState
14 changes: 7 additions & 7 deletions optax/_src/combine_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ def update_fn_kwargs(updates, state, params=None, **extra_args):
opt.update(params, state, params=params, ignored_kwarg='hi')


class MultiTransformTest(chex.TestCase):
"""Tests for the multi_transform wrapper."""
class PartitionTest(chex.TestCase):
"""Tests for the partition wrapper."""

@chex.all_variants
@parameterized.parameters(True, False)
def test_multi_transform(self, use_fn):
def test_partition(self, use_fn):
params = {'a1': 1.0, 'b1': 2.0, 'z1': {'a2': 3.0, 'z2': {'c1': 4.0}}}
params = jax.tree.map(jnp.asarray, params)
input_updates = jax.tree.map(lambda x: x / 10.0, params)
Expand All @@ -168,7 +168,7 @@ def test_multi_transform(self, use_fn):
param_labels = _map_keys_fn(lambda k, _: k[0])
if not use_fn:
param_labels = param_labels(params)
tx = combine.multi_transform(tx_dict, param_labels)
tx = combine.partition(tx_dict, param_labels)
update_fn = self.variant(tx.update)
state = self.variant(tx.init)(params)

Expand Down Expand Up @@ -206,7 +206,7 @@ def update_without_arg(updates, state, params=None):
opt_no_arg = base.GradientTransformation(init, update_without_arg)
opt_extra_arg = base.GradientTransformationExtraArgs(init, update_with_arg)

opt = combine.multi_transform(
opt = combine.partition(
{
'a': opt_no_arg,
'b': opt_extra_arg,
Expand All @@ -225,7 +225,7 @@ def update_without_arg(updates, state, params=None):

@parameterized.parameters(list, tuple, dict)
def test_empty(self, container):
init_fn, update_fn = combine.multi_transform(
init_fn, update_fn = combine.partition(
{0: alias.sgd(1.0)}, lambda _: 0
)
updates, _ = update_fn(container(), init_fn(container()))
Expand All @@ -249,7 +249,7 @@ def test_labels_mismatch(self, use_extra_label, use_fn):
1: alias.adam(1.0, b1=0.0, b2=0.0),
2: transform.trace(1.0),
}
init_fn, update_fn = combine.multi_transform(
init_fn, update_fn = combine.partition(
transforms, (lambda _: label_tree) if use_fn else label_tree
)

Expand Down
2 changes: 1 addition & 1 deletion optax/contrib/_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def muon(
Bernstein et al., `Old Optimizer, New Norm: An Anthology
<https://arxiv.org/abs/2409.20325>`_, 2024
"""
return combine.multi_transform(
return combine.partition(
transforms={
'muon': combine.chain(
scale_by_muon(
Expand Down