Skip to content

Commit

Permalink
updated nnx.training
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Jun 11, 2024
1 parent 5b7d671 commit 31adb00
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 26 deletions.
15 changes: 11 additions & 4 deletions docs/api_reference/flax.nnx/training/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,19 @@ Metrics
.. automodule:: flax.nnx.metrics
.. currentmodule:: flax.nnx.metrics


.. autoclass:: Metric
:members:
:members: __init__, reset, update, compute

.. autoclass:: Average
:members:
:members: __init__, reset, update, compute

.. autoclass:: Accuracy
:members:
:members: update

.. autoclass:: Welford
:members: __init__, reset, update, compute

.. autoclass:: MultiMetric
:members:
:members: __init__, reset, update, compute

2 changes: 1 addition & 1 deletion docs/api_reference/flax.nnx/training/optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ Optimizer
.. currentmodule:: flax.nnx.optimizer

.. autoclass:: Optimizer
:members:
:members: __init__, update
229 changes: 210 additions & 19 deletions flax/nnx/nnx/training/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,35 +45,81 @@ class MetricState(Variable):


class Metric(Object):
"""Base class for metrics. Any class that subclasses ``Metric`` should
implement a ``compute``, ``reset`` and ``update`` method."""

def __init__(self):
raise NotImplementedError('Must override `__init__()` method.')

def reset(self):
def reset(self) -> None:
"""In-place reset the ``Metric``."""
raise NotImplementedError('Must override `reset()` method.')

def update(self, **kwargs) -> None:
"""In-place update the ``Metric``."""
raise NotImplementedError('Must override `update()` method.')

def compute(self):
"""Compute and return the value of the ``Metric``."""
raise NotImplementedError('Must override `compute()` method.')

def split(self, *filters: filterlib.Filter):
return graph.split(self, *filters)


class Average(Metric):
"""Average metric.
Example usage::
>>> import jax.numpy as jnp
>>> from flax import nnx
>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
>>> metrics = nnx.metrics.Average()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Array(2.5, dtype=float32)
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Array(2., dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
"""

def __init__(self, argname: str = 'values'):
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
For example, constructing the metric as ``avg = Average('test')`` would allow you to make updates with
``avg.update(test=new_value)``.
Args:
argname: an optional string denoting the key-word argument that
:func:`update` will use to derive the new value. Defaults to
``'values'``.
"""
self.argname = argname
self.total = MetricState(jnp.array(0, dtype=jnp.float32))
self.count = MetricState(jnp.array(0, dtype=jnp.int32))

def reset(self):
def reset(self) -> None:
"""Reset this ``Metric``."""
self.total.value = jnp.array(0, dtype=jnp.float32)
self.count.value = jnp.array(0, dtype=jnp.int32)

def update(self, **kwargs):
def update(self, **kwargs) -> None:
"""In-place update this ``Metric``. This method will use the value from
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
defined on construction.
Args:
**kwargs: the key-word arguments that contains a ``self.argname``
entry that maps to the value we want to use to update this metric.
"""
if self.argname not in kwargs:
raise TypeError(f"Expected keyword argument '{self.argname}'")
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
Expand All @@ -82,7 +128,8 @@ def update(self, **kwargs):
)
self.count.value += 1 if isinstance(values, (int, float)) else values.size

def compute(self):
def compute(self) -> jax.Array:
"""Compute and return the average."""
return self.total.value / self.count.value


Expand All @@ -94,20 +141,60 @@ class Statistics:


class Welford(Metric):
"""Uses Welford's algorithm to compute the mean and variance of a stream of data."""
"""Uses Welford's algorithm to compute the mean and variance of a stream of data.
Example usage::
>>> import jax.numpy as jnp
>>> from flax import nnx
>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
>>> metrics = nnx.metrics.Welford()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
>>> metrics.update(values=batch_loss)
>>> metrics.compute()
Statistics(mean=Array(2.5, dtype=float32), standard_error_of_mean=Array(0.559017, dtype=float32), standard_deviation=Array(1.118034, dtype=float32))
>>> metrics.update(values=batch_loss2)
>>> metrics.compute()
Statistics(mean=Array(2., dtype=float32), standard_error_of_mean=Array(0.43301272, dtype=float32), standard_deviation=Array(1.2247449, dtype=float32))
>>> metrics.reset()
>>> metrics.compute()
Statistics(mean=Array(0., dtype=float32), standard_error_of_mean=Array(nan, dtype=float32), standard_deviation=Array(nan, dtype=float32))
"""

def __init__(self, argname: str = 'values'):
"""Pass in a string denoting the key-word argument that :func:`update` will use to derive the new value.
For example, constructing the metric as ``wf = Welford('test')`` would allow you to make updates with
``wf.update(test=new_value)``.
Args:
argname: an optional string denoting the key-word argument that
:func:`update` will use to derive the new value. Defaults to
``'values'``.
"""
self.argname = argname
self.count = MetricState(jnp.array(0, dtype=jnp.int32))
self.mean = MetricState(jnp.array(0, dtype=jnp.float32))
self.m2 = MetricState(jnp.array(0, dtype=jnp.float32))

def reset(self):
def reset(self) -> None:
"""Reset this ``Metric``."""
self.count.value = jnp.array(0, dtype=jnp.uint32)
self.mean.value = jnp.array(0, dtype=jnp.float32)
self.m2.value = jnp.array(0, dtype=jnp.float32)

def update(self, **kwargs):
def update(self, **kwargs) -> None:
"""In-place update this ``Metric``. This method will use the value from
``kwargs[self.argname]`` to update the metric, where ``self.argname`` is
defined on construction.
Args:
**kwargs: the key-word arguments that contains a ``self.argname``
entry that maps to the value we want to use to update this metric.
"""
if self.argname not in kwargs:
raise TypeError(f"Expected keyword argument '{self.argname}'")
values: tp.Union[int, float, jax.Array] = kwargs[self.argname]
Expand All @@ -123,7 +210,10 @@ def update(self, **kwargs):
m2 + delta * delta * count * original_count / self.count
)

