-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Translation behaviour with 1D SX arrays #16
Comments
Thank you for pointing this out, dimensionality always was a headache. I have added a new test on inputs and modified several existing ones, to mix different dimensions. Now all the inputs are raveled, hence native casadi indexing works as expected, as well as The only concern is the performance of raveling, but as far as I know, only structural metadata is modified, hence no overhead should be added. |
* feat: optimisiing comp graph * feat: graph tranlsation * Feat/pendulum rollout example (#9) * added pendulum rollout as example * fixed README * fix: rebase * Update README.md Added hyperlink for landing * fix: update branch for colab benchmark * added demo notebook (#13) * added demo notebook * minor readme change * fix: rebase * fix: rebase and fix tests * fix: pre-commit * feat: graph expansion * fix: benchmarks for more powers * fix: disable graph compression * fix: densify structural zeros * fix: adaprive dimensionality #16 * fix: translate as graph_translate * feat: test expand, docs * feat: test examples * fix: exclude running examples from workflow * fix: remove ast debug files * fix: organize imports * fix: ops fix * fix: update plots for graph translation * fix: add plots to website --------- Co-authored-by: Simeon Nedelchev <simkaned@gmail.com> Co-authored-by: Lev Kozlov <kozlov.l.a10@gmail.com>
Hello,
Thank you very much for this nice framework! 👍
I ran into some troubles using the conversion with a 1D array input and I was wondering what would be the best way to handle that.
Let's consider a simple case: getting a function that perform A @ x, with A a constant matrix and x the input
A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])
Coding and testing this function with pure Jax would be:
This outputs:
Now, if I want to apply A to a batch of x, I would simply do:
This outputs:
Now, let's do it in casadi instead, with a direct translation of the previous jax function:
This raises an
IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.
This happens because although
x = ca.SX.sym("x", 2)
, x accesses get translated intoinputs[0][0, 0]
andinputs[0][1, 0]
byOP_INPUT: "inputs[{0}][{1}, {2}]"
usingExcept the input has no second dimension, so the column index leads to the IndexError.
The workaround is either to change the shape of the input to have that missing axis:
Or expand a fake axis with a wrapper:
What do you think is the best solution do handle this 1D case in your framework? Would it be possible to tweak the translation function to have only
inputs[0][0]
andinputs[0][1]
when the SX variable is detected as 1D?Thank you for your help!
Best,
The text was updated successfully, but these errors were encountered: