-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a is_static_jax
property to TensorVariable's tag
#182
Comments
I thought static parameters were effectively Regarding shared values, those don't really exist within the JAX compilation context; they're Also, if I recall, |
I guess those are parameters not actually static parameters (thus they are not Constants), but nonetheless due to the limit of XLA you need to treat them as static at runtime. If we can mark them as static, it gives us a bit of additional flexibility to jit those function.
Yes that's right. |
At a high level, we shouldn't add properties to classes unless they're directly relevant to the concepts/objects being modeled by the classes (e.g. a From a lower level, our class implementations need to remain as simple and "static" as possible. Doing so greatly improves the comprehensibility of our code, since it introduces fewer runtime and downstream logic surprises (e.g. avoiding questions like "What's this field, where did it come from, and how did it get set to this?"). Also, we could leverage some individually small—but cumulatively large—performance advantages from this situation (e.g. That said, the |
Adding it to |
I came across something like this in #631. There is a second problem in that scalar symbolic variables become scalar numpy arrays during execution, and these cannot be used as static arguments for Jax functions, because they are not hashable. |
is_static_jax
property to TensorVariableis_static_jax
property to TensorVariable's tag
is_static_jax
property to TensorVariable's tagis_static_jax
property to TensorVariable's tag
Jax jit requires static inputs for some of the function args (for example,
shape
injnp.reshape
,length
injax.lax.scan
). Currently, if these are symbolic input it will breakjax.jit
inhttps://github.com/pymc-devs/Theano-PyMC/blob/a9275c3dcc998c8cca5719037e493809b23422ff/theano/sandbox/jax_linker.py#L80
I propose we add a property to
TensorVariable
in:and
SharedVariable
Then we can detect the additional
static_argnums
in:For user, they will need to mark these variable by hand for now, for example, we can do the following to make the tests pass:
The text was updated successfully, but these errors were encountered: