Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a is_static_jax property to TensorVariable's tag #182

Open
junpenglao opened this issue Nov 18, 2020 · 5 comments
Open

Add a is_static_jax property to TensorVariable's tag #182

junpenglao opened this issue Nov 18, 2020 · 5 comments
Assignees
Labels
enhancement New feature or request JAX Involves JAX transpilation

Comments

@junpenglao
Copy link
Contributor

Jax jit requires static inputs for some of the function args (for example, shape in jnp.reshape, length in jax.lax.scan). Currently, if these are symbolic input it will break jax.jit in
https://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/sandbox/jax_linker.py#L80

I propose we add a property to TensorVariable in:

diff --git a/theano/tensor/var.py b/theano/tensor/var.py
index 4cda4e5e1..6f2aaf398 100644
--- a/theano/tensor/var.py
+++ b/theano/tensor/var.py
@@ -872,6 +872,8 @@ class TensorVariable(_tensor_py_operators, Variable):
 
                 pdb.set_trace()
 
+    def is_static_jax(self):
+        return False
 
 TensorType.Variable = TensorVariable

and SharedVariable

diff --git a/theano/compile/sharedvalue.py b/theano/compile/sharedvalue.py
index cc3dd3cce..ca3e7af3b 100644
--- a/theano/compile/sharedvalue.py
+++ b/theano/compile/sharedvalue.py
@@ -224,6 +224,9 @@ class SharedVariable(Variable):
     # We keep this just to raise an error
     value = property(_value_get, _value_set)
 
+    def is_static_jax(self):
+        return False
+
 
 def shared_constructor(ctor, remove=False):
     if remove:

Then we can detect the additional static_argnums in:

diff --git a/theano/sandbox/jax_linker.py b/theano/sandbox/jax_linker.py
index 59b61caf3..0093c3fa7 100644
--- a/theano/sandbox/jax_linker.py
+++ b/theano/sandbox/jax_linker.py
@@ -62,7 +62,9 @@ class JAXLinker(PerformLinker):
         # I suppose we can consider `Constant`s to be "static" according to
         # JAX.
         static_argnums = [
-            n for n, i in enumerate(self.fgraph.inputs) if isinstance(i, Constant)
+            n
+            for n, i in enumerate(self.fgraph.inputs)
+            if isinstance(i, Constant) or i.is_static_jax
         ]
 
         thunk_inputs = [storage_map[n] for n in self.fgraph.inputs]

For user, they will need to mark these variable by hand for now, for example, we can do the following to make the tests pass:

diff --git a/tests/sandbox/test_jax.py b/tests/sandbox/test_jax.py
index 89c46ff9b..c3c3d7225 100644
--- a/tests/sandbox/test_jax.py
+++ b/tests/sandbox/test_jax.py
@@ -534,10 +534,10 @@ def test_jax_Reshape():
     compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(theano.config.floatX)])
 
 
-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
 def test_jax_Reshape_nonconcrete():
     a = tt.vector("a")
     b = tt.iscalar("b")
+    b.is_static_jax = True
     x = tt.basic.reshape(a, (b, b))
     x_fg = theano.gof.FunctionGraph([a, b], [x])
     compare_jax_and_py(
@@ -666,10 +666,10 @@ def test_tensor_basics():
     compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
 
 
-@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
 def test_arange_nonconcrete():
 
     a = tt.scalar("a")
+    a.is_static_jax = True
     a.tag.test_value = 10
 
     out = tt.arange(a)
@@ -677,7 +677,6 @@ def test_arange_nonconcrete():
     compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@junpenglao junpenglao added enhancement New feature or request JAX Involves JAX transpilation labels Nov 18, 2020
@brandonwillard
Copy link
Member

brandonwillard commented Nov 18, 2020

I thought static parameters were effectively Constants; is that not the correct indicator already?

Regarding shared values, those don't really exist within the JAX compilation context; they're Constants at that point. See #73 for an explanation.

Also, if I recall, static_argnums only applies to the inputs of the JITed function, which implies that such an indicator would only have relevance during the JAX transpilation process and nowhere else.

@junpenglao
Copy link
Contributor Author

I thought static parameters were effectively Constants; is that not the correct indicator already?

I guess those are parameters not actually static parameters (thus they are not Constants), but nonetheless due to the limit of XLA you need to treat them as static at runtime. If we can mark them as static, it gives us a bit of additional flexibility to jit those function.

Also, if I recall, static_argnums only applies to the inputs of the JITed function, which implies that such an indicator would only have relevance during the JAX transpilation process and nowhere else.

Yes that's right.

@brandonwillard
Copy link
Member

brandonwillard commented Nov 18, 2020

At a high level, we shouldn't add properties to classes unless they're directly relevant to the concepts/objects being modeled by the classes (e.g. a TensorVariable class should only concern itself with properties relating tensors and/or variables). This particular addition is far too specific to the JAX transpilation process, and, just like the numerous C-related methods attached to our classes, we can incorporate this information differently.

From a lower level, our class implementations need to remain as simple and "static" as possible. Doing so greatly improves the comprehensibility of our code, since it introduces fewer runtime and downstream logic surprises (e.g. avoiding questions like "What's this field, where did it come from, and how did it get set to this?"). Also, we could leverage some individually small—but cumulatively large—performance advantages from this situation (e.g. __slots__), especially as graphs scale. See issue #72 for related concerns.

That said, the tag field is better for this situation; however, it's still not clear to me how we would use this information. Can you give a small example of how/when we could use it to avoid the limitations in jnp.reshape and others?

@junpenglao
Copy link
Contributor Author

That said, the tag field is better for this situation; however, it's still not clear to me how we would use this information. Can you give a small example of how/when we could use it to avoid the limitations in jnp.reshape and others?

Adding it to tag is a nice compromised.
The way we could use it to avoid the limitation in jnp.reshape is by marking the shape arg tensors being jax_static, then theano.function(..., mode='jax') would also work (i.e., does not gives an error during jax.jit).

@ricardoV94
Copy link
Contributor

ricardoV94 commented Oct 30, 2021

I came across something like this in #631. There is a second problem in that scalar symbolic variables become scalar numpy arrays during execution, and these cannot be used as static arguments for Jax functions, because they are not hashable.

@twiecki twiecki changed the title Design proposal: Add a is_static_jax property to TensorVariable Add a is_static_jax property to TensorVariable's tag Oct 30, 2021
@twiecki twiecki changed the title Add a is_static_jax property to TensorVariable's tag Add a is_static_jax property to TensorVariable's tag Oct 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request JAX Involves JAX transpilation
Projects
None yet
Development

No branches or pull requests

3 participants