-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add custom concatenate/stack functions #8
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrick-kidger FYI
sympy2jax/sympy_module.py
Outdated
@@ -34,7 +37,16 @@ def fn_(*args): | |||
return fn_ | |||
|
|||
|
|||
def _args_as_array(fn): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This "works", but looking for other ideas as it means clients cannot specify other args to stack/concatenate. Encapsulating the inputs as an array in sympy would require some more type handling there. Let me know if you have other ideas here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not clear to me why this is necessary. Can't we just use stack/concatenate directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These interfaces take a sequence of arrays as the first argument. jnp.stack, for example: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.stack.html
jax.numpy.stack(arrays, axis=0, out=None, dtype=None)
IIUC to use this directly, I'd need to either already have an array, or to encode my input as an array/tuple/sequence within sympy - which I believe is not something that is supported in sympy2jax currently. e.g. stack(Array([x, y , z]))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They accept arraylikes, not just arrays (i.e. bool/int/float/complex as well.) E.g. jnp.stack([1, 2])
works just fine.
That aside, most of the time the input should already be an array anyway (from earlier operations).
Note that what you're doing only works at the moment because jnp.array
is doing the stack/concenate for you. A stack or concatenate op is the identity function on a single argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code here is not about changing the type to Array (that was just an implementation choice), it's for putting the arguments into the right position.
The 0th argument to jnp.stack
needs to be an arraylike of values to concatenate. AFAIK I have no way of constructing a sequence like this in sympy2jax compatible sympy (let me know if that's incorrect and please share some example sympy that would achieve this).
In our case we have several state variables we're trying to stitch together and would like to emit as an array, e.g. state_x, state_y
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you're also talking about just combining the args. Okay:
def _single_args(fn):
def _fn(*args):
return fn(args)
return _fn
_single_args(jnp.stack)
?
Indeed I don't think passing the other arguments to stack
is important. (Some of them are holdovers from numpy and aren't implemented in JAX anyway.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Sorry - we were talking past eachother a bit. My original comment was lamenting the usage of the wrapper function in general (for combining the args).
You're right that the call to jnp.array was not needed - I wasn't really even looking at this. Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other than the one question, this LGTM!
I think it makes sense to include in sympy2jax so that folks don't need to think about how to implement this. But it could easily be done by a client, it's true.
Looks like the pre-commit checks failed. FYI you can arrange to have these run automatically whenever you commit. They will either pass and commit, or fail and autoformat. If the latter, you can then check you're happy with the autoformat, followed by (See CONTRIBUTING.MD.) |
Fixed, I missed this repo was using isort. I'll try out the precommit hook next time. |
Merged! Ty. |
This is to support stacking / concatenating arrays within a sympy expression.
For discussion: