Skip to content

Commit

Permalink
Merge pull request #3448 from superbobry:fork-configuration
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579243331
  • Loading branch information
Flax Authors committed Nov 3, 2023
2 parents 0b126b8 + 44c2a88 commit c1023d9
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 31 deletions.
3 changes: 2 additions & 1 deletion docs/api_reference/flax.config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ flax.config package

.. automodule:: flax.configurations
:members:
:exclude-members: temp_flip_flag
:undoc-members:
:exclude-members: FlagHolder, bool_flag, temp_flip_flag, static_bool_env
135 changes: 105 additions & 30 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,115 @@
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Global configuration options for Flax.
Now a wrapper over jax.config, in which all config vars have a 'flax\_' prefix.
To modify a config value on run time, call:
``flax.config.update('flax_<config_name>', <value>)``
"""
"""Global configuration flags for Flax."""

import os
from contextlib import contextmanager
from typing import Any, Generic, NoReturn, TypeVar, overload

_T = TypeVar('_T')


class Config:
# See https://google.github.io/pytype/faq.html.
_HAS_DYNAMIC_ATTRIBUTES = True

def __init__(self):
self._values = {}

def _add_option(self, name, default):
if name in self._values:
raise RuntimeError(f'Config option {name} already defined')
self._values[name] = default

def _read(self, name):
try:
return self._values[name]
except KeyError:
raise LookupError(f'Unrecognized config option: {name}')

@overload
def update(self, name: str, value: Any, /) -> None:
...

@overload
def update(self, holder: 'FlagHolder[_T]', value: _T, /) -> None:
...

def update(self, name_or_holder, value, /):
"""Modify the value of a given flag.
Args:
name_or_holder: the name of the flag to modify or the corresponding
flag holder object.
value: new value to set.
"""
name = name_or_holder
if isinstance(name_or_holder, FlagHolder):
name = name_or_holder.name
if name not in self._values:
raise LookupError(f'Unrecognized config option: {name}')
self._values[name] = value

from jax import config as jax_config

# Keep a wrapper at the flax namespace, in case we make our implementation
# in the future.
config = jax_config
config = Config()

# Config parsing utils


def define_bool_state(name, default, help):
"""Set up a boolean flag using JAX's config system.
class FlagHolder(Generic[_T]):
def __init__(self, name, help):
self.name = name
self.__name__ = name[4:] if name.startswith('flax_') else name
self.__doc__ = f'Flag holder for `{name}`.\n\n{help}'

The flag will actually be stored as an environment variable of
'FLAX_<UPPERCASE_NAME>'. JAX config ensures that the flag can be overwritten
on runtime with `flax.config.update('flax_<config_name>', <value>)`.
def __bool__(self) -> NoReturn:
raise TypeError(
"bool() not supported for instances of type '{0}' "
"(did you mean to use '{0}.value' instead?)".format(type(self).__name__)
)

@property
def value(self) -> _T:
return config._read(self.name)


def bool_flag(name: str, *, default: bool, help: str) -> FlagHolder[bool]:
"""Set up a boolean flag.
Example::
enable_foo = bool_flag(
name='flax_enable_foo',
default=False,
help='Enable foo.',
)
Now the ``FLAX_ENABLE_FOO`` shell environment variable can be used to
control the process-level value of the flag, in addition to using e.g.
``config.update("flax_enable_foo", True)`` directly.
Args:
name: converted to lowercase to define the name of the flag. It is
converted to uppercase to define the corresponding shell environment
variable.
default: a default value for the flag.
help: used to populate the docstring of the returned flag holder object.
Returns:
A flag holder object for accessing the value of the flag.
"""
return jax_config.define_bool_state('flax_' + name, default, help)
name = name.lower()
config._add_option(name, static_bool_env(name.upper(), default))
fh = FlagHolder[bool](name, help)
setattr(Config, name, property(lambda _: fh.value, doc=help))
return fh


def static_bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
This is deprecated. Please use define_bool_state() unless your flag
This is deprecated. Please use bool_flag() unless your flag
will be used in a static method and does not require runtime updates.
True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
Expand Down Expand Up @@ -90,39 +165,39 @@ def temp_flip_flag(var_name: str, var_value: bool):
# Whether to use the lazy rng implementation.
flax_lazy_rng = static_bool_env('FLAX_LAZY_RNG', True)

flax_filter_frames = define_bool_state(
name='filter_frames',
flax_filter_frames = bool_flag(
name='flax_filter_frames',
default=True,
help='Whether to hide flax-internal stack frames from tracebacks.',
)

flax_profile = define_bool_state(
name='profile',
flax_profile = bool_flag(
name='flax_profile',
default=True,
help='Whether to run Module methods under jax.named_scope for profiles.',
)

flax_use_orbax_checkpointing = define_bool_state(
name='use_orbax_checkpointing',
flax_use_orbax_checkpointing = bool_flag(
name='flax_use_orbax_checkpointing',
default=True,
help='Whether to use Orbax to save checkpoints.',
)

flax_preserve_adopted_names = define_bool_state(
name='preserve_adopted_names',
flax_preserve_adopted_names = bool_flag(
name='flax_preserve_adopted_names',
default=False,
help="When adopting outside modules, don't clobber existing names.",
)

# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
flax_return_frozendict = bool_flag(
name='flax_return_frozendict',
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name='fix_rng_separator',
flax_fix_rng = bool_flag(
name='flax_fix_rng_separator',
default=False,
help=(
'Whether to add separator characters when folding in static data into'
Expand Down
52 changes: 52 additions & 0 deletions tests/configurations_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2023 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

from absl.testing import absltest

from flax.configurations import bool_flag, config


class MyTestCase(absltest.TestCase):
def setUp(self):
super().setUp()
self.enter_context(mock.patch.object(config, '_values', {}))
self._flag = bool_flag('test', default=False, help='Just a test flag.')

def test_duplicate_flag(self):
with self.assertRaisesRegex(RuntimeError, 'already defined'):
bool_flag(self._flag.name, default=False, help='Another test flag.')

def test_default(self):
self.assertFalse(self._flag.value)
self.assertFalse(config.test)

def test_typed_update(self):
config.update(self._flag, True)
self.assertTrue(self._flag.value)
self.assertTrue(config.test)

def test_untyped_update(self):
config.update(self._flag.name, True)
self.assertTrue(self._flag.value)
self.assertTrue(config.test)

def test_update_unknown_flag(self):
with self.assertRaisesRegex(LookupError, 'Unrecognized config option'):
config.update('unknown', True)


if __name__ == '__main__':
absltest.main()

0 comments on commit c1023d9

Please sign in to comment.