Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added flax.typing #3624

Merged
merged 1 commit into from
Jan 30, 2024
Merged

added flax.typing #3624

merged 1 commit into from
Jan 30, 2024

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Jan 16, 2024

Added flax/typing.py file to unify typing annotations.

Few things to note:

  • Filter is defined differently in NNX, so I did not unify NNX's definition under the flax.typing definition
  • Axis is defined differently in NNX, so I did not unify NNX's definition under the flax.typing definition
  • In addition to regular JAX initializers, Partitioned objects can also be used as kernel_init and bias_init arguments. Currently to accommodate this, the Initializer type annotation is Union[jax.nn.initializers.Initializer, Callable[..., Any]]. The more correct annotation would be Union[jax.nn.initializers.Initializer, Callable[..., Partitioned[Any]]], but then the linter will complain about code that tries to call Array methods or access Array attributes like .reshape and .shape on the initialized kernel or bias since Partitioned[Any] does not have any of these methods or attributes. The linter will also complain about functions that take in the initialized kernel or bias as an input argument, because they expect an Array and not Partitioned[Any].

@codecov-commenter
Copy link

codecov-commenter commented Jan 17, 2024

Codecov Report

Attention: 7 lines in your changes are missing coverage. Please review.

Comparison is base (74bdc7b) 56.19% compared to head (96d2d5f) 56.12%.
Report is 2 commits behind head on main.

Files Patch % Lines
flax/training/dynamic_scale.py 0.00% 3 Missing ⚠️
flax/linen/partitioning.py 0.00% 2 Missing ⚠️
flax/cursor.py 0.00% 1 Missing ⚠️
flax/linen/linear.py 92.30% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3624      +/-   ##
==========================================
- Coverage   56.19%   56.12%   -0.08%     
==========================================
  Files         101      102       +1     
  Lines       12214    12180      -34     
==========================================
- Hits         6864     6836      -28     
+ Misses       5350     5344       -6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@chiamp chiamp requested a review from cgarciae January 17, 2024 02:14
@chiamp chiamp force-pushed the typing branch 2 times, most recently from 83468af to 333bcb0 Compare January 17, 2024 09:24
flax/core/__init__.py Outdated Show resolved Hide resolved
@chiamp chiamp force-pushed the typing branch 9 times, most recently from 09db07e to 6e22402 Compare January 24, 2024 18:57
@chiamp chiamp force-pushed the typing branch 6 times, most recently from 8034f3a to 96d2d5f Compare January 30, 2024 20:36
@copybara-service copybara-service bot merged commit 8995420 into google:main Jan 30, 2024
19 checks passed
@chiamp chiamp deleted the typing branch January 30, 2024 23:48
@jan1854 jan1854 mentioned this pull request Mar 30, 2024
14 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants