Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] remove flagslib #3733

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .nnx.errors import TraceContextError as TraceContextError
from .nnx.filterlib import All as All
from .nnx.filterlib import Not as Not
from .nnx.flaglib import flags as flags
from .nnx.graph_utils import GraphDef as GraphDef
from .nnx.helpers import Dict as Dict
from .nnx.helpers import Sequence as Sequence
Expand Down
16 changes: 11 additions & 5 deletions flax/experimental/nnx/examples/lm1b/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from jax import lax

from flax.experimental import nnx
from flax.experimental.nnx.nnx import flaglib
from flax.experimental.nnx.examples.lm1b.configs import default

Shape = tuple[int, ...]
Expand Down Expand Up @@ -126,9 +125,11 @@ def __init__(
self,
config: TransformerConfig,
*,
decode: bool = False,
rngs: nnx.Rngs,
):
self.config = config
self.decode = decode
self.pos_emb_shape = (1, config.max_len, config.emb_dim)

if config.posemb_init is not None:
Expand Down Expand Up @@ -168,7 +169,7 @@ def __call__(self, inputs: jax.Array, inputs_positions=None):
pos_embedding = self.pos_embedding.value

# We use a cache position index for tracking decoding position.
if flaglib.flags.get('decode', False):
if self.decode:
_, _, df = pos_embedding.shape
# equivalent to pos_embedding[:, i:i+1] but traceable
pos_embedding = lax.dynamic_slice(
Expand Down Expand Up @@ -336,9 +337,11 @@ def __init__(
config: TransformerConfig,
shared_embedding: nnx.Embed | None = None,
*,
decode: bool = False,
rngs: nnx.Rngs,
):
self.config = config
self.decode = decode
self.shared_embedding = shared_embedding

# Target Embedding
Expand Down Expand Up @@ -413,7 +416,7 @@ def __call__(
assert inputs.ndim == 2 # (batch, len)

y = inputs.astype('int32')
if not flaglib.flags.get('decode', False):
if not self.decode:
y = shift_inputs(y, segment_ids=inputs_segmentation)
y = self.output_embed(y)
y = self.posembed_output(y, inputs_positions=inputs_positions)
Expand Down Expand Up @@ -450,8 +453,11 @@ class TransformerLM(nnx.Module):
config: TransformerConfig dataclass containing hyperparameters.
"""

def __init__(self, config: TransformerConfig, *, rngs: nnx.Rngs):
def __init__(
self, config: TransformerConfig, *, decode: bool = False, rngs: nnx.Rngs
):
self.config = config
self.decode = decode
self.decoder = Decoder(config=config, shared_embedding=None, rngs=rngs)

def __call__(
Expand All @@ -475,7 +481,7 @@ def __call__(
config = self.config

# Make padding attention masks.
if flaglib.flags.get('decode', False):
if self.decode:
# for fast autoregressive decoding we use no decoder mask
decoder_mask = None
else:
Expand Down
8 changes: 4 additions & 4 deletions flax/experimental/nnx/examples/lm1b/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def test_forward_eval(self):
self.transfer_params(config, params_nnx, params_linen)
model_nnx.update(params_nnx)

with nnx.flags(deterministic=True, decode=False):
output_nnx = model_nnx(sample_inputs)
model_nnx.set_attributes(deterministic=True, decode=False)
output_nnx = model_nnx(sample_inputs)

output_linen: jax.Array = model_linen.apply(
{'params': params_linen}, sample_inputs
Expand Down Expand Up @@ -263,13 +263,13 @@ def test_forward_decode(self):
self.transfer_params(config, params_nnx, params_linen)
self.transfer_cache(config, cache_nnx, cache_linen)
model_nnx.update(params_nnx, cache_nnx)
model_nnx.set_attributes(deterministic=True, decode=True)

outputs_nnx = []
outputs_linen = []

for inputs in ar_decode_inputs:
with nnx.flags(deterministic=True, decode=True):
output_nnx = model_nnx(inputs)
output_nnx = model_nnx(inputs)
outputs_nnx.append(output_nnx)

output_linen: jax.Array
Expand Down
24 changes: 12 additions & 12 deletions flax/experimental/nnx/examples/lm1b/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ def train_step(
def loss_fn(params):
"""loss function used for training."""
module = state.graphdef.merge(params)
with nnx.flags(deterministic=False, decode=False):
logits = module(
inputs,
inputs_positions=inputs_positions,
inputs_segmentation=inputs_segmentation,
rngs=nnx.Rngs(dropout=dropout_rng),
)
module.set_attributes(deterministic=False, decode=False)
logits = module(
inputs,
inputs_positions=inputs_positions,
inputs_segmentation=inputs_segmentation,
rngs=nnx.Rngs(dropout=dropout_rng),
)

loss, weight_sum = compute_weighted_cross_entropy(
logits, inputs, weights, label_smoothing
Expand Down Expand Up @@ -229,8 +229,8 @@ def eval_step(
inputs = batch['inputs']
weights = jnp.where(inputs > 0, 1.0, 0.0)
module = static.merge(params)
with nnx.flags(deterministic=True, decode=False):
logits = module(inputs)
module.set_attributes(deterministic=True, decode=False)
logits = module(inputs)

return compute_metrics(logits, inputs, weights, label_smoothing)

Expand Down Expand Up @@ -261,8 +261,8 @@ def tokens_ids_to_logits(flat_ids, cache: nnx.State):
"""Token slice to logits from decoder model."""
# --> [batch * beam, 1, vocab]
module = static.merge(params, cache)
with nnx.flags(deterministic=True, decode=True):
logits = module(flat_ids)
module.set_attributes(deterministic=True, decode=True)
logits = module(flat_ids)
cache = module.extract(nnx.Cache)
# Remove singleton sequence-length dimension:
# [batch, 1, vocab] --> [batch, vocab]
Expand Down Expand Up @@ -538,7 +538,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
predict_step,
in_axes=(
0,
jax.tree_map(lambda x: None, state.params),
jax.tree_util.tree_map(lambda x: None, state.params),
0,
None,
None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def scan_fn(
model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0))

x = jnp.ones((3, 10))
with nnx.flags(deterministic=False):
y = model(x, rngs=nnx.Rngs(dropout=1))
model.set_attributes(deterministic=False)
y = model(x, rngs=nnx.Rngs(dropout=1))

print(jax.tree_map(jnp.shape, model.get_state()))
print(y.shape)
79 changes: 0 additions & 79 deletions flax/experimental/nnx/nnx/flaglib.py

This file was deleted.

59 changes: 59 additions & 0 deletions flax/experimental/nnx/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,65 @@ def modules(self) -> tp.Iterator[tuple[Path, Module]]:
if isinstance(value, Module):
yield path, value

def set_attributes(
self,
*filters: filterlib.Filter,
raise_if_not_found: bool = True,
**attributes: tp.Any,
) -> None:
"""Sets the attributes of nested Modules including the current Module.
If the attribute is not found in the Module, it is ignored.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you consider a flag that would raise an error if the attribute isn't found?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


Example::

>>> from flax.experimental import nnx
...
>>> class Block(nnx.Module):
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
... self.linear = nnx.Linear(din, dout, rngs=rngs)
... self.dropout = nnx.Dropout(0.5, deterministic=False)
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
...
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(False, False)
>>> block.set_attributes(deterministic=True, use_running_average=True)
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, True)

``Filter``s can be used to set the attributes of specific Modules::

>>> block = Block(2, 5, rngs=nnx.Rngs(0))
>>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True)
>>> # Only the dropout will be modified
>>> block.dropout.deterministic, block.batch_norm.use_running_average
(True, False)

Args:
*filters: Filters to select the Modules to set the attributes of.
raise_if_not_found: If True (default), raises a ValueError if at least one attribute
instance is not found in one of the selected Modules.
**attributes: The attributes to set.
"""
remaining_attributes = set(attributes.keys())
if not filters:
filters = (True,)
predicates = tuple(map(filterlib.to_predicate, filters))
for path, module in self.modules():
for predicate in predicates:
if predicate(path, module):
for name, value in attributes.items():
if hasattr(module, name):
if name in remaining_attributes:
remaining_attributes.remove(name)
setattr(module, name, value)
break

if remaining_attributes and raise_if_not_found:
raise ValueError(
f'Could not find at least one instance of the following attributes: {remaining_attributes}'
)

def __init_subclass__(cls, experimental_pytree: bool = False) -> None:
super().__init_subclass__()

Expand Down
3 changes: 0 additions & 3 deletions flax/experimental/nnx/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from flax.experimental import nnx
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx import flaglib
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import initializers
from flax.experimental.nnx.nnx.nn.dtypes import promote_dtype
Expand Down Expand Up @@ -510,7 +509,6 @@ def __call__(
decode = first_from(
decode,
self.decode,
flaglib.flags.get('decode'),
error_msg="""No `decode` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down Expand Up @@ -557,7 +555,6 @@ def __call__(
deterministic = first_from(
deterministic,
self.deterministic,
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down
3 changes: 1 addition & 2 deletions flax/experimental/nnx/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from jax import lax

from flax.experimental import nnx
from flax.experimental.nnx.nnx import flaglib, rnglib
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.module import Module, first_from
from flax.experimental.nnx.nnx.nn import dtypes, initializers
from flax.typing import (
Expand Down Expand Up @@ -283,7 +283,6 @@ def __call__(
use_running_average = first_from(
use_running_average,
self.use_running_average,
flaglib.flags.get('use_running_average'),
error_msg="""No `use_running_average` argument was provided to BatchNorm
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down
3 changes: 1 addition & 2 deletions flax/experimental/nnx/nnx/nn/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax.numpy as jnp
from jax import lax, random

from flax.experimental.nnx.nnx import flaglib, rnglib
from flax.experimental.nnx.nnx import rnglib
from flax.experimental.nnx.nnx.module import Module, first_from
import dataclasses

Expand Down Expand Up @@ -61,7 +61,6 @@ def __call__(
deterministic = first_from(
deterministic,
self.deterministic,
flaglib.flags.get('deterministic'),
error_msg="""No `deterministic` argument was provided to Dropout
as either a __call__ argument, class attribute, or nnx.flag.""",
)
Expand Down
7 changes: 3 additions & 4 deletions flax/experimental/nnx/tests/nn/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def __call__(self, x, sow_weights=False):
),
rng,
)
module.set_attributes(decode=False)

with nnx.flags(decode=False):
_ = module(x, True)
_ = module(x, True)
intermediates = module.pop(nnx.Intermediate)
assert intermediates['attention_layers/0/attention_weights'].raw_value[
0
Expand All @@ -77,8 +77,7 @@ def __call__(self, x, sow_weights=False):
0
].shape == (4, 8, 6, 6)

with nnx.flags(decode=False):
_ = module(x)
_ = module(x)
intermediates = module.pop(nnx.Intermediate)
assert not intermediates # empty

Expand Down
Loading
Loading