-
Notifications
You must be signed in to change notification settings - Fork 660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] remove flagslib #3733
[nnx] remove flagslib #3733
Conversation
866f60f
to
67202a7
Compare
67202a7
to
831eee9
Compare
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
831eee9
to
3719686
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3733 +/- ##
==========================================
+ Coverage 58.43% 58.77% +0.34%
==========================================
Files 102 101 -1
Lines 12365 12409 +44
==========================================
+ Hits 7225 7293 +68
+ Misses 5140 5116 -24 ☔ View full report in Codecov by Sentry. |
@@ -0,0 +1,188 @@ | |||
# Copyright 2024 The Flax Authors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file was deleted after #3742. I think you need to rebase.
flax/experimental/nnx/nnx/module.py
Outdated
@@ -109,6 +109,12 @@ def _meta_call(cls: tp.Type[M], *args, **kwargs) -> M: | |||
vars(module)[field.name] = None | |||
continue | |||
|
|||
if 'nnx_variable_constructor' not in field.metadata: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was deleted as well in #3742.
@@ -68,6 +68,17 @@ def call(cls: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P: | |||
with _mutable(obj), _initializing(obj): | |||
obj.__init__(*args, **kwargs) | |||
|
|||
if dataclasses.is_dataclass(obj): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was deleted as well in #3742.
@@ -492,31 +516,31 @@ def __init__(self): | |||
|
|||
class TestModuleDataclass: | |||
def test_basic(self): | |||
@dataclasses.dataclass | |||
@nnx.dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test was modified by #3742.
@@ -48,12 +47,12 @@ def __init__(self, y) -> None: | |||
pytree.x = 4 | |||
|
|||
def test_immutable_pytree_dataclass(self): | |||
@dataclasses.dataclass(frozen=True) | |||
@nnx.dataclass(frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file was modified by #3742.
flax/linen/fp8_ops.py
Outdated
@@ -12,81 +12,16 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
|
|||
import dataclasses | |||
import numpy as np | |||
import warnings |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes don't look related to this PR. Maybe rebasing will fix this?
flax/linen/linear.py
Outdated
@@ -31,12 +31,9 @@ | |||
from jax import eval_shape, lax | |||
from jax.core import ShapedArray | |||
|
|||
import opt_einsum | |||
|
|||
from flax.core import meta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, rebase
self, *filters: filterlib.Filter, **attributes: tp.Any | ||
) -> None: | ||
"""Sets the attributes of nested Modules including the current Module. | ||
If the attribute is not found in the Module, it is ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would you consider a flag that would raise an error if the attribute isn't found?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
544091e
to
41b1e2c
Compare
Cleaned the PR to remove the spurious changes. |
e5a2b4e
to
db0e96a
Compare
db0e96a
to
2206c40
Compare
What does this PR do?
This PR removes the
flagslib
module is favor of aModule.set_attributes
method that recursively sets the attributes of all Modules in the Module graph. Basically, flags are now just Module attributes plus a mechanism to recursively set them.