def compute(self):
def compute(self) -> Statistics:
"""Compute and return the mean and variance statistics in a
``Statistics`` dataclass object.
"""
variance = self.m2 / self.count
standard_deviation = variance**0.5
sem = standard_deviation / (self.count**0.5)
Expand All @@ -135,8 +225,44 @@ def compute(self):


class Accuracy(Average):
"""Accuracy metric. This metric subclasses :class:`Average`,
and so they share the same ``reset`` and ``compute`` method
implementations. Unlike :class:`Average`, no string needs to
be passed to ``Accuracy`` during construction.
Example usage::
>>> from flax import nnx
>>> import jax, jax.numpy as jnp
def update(self, *, logits: jax.Array, labels: jax.Array, **_): # type: ignore[override]
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
>>> metrics = nnx.metrics.Accuracy()
>>> metrics.compute()
Array(nan, dtype=float32)
>>> metrics.update(logits=logits, labels=labels)
>>> metrics.compute()
Array(0.6, dtype=float32)
>>> metrics.update(logits=logits2, labels=labels2)
>>> metrics.compute()
Array(0.7, dtype=float32)
>>> metrics.reset()
>>> metrics.compute()
Array(nan, dtype=float32)
"""

def update(self, *, logits: jax.Array, labels: jax.Array, **_) -> None: # type: ignore[override]
"""In-place update this ``Metric``.
Args:
logits: the outputted predicted activations. These values are
argmax-ed (on the trailing dimension), before comparing them
to the labels.
labels: the ground truth integer labels.
"""
if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32:
raise ValueError(
f'Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}=='
Expand All @@ -150,20 +276,65 @@ class MultiMetric(Metric):
Example usage::
>>> import jax, jax.numpy as jnp
>>> from flax import nnx
...
>>> import jax, jax.numpy as jnp
>>> metrics = nnx.MultiMetric(
... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
... )
>>> metrics
MultiMetric(
accuracy=Accuracy(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
)
),
loss=Average(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
)
)
)
>>> metrics.accuracy
Accuracy(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
)
)
>>> metrics.loss
Average(
argname='values',
total=MetricState(
value=Array(shape=(), dtype=float32)
),
count=MetricState(
value=Array(shape=(), dtype=int32)
)
)
>>> logits = jax.random.normal(jax.random.key(0), (5, 2))
>>> labels = jnp.array([1, 1, 0, 1, 0])
>>> logits2 = jax.random.normal(jax.random.key(1), (5, 2))
>>> labels2 = jnp.array([0, 1, 1, 1, 1])
...
>>> batch_loss = jnp.array([1, 2, 3, 4])
>>> batch_loss2 = jnp.array([3, 2, 1, 0])
...
>>> metrics = nnx.MultiMetric(
... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()
... )
>>> metrics.compute()
{'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)}
>>> metrics.update(logits=logits, labels=labels, values=batch_loss)
Expand All @@ -178,23 +349,43 @@ class MultiMetric(Metric):
"""

def __init__(self, **metrics):
"""Pass in key-word arguments to the constructor, e.g.
``MultiMetric(keyword1=Average(), keyword2=Accuracy(), ...)``.
Args:
**metrics: the key-word arguments that will be used to access
the corresponding ``Metric``.
"""
# TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods
self._metric_names = []
for metric_name, metric in metrics.items():
self._metric_names.append(metric_name)
vars(self)[metric_name] = metric

def reset(self):
def reset(self) -> None:
"""Reset all underlying ``Metric``'s."""
for metric_name in self._metric_names:
getattr(self, metric_name).reset()

def update(self, **updates):
def update(self, **updates) -> None:
"""In-place update all underlying ``Metric``'s in this ``MultiMetric``. All
``**updates`` will be passed to the ``update`` method of all underlying
``Metric``'s.
Args:
**updates: the key-word arguments that will be passed to the underlying ``Metric``'s
``update`` method.
"""
# TODO: should we give the option of updating only some of the metrics and not all? e.g. if for some kwargs==None, don't do update
# TODO: should we raise an error if a kwarg is passed into **updates that has no match with any underlying metric? e.g. user typo
for metric_name in self._metric_names:
getattr(self, metric_name).update(**updates)

def compute(self):
def compute(self) -> dict[str, Metric]:
"""Compute and return the value of all underlying ``Metric``'s. This method
will return a dictionary, mapping strings (defined by the key-word arguments
``**metrics`` passed to the constructor) to the corresponding metric value.
"""
return {
f'{metric_name}': getattr(self, metric_name).compute()
for metric_name in self._metric_names
Expand Down
14 changes: 12 additions & 2 deletions flax/nnx/nnx/training/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,26 @@ class Optimizer(Object):
For more exotic usecases (e.g. multiple optimizers) it's probably best to
fork the class and modify it.
Args:
model: An NNX Module.
Attributes:
step: An ``OptState`` :class:`Variable` that tracks the step count.
model: The wrapped :class:`Module`.
tx: An Optax gradient transformation.
opt_state: The Optax optimizer state.
"""

def __init__(
self,
model: nnx.Module,
tx: optax.GradientTransformation,
):
"""
Instantiate the class and wrap the :class:`Module` and Optax gradient
transformation. Set the step count to 0.
Args:
model: An NNX Module.
tx: An Optax gradient transformation.
"""
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
self.model = model
self.tx = tx
Expand Down

0 comments on commit 31adb00

Please sign in to comment.