-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
proxy more methods for nnx.Variable
#4234
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Hey @njzjz, do you have a specific use case in mind? |
I have a complex method that supports different backends like NumPy and JAX. The Array API is used. An example is shown below. import array_api_compat
def f(x, w, b):
xp = array_api_compat.array_namespace(x, w, b)
return xp.matmal(w, x) + b When In addition, without this PR, one can not do Currently, I use the workaround to inherit class ArrayAPIParam(nnx.Param):
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)
def __array_namespace__(self, *args, **kwargs):
return self.value.__array_namespace__(*args, **kwargs)
def __dlpack__(self, *args, **kwargs):
return self.value.__dlpack__(*args, **kwargs)
def __dlpack_device__(self, *args, **kwargs):
return self.value.__dlpack_device__(*args, **kwargs) |
Thanks! Can say a bit more how you are using this inside the My personal opinion is that we should merge this PR because Variable is already a proxy. Just want to understand this better because |
An example is below: def f(x, w, b):
xp = array_api_compat.array_namespace(x, w, b)
return xp.matmal(w, x) + b
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return f(x, self.w, self.b) You are right, self._w = nnx.Param(jax.random.uniform(key, (din, dout)))
self._b = nnx.Param(jnp.zeros((dout,)))
self.w = self._w.value
self.b = self._b.value |
@njzjz def f(x, w, b):
xp = array_api_compat.array_namespace(x, w, b)
return xp.matmal(w, x) + b
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
key = rngs.params()
self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
self.din, self.dout = din, dout
def __call__(self, x: jax.Array):
return f(x, self.w.value, self.b.value) |
What does this PR do?
Proxy
__array__
,__array_namespace__
,__dlpack__
, and__dlpack_device__
forflax.nnx.Variable
.__array_namespace__
,__dlpack__
, and__dlpack_device__
are specified in the Array API.__array__
is specified in NumPy. These methods have been supported byjax.numpy.Array
.Checklist
checks if that's the case).
discussion (please add a
link).
documentation guidelines.
(No quality testing = no merge!)