Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forked a subset of JAX configuration APIs #3448

Merged
merged 1 commit into from
Nov 3, 2023

Conversation

superbobry
Copy link
Member

These APIs are internal to JAX and should not be used by other projects for managing their configuration.

@codecov-commenter
Copy link

codecov-commenter commented Oct 30, 2023

Codecov Report

Merging #3448 (44c2a88) into main (0b126b8) will increase coverage by 0.05%.
The diff coverage is 90.38%.

@@            Coverage Diff             @@
##             main    #3448      +/-   ##
==========================================
+ Coverage   83.50%   83.56%   +0.05%     
==========================================
  Files          56       56              
  Lines        6725     6766      +41     
==========================================
+ Hits         5616     5654      +38     
- Misses       1109     1112       +3     
Files Coverage Δ
flax/configurations.py 91.42% <90.38%> (+1.77%) ⬆️

@superbobry superbobry force-pushed the fork-configuration branch 2 times, most recently from a8031e5 to 927cf87 Compare October 30, 2023 10:42
@IvyZX
Copy link
Collaborator

IvyZX commented Oct 31, 2023

Thanks for working on it!

  1. We only have bool flags now, but with this setup, does adding non-bool flags just requires forking a few more lines?

  2. Since we no longer rely on an existing framework, can we have a simple unit test for this? To make sure the runtime update works, and that the flag value will not reset after re-importing the configuration.py. Thank you!

@superbobry
Copy link
Member Author

Thanks for having a look, Ivy!

  1. Yes, adding support for new flag types should be easy. E.g. here is how int_flag could look

    def int_flag(name: str, default: bool, help: str) -> FlagHolder[int]:
      name = name.lower()
      if default_env := os.getenv(name.upper()):
        try:
          default = int(default_env)
        except ValueError:
          raise ValueError(f'Invalid value "{default_env}" for flag {name}')
      config._add_option(name, default)
      fh = FlagHolder[int](name, help)
      setattr(Config, name, property(lambda _: fh.value, doc=help))
      return fh
  2. Will add a test, sure!

@superbobry superbobry marked this pull request as ready for review November 2, 2023 08:07
These APIs are internal to JAX and should not be used by other projects for
managing their configuration.
@copybara-service copybara-service bot merged commit c1023d9 into google:main Nov 3, 2023
19 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants