Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy more methods for nnx.Variable #4234

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

import dataclasses
import enum
import functools
from functools import partial
import typing as tp
Expand Down Expand Up @@ -654,6 +655,17 @@ def __floor__(self) -> A:
def __ceil__(self) -> A:
return self.value.__ceil__() # type: ignore

def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs) # type: ignore

def __array_namespace__(self, *args, **kwargs):
return self.value.__array_namespace__(*args, **kwargs) # type: ignore

def __dlpack__(self, *args, **kwargs):
return self.value.__dlpack__(*args, **kwargs) # type: ignore

def __dlpack_device__(self) -> tuple[enum.Enum, int]:
return self.value.__dlpack_device__() # type: ignore

class Param(Variable[A]):
"""The canonical learnable parameter. All learnable parameters
Expand Down
15 changes: 15 additions & 0 deletions tests/nnx/variable_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import typing as tp

import numpy as np
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -81,6 +82,20 @@ def __call__(self, x):

self.assertEqual(result, 6)

def test_proxy_dlpack(self):
v = nnx.Param(jnp.ones((2, 3)))

self.assertEqual(np.from_dlpack(v).shape, (2, 3))

def test_proxy_array(self):
v = nnx.Param(jnp.ones((2, 3)))

self.assertEqual(np.asarray(v).shape, (2, 3))

def test_proxy_array_namespace(self):
v = nnx.Param(jnp.ones((2, 3)))

self.assertEqual(v.__array_namespace__(), jnp)

if __name__ == '__main__':
absltest.main()