From e5a2b4e7c35a5c78107cea86c3cebf02a80b52d6 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Wed, 13 Mar 2024 15:47:53 +0000 Subject: [PATCH] [nnx] remove flagslib --- flax/experimental/nnx/__init__.py | 1 - flax/experimental/nnx/examples/lm1b/models.py | 16 ++-- .../nnx/examples/lm1b/models_test.py | 8 +- flax/experimental/nnx/examples/lm1b/train.py | 24 +++--- .../toy_examples/06_scan_over_layers.py | 4 +- flax/experimental/nnx/nnx/flaglib.py | 79 ------------------- flax/experimental/nnx/nnx/module.py | 58 ++++++++++++++ flax/experimental/nnx/nnx/nn/attention.py | 3 - flax/experimental/nnx/nnx/nn/normalization.py | 3 +- flax/experimental/nnx/nnx/nn/stochastic.py | 3 +- .../nnx/tests/nn/test_attention.py | 7 +- .../nnx/tests/test_integration.py | 20 +++-- flax/experimental/nnx/tests/test_module.py | 28 ++++++- .../experimental/nnx/tests/test_transforms.py | 12 +-- 14 files changed, 137 insertions(+), 129 deletions(-) delete mode 100644 flax/experimental/nnx/nnx/flaglib.py diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 2c07fba260..b9a475c0f4 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -23,7 +23,6 @@ from .nnx.errors import TraceContextError as TraceContextError from .nnx.filterlib import All as All from .nnx.filterlib import Not as Not -from .nnx.flaglib import flags as flags from .nnx.graph_utils import GraphDef as GraphDef from .nnx.helpers import Dict as Dict from .nnx.helpers import Sequence as Sequence diff --git a/flax/experimental/nnx/examples/lm1b/models.py b/flax/experimental/nnx/examples/lm1b/models.py index 0d262f6107..58274de4eb 100644 --- a/flax/experimental/nnx/examples/lm1b/models.py +++ b/flax/experimental/nnx/examples/lm1b/models.py @@ -33,7 +33,6 @@ from jax import lax from flax.experimental import nnx -from flax.experimental.nnx.nnx import flaglib from flax.experimental.nnx.examples.lm1b.configs import default Shape = tuple[int, ...] @@ -126,9 +125,11 @@ def __init__( self, config: TransformerConfig, *, + decode: bool = False, rngs: nnx.Rngs, ): self.config = config + self.decode = decode self.pos_emb_shape = (1, config.max_len, config.emb_dim) if config.posemb_init is not None: @@ -168,7 +169,7 @@ def __call__(self, inputs: jax.Array, inputs_positions=None): pos_embedding = self.pos_embedding.value # We use a cache position index for tracking decoding position. - if flaglib.flags.get('decode', False): + if self.decode: _, _, df = pos_embedding.shape # equivalent to pos_embedding[:, i:i+1] but traceable pos_embedding = lax.dynamic_slice( @@ -336,9 +337,11 @@ def __init__( config: TransformerConfig, shared_embedding: nnx.Embed | None = None, *, + decode: bool = False, rngs: nnx.Rngs, ): self.config = config + self.decode = decode self.shared_embedding = shared_embedding # Target Embedding @@ -413,7 +416,7 @@ def __call__( assert inputs.ndim == 2 # (batch, len) y = inputs.astype('int32') - if not flaglib.flags.get('decode', False): + if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) y = self.output_embed(y) y = self.posembed_output(y, inputs_positions=inputs_positions) @@ -450,8 +453,11 @@ class TransformerLM(nnx.Module): config: TransformerConfig dataclass containing hyperparameters. """ - def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs): + def __init__( + self, config: TransformerConfig, *, decode: bool = False, rngs: nnx.Rngs + ): self.config = config + self.decode = decode self.decoder = Decoder(config=config, shared_embedding=None, rngs=rngs) def __call__( @@ -475,7 +481,7 @@ def __call__( config = self.config # Make padding attention masks. - if flaglib.flags.get('decode', False): + if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: diff --git a/flax/experimental/nnx/examples/lm1b/models_test.py b/flax/experimental/nnx/examples/lm1b/models_test.py index 007c6da206..d8beda34ec 100644 --- a/flax/experimental/nnx/examples/lm1b/models_test.py +++ b/flax/experimental/nnx/examples/lm1b/models_test.py @@ -217,8 +217,8 @@ def test_forward_eval(self): self.transfer_params(config, params_nnx, params_linen) model_nnx.update(params_nnx) - with nnx.flags(deterministic=True, decode=False): - output_nnx = model_nnx(sample_inputs) + model_nnx.set_attributes(deterministic=True, decode=False) + output_nnx = model_nnx(sample_inputs) output_linen: jax.Array = model_linen.apply( {'params': params_linen}, sample_inputs @@ -263,13 +263,13 @@ def test_forward_decode(self): self.transfer_params(config, params_nnx, params_linen) self.transfer_cache(config, cache_nnx, cache_linen) model_nnx.update(params_nnx, cache_nnx) + model_nnx.set_attributes(deterministic=True, decode=True) outputs_nnx = [] outputs_linen = [] for inputs in ar_decode_inputs: - with nnx.flags(deterministic=True, decode=True): - output_nnx = model_nnx(inputs) + output_nnx = model_nnx(inputs) outputs_nnx.append(output_nnx) output_linen: jax.Array diff --git a/flax/experimental/nnx/examples/lm1b/train.py b/flax/experimental/nnx/examples/lm1b/train.py index 13eb172982..a8a2c27487 100644 --- a/flax/experimental/nnx/examples/lm1b/train.py +++ b/flax/experimental/nnx/examples/lm1b/train.py @@ -194,13 +194,13 @@ def train_step( def loss_fn(params): """loss function used for training.""" module = state.graphdef.merge(params) - with nnx.flags(deterministic=False, decode=False): - logits = module( - inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - rngs=nnx.Rngs(dropout=dropout_rng), - ) + module.set_attributes(deterministic=False, decode=False) + logits = module( + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + rngs=nnx.Rngs(dropout=dropout_rng), + ) loss, weight_sum = compute_weighted_cross_entropy( logits, inputs, weights, label_smoothing @@ -229,8 +229,8 @@ def eval_step( inputs = batch['inputs'] weights = jnp.where(inputs > 0, 1.0, 0.0) module = static.merge(params) - with nnx.flags(deterministic=True, decode=False): - logits = module(inputs) + module.set_attributes(deterministic=True, decode=False) + logits = module(inputs) return compute_metrics(logits, inputs, weights, label_smoothing) @@ -261,8 +261,8 @@ def tokens_ids_to_logits(flat_ids, cache: nnx.State): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] module = static.merge(params, cache) - with nnx.flags(deterministic=True, decode=True): - logits = module(flat_ids) + module.set_attributes(deterministic=True, decode=True) + logits = module(flat_ids) cache = module.extract(nnx.Cache) # Remove singleton sequence-length dimension: # [batch, 1, vocab] --> [batch, vocab] @@ -538,7 +538,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): predict_step, in_axes=( 0, - jax.tree_map(lambda x: None, state.params), + jax.tree_util.tree_map(lambda x: None, state.params), 0, None, None, diff --git a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py index e3c9d3f658..8ed581df02 100644 --- a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py +++ b/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py @@ -80,8 +80,8 @@ def scan_fn( model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) x = jnp.ones((3, 10)) -with nnx.flags(deterministic=False): - y = model(x, rngs=nnx.Rngs(dropout=1)) +model.set_attributes(deterministic=False) +y = model(x, rngs=nnx.Rngs(dropout=1)) print(jax.tree_map(jnp.shape, model.get_state())) print(y.shape) diff --git a/flax/experimental/nnx/nnx/flaglib.py b/flax/experimental/nnx/nnx/flaglib.py deleted file mode 100644 index 065fc2b613..0000000000 --- a/flax/experimental/nnx/nnx/flaglib.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import dataclasses -import threading -import typing as tp -from contextlib import contextmanager -from types import MappingProxyType - -A = tp.TypeVar('A') - - -@dataclasses.dataclass -class FlagsContext(threading.local): - flags_stack: tp.List[MappingProxyType[str, tp.Any]] = dataclasses.field( - default_factory=lambda: [MappingProxyType({})] - ) - - -FLAGS_CONTEXT = FlagsContext() - - -class Flags(tp.Mapping[str, tp.Any]): - __slots__ = () - - def __getitem__(self, name: str) -> tp.Any: - current_flags = FLAGS_CONTEXT.flags_stack[-1] - if name not in current_flags: - raise ValueError(f'Unknown Flag: {name}') - return current_flags[name] - - __getattr__ = __getitem__ - - def __iter__(self) -> tp.Iterator[str]: - return iter(FLAGS_CONTEXT.flags_stack[-1]) - - def __len__(self) -> int: - return len(FLAGS_CONTEXT.flags_stack[-1]) - - def __contains__(self, name: tp.Any) -> bool: - return name in FLAGS_CONTEXT.flags_stack[-1] - - @contextmanager - def __call__(self, **kwargs: tp.Any): - current_flags = FLAGS_CONTEXT.flags_stack[-1] - FLAGS_CONTEXT.flags_stack.append( - MappingProxyType(dict(current_flags, **kwargs)) - ) - try: - yield - finally: - FLAGS_CONTEXT.flags_stack.pop() - - @tp.overload - def get(self, name: str, /) -> tp.Any: - ... - - @tp.overload - def get(self, name: str, /, default: A) -> A: - ... - - def get(self, name: str, /, default: A = None) -> A | None: - return FLAGS_CONTEXT.flags_stack[-1].get(name, default) - - -flags = Flags() diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index a4d69375f1..abb4640c9d 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -452,6 +452,64 @@ def modules(self) -> tp.Iterator[tuple[Path, Module]]: if isinstance(value, Module): yield path, value + def set_attributes( + self, + *filters: filterlib.Filter, + raise_if_not_found: bool = True, + **attributes: tp.Any, + ) -> None: + """Sets the attributes of nested Modules including the current Module. + If the attribute is not found in the Module, it is ignored. + + Example:: + + >>> from flax.experimental import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5, deterministic=False) + ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs) + ... + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (False, False) + >>> block.set_attributes(deterministic=True, use_running_average=True) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (True, True) + + ``Filter``s can be used to set the attributes of specific Modules:: + + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True) + >>> # Only the dropout will be modified + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (True, False) + + Args: + *filters: Filters to select the Modules to set the attributes of. + raise_if_not_found: If True (default), raises a ValueError if at least one attribute + instance is not found in one of the selected Modules. + **attributes: The attributes to set. + """ + remaining_attributes = set(attributes.keys()) + if not filters: + filters = (True,) + predicates = tuple(map(filterlib.to_predicate, filters)) + for path, module in self.modules(): + for predicate in predicates: + if predicate(path, module): + for name, value in attributes.items(): + if hasattr(module, name): + remaining_attributes.remove(name) + setattr(module, name, value) + break + + if remaining_attributes and raise_if_not_found: + raise ValueError( + f'Could not find at least one instance of the following attributes: {remaining_attributes}' + ) + def __init_subclass__(cls, experimental_pytree: bool = False) -> None: super().__init_subclass__() diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index 86f77044d7..a48c7df90a 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -25,7 +25,6 @@ from flax.experimental import nnx from flax.experimental.nnx.nnx import rnglib -from flax.experimental.nnx.nnx import flaglib from flax.experimental.nnx.nnx.module import Module, first_from from flax.experimental.nnx.nnx.nn import initializers from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype @@ -510,7 +509,6 @@ def __call__( decode = first_from( decode, self.decode, - flaglib.flags.get('decode'), error_msg="""No `decode` argument was provided to MultiHeadAttention as either a __call__ argument, class attribute, or nnx.flag.""", ) @@ -557,7 +555,6 @@ def __call__( deterministic = first_from( deterministic, self.deterministic, - flaglib.flags.get('deterministic'), error_msg="""No `deterministic` argument was provided to MultiHeadAttention as either a __call__ argument, class attribute, or nnx.flag.""", ) diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py index 1d0faad048..d93ebfd4cc 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -19,7 +19,7 @@ from jax import lax from flax.experimental import nnx -from flax.experimental.nnx.nnx import flaglib, rnglib +from flax.experimental.nnx.nnx import rnglib from flax.experimental.nnx.nnx.module import Module, first_from from flax.experimental.nnx.nnx.nn import dtypes, initializers from flax.typing import ( @@ -283,7 +283,6 @@ def __call__( use_running_average = first_from( use_running_average, self.use_running_average, - flaglib.flags.get('use_running_average'), error_msg="""No `use_running_average` argument was provided to BatchNorm as either a __call__ argument, class attribute, or nnx.flag.""", ) diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index 64cf057f54..5df48aab37 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -17,7 +17,7 @@ import jax.numpy as jnp from jax import lax, random -from flax.experimental.nnx.nnx import flaglib, rnglib +from flax.experimental.nnx.nnx import rnglib from flax.experimental.nnx.nnx.module import Module, first_from import dataclasses @@ -61,7 +61,6 @@ def __call__( deterministic = first_from( deterministic, self.deterministic, - flaglib.flags.get('deterministic'), error_msg="""No `deterministic` argument was provided to Dropout as either a __call__ argument, class attribute, or nnx.flag.""", ) diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/experimental/nnx/tests/nn/test_attention.py index 776abf46f6..f312db7321 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/experimental/nnx/tests/nn/test_attention.py @@ -65,9 +65,9 @@ def __call__(self, x, sow_weights=False): ), rng, ) + module.set_attributes(decode=False) - with nnx.flags(decode=False): - _ = module(x, True) + _ = module(x, True) intermediates = module.pop(nnx.Intermediate) assert intermediates['attention_layers/0/attention_weights'].raw_value[ 0 @@ -77,8 +77,7 @@ def __call__(self, x, sow_weights=False): 0 ].shape == (4, 8, 6, 6) - with nnx.flags(decode=False): - _ = module(x) + _ = module(x) intermediates = module.pop(nnx.Intermediate) assert not intermediates # empty diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/experimental/nnx/tests/test_integration.py index 8a534608c8..9bdb713c62 100644 --- a/flax/experimental/nnx/tests/test_integration.py +++ b/flax/experimental/nnx/tests/test_integration.py @@ -50,19 +50,21 @@ def __call__(self, x): def train_step(model: Model, x, y): @nnx.grad def loss_fn(model: Model): - with nnx.flags(use_running_average=False): - y_pred = model(x) + y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) model.update( - jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) + jax.tree_util.tree_map( + lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads + ) ) model = Model(rngs=nnx.Rngs(0)) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) + model.set_attributes(use_running_average=False) for _i in range(3): train_step(model, x, y) @@ -96,16 +98,18 @@ def __call__(self, x): @jax.jit def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): model = graphdef.merge(state) + model.set_attributes(use_running_average=False) @nnx.grad def loss_fn(model: Model): - with nnx.flags(use_running_average=False): - y_pred = model(x) + y_pred = model(x) return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) model.update( - jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) + jax.tree_util.tree_map( + lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads + ) ) return model.split() @@ -158,7 +162,9 @@ def loss_fn(model): grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) # SGD update model.update( - jax.tree_map(lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads) + jax.tree_util.tree_map( + lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads + ) ) # execute the training step diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index 8cd1e6fd18..47319a2fba 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -28,8 +28,7 @@ class TestModule: def test_has_module_state(self): - class Foo(nnx.Module): - ... + class Foo(nnx.Module): ... foo = Foo() @@ -474,6 +473,31 @@ def __init__(self) -> None: assert m1.c is not m2.c assert m1.self is m1 + def test_set_attributes(self): + class Block(nnx.Module): + def __init__(self, din, dout, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False) + self.batch_norm = nnx.BatchNorm( + 10, use_running_average=False, rngs=rngs + ) + + block = Block(2, 5, rngs=nnx.Rngs(0)) + assert block.dropout.deterministic == False + assert block.batch_norm.use_running_average == False + + block.set_attributes(deterministic=True, use_running_average=True) + assert block.dropout.deterministic == True + assert block.batch_norm.use_running_average == True + + block = Block(2, 5, rngs=nnx.Rngs(0)) + block.set_attributes( + nnx.Dropout, deterministic=True, use_running_average=True + ) + # Only the dropout will be modified + assert block.dropout.deterministic == True + assert block.batch_norm.use_running_average == False + class TestModulePytree: def test_tree_map(self): diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py index f441adf16f..bbe90b42c9 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -391,14 +391,14 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: ) module = MLP(rngs=nnx.Rngs(0)) + module.set_attributes(deterministic=False, use_running_average=False) assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) assert module.scan_module.linear.bias.value.shape == (5, 3) assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) - with nnx.flags(deterministic=False, use_running_average=False): - y = module(x, rngs=nnx.Rngs(1)) + y = module(x, rngs=nnx.Rngs(1)) assert y.shape == (1, 3) @@ -427,14 +427,14 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: ) module = MLP(rngs=nnx.Rngs(0)) + module.set_attributes(deterministic=False, use_running_average=False) assert module.scan_module.linear.kernel.value.shape == (5, 3, 3) assert module.scan_module.linear.bias.value.shape == (5, 3) assert module.scan_module.node.value.shape == (2,) x = jnp.ones((1, 3)) - with nnx.flags(deterministic=False, use_running_average=False): - y = module(x, rngs=nnx.Rngs(1)) + y = module(x, rngs=nnx.Rngs(1)) assert y.shape == (1, 3) @@ -465,6 +465,7 @@ def __call__( return x, None module = Block(rngs=nnx.Rngs(0)) + module.set_attributes(deterministic=False, use_running_average=False) assert module.d == 3 assert module.linear.kernel.value.shape == (5, 3, 3) @@ -472,8 +473,7 @@ def __call__( assert module.node.value.shape == (2,) x = jnp.ones((1, 3)) - with nnx.flags(deterministic=False, use_running_average=False): - y, out = module(x, None, rngs=nnx.Rngs(dropout=1)) + y, out = module(x, None, rngs=nnx.Rngs(dropout=1)) assert y.shape == (1, 3) assert out is None