Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Added
flax/typing.py
file to unify typing annotations.Few things to note:
Filter
is defined differently in NNX, so I did not unify NNX's definition under theflax.typing
definitionAxis
is defined differently in NNX, so I did not unify NNX's definition under theflax.typing
definitionPartitioned
objects can also be used askernel_init
andbias_init
arguments. Currently to accommodate this, theInitializer
type annotation isUnion[jax.nn.initializers.Initializer, Callable[..., Any]]
. The more correct annotation would beUnion[jax.nn.initializers.Initializer, Callable[..., Partitioned[Any]]]
, but then the linter will complain about code that tries to callArray
methods or accessArray
attributes like.reshape
and.shape
on the initializedkernel
orbias
sincePartitioned[Any]
does not have any of these methods or attributes. The linter will also complain about functions that take in the initializedkernel
orbias
as an input argument, because they expect anArray
and notPartitioned[Any]
.