Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Translation behaviour with 1D SX arrays #16

Closed
paLeziart opened this issue Nov 28, 2024 · 1 comment
Closed

Translation behaviour with 1D SX arrays #16

paLeziart opened this issue Nov 28, 2024 · 1 comment
Assignees

Comments

@paLeziart
Copy link

Hello,

Thank you very much for this nice framework! 👍

I ran into some troubles using the conversion with a 1D array input and I was wondering what would be the best way to handle that.

Let's consider a simple case: getting a function that perform A @ x, with A a constant matrix and x the input
A = jnp.array([[1.0, 1.0], [-1.0, 0.0], [0.0, -1.0]])

Coding and testing this function with pure Jax would be:

import jax
import jax.numpy as jnp


@jax.jit
def jax_Ax(x):
    Ax = jnp.zeros(3)
    Ax = Ax.at[0].set(jnp.sum(x, axis=-1))
    Ax = Ax.at[1].set(-x[0])
    Ax = Ax.at[2].set(-x[1])
    return Ax


if __name__ == "__main__":
    key = jax.random.key(0)

    x = jax.random.uniform(key, (2,))
    Ax = jax_Ax(x)
    print("Input: ", x)
    print("Output: ", Ax)

This outputs:

Input:  [0.21629536 0.8041241 ]
Output:  [ 1.0204195  -0.21629536 -0.8041241 ]

Now, if I want to apply A to a batch of x, I would simply do:

import jax
import jax.numpy as jnp


@jax.jit
def jax_Ax(x):
    Ax = jnp.zeros(3)
    Ax = Ax.at[0].set(jnp.sum(x, axis=-1))
    Ax = Ax.at[1].set(-x[0])
    Ax = Ax.at[2].set(-x[1])
    return Ax


if __name__ == "__main__":
    key = jax.random.key(0)

    x = jax.random.uniform(key, (4, 2))
    Ax = jax.vmap(jax_Ax)(x)
    print("Input:\n", x)
    print("Output:\n", Ax)

This outputs:

Input:
 [[0.53222644 0.34965682]
 [0.35358644 0.9524387 ]
 [0.10100961 0.9829632 ]
 [0.1953876  0.8708868 ]]
Output:
 [[ 0.88188326 -0.53222644 -0.34965682]
 [ 1.3060251  -0.35358644 -0.9524387 ]
 [ 1.0839728  -0.10100961 -0.9829632 ]
 [ 1.0662744  -0.1953876  -0.8708868 ]]

Now, let's do it in casadi instead, with a direct translation of the previous jax function:

import casadi as ca
import jax
import jax.numpy as jnp
import jaxadi

def create_Ax():

    x = ca.SX.sym("x", 2)
    Ax = ca.SX.sym("Ax", 3)

    Ax[0] = x[0] + x[1]
    Ax[1] = -x[0]
    Ax[2] = -x[1]

    # Create CasADi function
    ca_Ax = ca.Function("ca_Ax", [x], [Ax], ["x"], ["Ax"])

    # Convert CasADi function to JAX
    jax_Ax = jaxadi.convert(ca_Ax, compile=True)

    return jax_Ax

if __name__ == "__main__":
    key = jax.random.key(0)

    jax_Ax = create_Ax()
    x = jax.random.uniform(key, (4, 2))
    Ax = jax.vmap(jax_Ax)(x)
    print("Input:\n", x)
    print("Output:\n", Ax)

This raises an IndexError: Too many indices: 0-dimensional array indexed with 1 regular index.

This happens because although x = ca.SX.sym("x", 2), x accesses get translated into inputs[0][0, 0] and inputs[0][1, 0] by OP_INPUT: "inputs[{0}][{1}, {2}]" using

this_shape = in_shapes[i_idx[0]]
rows, cols = this_shape  # Get the shape of the output
row_number = i_idx[1] % rows  # Compute row index for JAX
column_number = i_idx[1] // rows  # Compute column index for JAX
workers[o_idx[0]] = OP_JAX_VALUE_DICT[op].format(i_idx[0], row_number, column_number)

Except the input has no second dimension, so the column index leads to the IndexError.

The workaround is either to change the shape of the input to have that missing axis:

if __name__ == "__main__":
    key = jax.random.key(0)

    jax_Ax = create_Ax()
    x = jax.random.uniform(key, (4, 2, 1))
    Ax = jax.vmap(jax_Ax)(x)
    print("Input:\n", x)
    print("Output:\n", Ax)
Input:
 [[[0.53222644]
  [0.34965682]]

 [[0.35358644]
  [0.9524387 ]]

 [[0.10100961]
  [0.9829632 ]]

 [[0.1953876 ]
  [0.8708868 ]]]
Output:
 [Array([[[ 0.88188326],
        [-0.53222644],
        [-0.34965682]],

       [[ 1.3060251 ],
        [-0.35358644],
        [-0.9524387 ]],

       [[ 1.0839728 ],
        [-0.10100961],
        [-0.9829632 ]],

       [[ 1.0662744 ],
        [-0.1953876 ],
        [-0.8708868 ]]], dtype=float32)]

Or expand a fake axis with a wrapper:

@jax.jit
def wrap(x):
    return jax_Ax(jnp.expand_dims(x, 1))

if __name__ == "__main__":
    key = jax.random.key(0)

    jax_Ax = create_Ax()
    x = jax.random.uniform(key, (4, 2))
    Ax = jax.vmap(wrap)(x)
    print("Input:\n", x)
    print("Output:\n", Ax)
Input:
 [[0.53222644 0.34965682]
 [0.35358644 0.9524387 ]
 [0.10100961 0.9829632 ]
 [0.1953876  0.8708868 ]]
Output:
 [Array([[[ 0.88188326],
        [-0.53222644],
        [-0.34965682]],

       [[ 1.3060251 ],
        [-0.35358644],
        [-0.9524387 ]],

       [[ 1.0839728 ],
        [-0.10100961],
        [-0.9829632 ]],

       [[ 1.0662744 ],
        [-0.1953876 ],
        [-0.8708868 ]]], dtype=float32)]

What do you think is the best solution do handle this 1D case in your framework? Would it be possible to tweak the translation function to have only inputs[0][0] and inputs[0][1] when the SX variable is detected as 1D?

Thank you for your help!

Best,

@mattephi
Copy link
Member

mattephi commented Dec 7, 2024

Thank you for pointing this out, dimensionality always was a headache. I have added a new test on inputs and modified several existing ones, to mix different dimensions.

Now all the inputs are raveled, hence native casadi indexing works as expected, as well as (n,1) vs (n,) problems are resolved.

The only concern is the performance of raveling, but as far as I know, only structural metadata is modified, hence no overhead should be added.

@mattephi mattephi closed this as completed Dec 7, 2024
lvjonok added a commit that referenced this issue Feb 8, 2025
* feat: optimisiing comp graph

* feat: graph tranlsation

* Feat/pendulum rollout example (#9)

* added pendulum rollout as example

* fixed README

* fix: rebase

* Update README.md

Added hyperlink for landing

* fix: update branch for colab benchmark

* added demo notebook (#13)

* added demo notebook

* minor readme change

* fix: rebase

* fix: rebase and fix tests

* fix: pre-commit

* feat: graph expansion

* fix: benchmarks for more powers

* fix: disable graph compression

* fix: densify structural zeros

* fix: adaprive dimensionality #16

* fix: translate as graph_translate

* feat: test expand, docs

* feat: test examples

* fix: exclude running examples from workflow

* fix: remove ast debug files

* fix: organize imports

* fix: ops fix

* fix: update plots for graph translation

* fix: add plots to website

---------

Co-authored-by: Simeon Nedelchev <simkaned@gmail.com>
Co-authored-by: Lev Kozlov <kozlov.l.a10@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants