Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

proxy more methods for nnx.Variable #4234

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

proxy more methods for nnx.Variable #4234

wants to merge 3 commits into from

Conversation

njzjz
Copy link

@njzjz njzjz commented Sep 29, 2024

What does this PR do?

Proxy __array__, __array_namespace__, __dlpack__, and __dlpack_device__ for flax.nnx.Variable.

__array_namespace__, __dlpack__, and __dlpack_device__ are specified in the Array API. __array__ is specified in NumPy. These methods have been supported by jax.numpy.Array.

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion (please add a
    link).
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
@cgarciae
Copy link
Collaborator

cgarciae commented Oct 1, 2024

Hey @njzjz, do you have a specific use case in mind?

@njzjz
Copy link
Author

njzjz commented Oct 1, 2024

Hey @njzjz, do you have a specific use case in mind?

I have a complex method that supports different backends like NumPy and JAX. The Array API is used. An example is shown below.

import array_api_compat

def f(x, w, b):
    xp = array_api_compat.array_namespace(x, w, b)
    return xp.matmal(w, x) + b

When w and b are JAX Arrays, it works as expected, as JAX has fully supported the Array API (xref: jax-ml/jax#22818). However, when w and b are nnx.Variable, it doesn't work without this PR.

In addition, without this PR, one can not do np.asarray(w) or np.from_dlpack(w) to directly convert an nnx.Variable to a numpy.ndarray.

Currently, I use the workaround to inherit nnx.Variable as shown below, and I think these methods should be added to nnx.Variable itself.

class ArrayAPIParam(nnx.Param):
    def __array__(self, *args, **kwargs):
        return self.value.__array__(*args, **kwargs)

    def __array_namespace__(self, *args, **kwargs):
        return self.value.__array_namespace__(*args, **kwargs)

    def __dlpack__(self, *args, **kwargs):
        return self.value.__dlpack__(*args, **kwargs)

    def __dlpack_device__(self, *args, **kwargs):
        return self.value.__dlpack_device__(*args, **kwargs)

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 2, 2024

Thanks! Can say a bit more how you are using this inside the nnx.Module? Would using the Variable's .value property to extract the inner array work for your use case?

My personal opinion is that we should merge this PR because Variable is already a proxy. Just want to understand this better because __jax_array__ is not a public JAX API that can cause issue from time to time.

@njzjz
Copy link
Author

njzjz commented Oct 2, 2024

Thanks! Can say a bit more how you are using this inside the nnx.Module? Would using the Variable's .value property to extract the inner array work for your use case?

An example is below:

def f(x, w, b):
    xp = array_api_compat.array_namespace(x, w, b)
    return xp.matmal(w, x) + b

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return f(x, self.w, self.b)

You are right, self.w.value should work. I don't quite understand how nnx.Param works. Will the following code work?

    self._w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self._b = nnx.Param(jnp.zeros((dout,)))
    self.w = self._w.value
    self.b = self._b.value

@cgarciae
Copy link
Collaborator

cgarciae commented Oct 8, 2024

@njzjz Variable just holds and inner .value and to make it feel like an Array we are implementing all these proxy methods. But we do require that Arrays are held in Variables, so ideally you do this:

def f(x, w, b):
    xp = array_api_compat.array_namespace(x, w, b)
    return xp.matmal(w, x) + b

class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    key = rngs.params()
    self.w = nnx.Param(jax.random.uniform(key, (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))
    self.din, self.dout = din, dout

  def __call__(self, x: jax.Array):
    return f(x, self.w.value, self.b.value)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants