Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/pendulum rollout example #9

Merged
merged 2 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,18 @@ JAXADI comes with several examples to help you get started:

3. [Function Conversion](examples/02_convert.py): See how to fully convert CasADi functions to JAX.

4. [Pinocchio Integration](examples/03_pinocchio.py): Explore how to convert Pinocchio-based CasADi functions to JAX.
4. [Pendulum Rollout](examples/03_pendulum_rollout.py): Batched rollout of the nonlinear passive nonlinear pendulum

5. [Pinocchio Integration](examples/04_pinocchio.py): Explore how to convert Pinocchio-based CasADi functions to JAX.

> **Note**: To run the Pinocchio example, ensure you have Pinocchio properly installed in your environment.
6. [MJX Comparison](examples/05_mjx.py): Compare the transformed Pinnocchio forward kinematics with one provided by Mujoco MJX

> **Note**: To run the Pinocchio and MJX examples, ensure you have them properly installed in your environment.

## Performance Benchmarks

(Consider adding a section about performance comparisons between CasADi and JAXADI-translated functions)
<!-- ## Performance Benchmarks

(Consider adding a section about performance comparisons between CasADi and JAXADI-translated functions) -->

<!-- ## Contributing

Expand Down
101 changes: 101 additions & 0 deletions examples/03_pendulum_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import timeit
import casadi as ca
import jax
import jax.numpy as jnp
import numpy as np
from jaxadi import convert

# Static parameters
dt = 0.02
g = 9.81 # Acceleration due to gravity
L = 1.0 # Length of the pendulum
b = 0.1 # Damping coefficient
I = 1.0 # Moment of inertia
# Test parameters
batch_size = 4096
timesteps = 100


# Define the uncontrolled pendulum model in CasADi
def casadi_pendulum_model():
state = ca.SX.sym("state", 2)
theta, omega = state[0], state[1]

theta_dot = omega
omega_dot = (-b * omega - (g / L) * ca.sin(theta)) / I

next_theta = theta + theta_dot * dt
next_omega = omega + omega_dot * dt

next_state = ca.vertcat(next_theta, next_omega)
return ca.Function("pendulum_model", [state], [next_state])


# Create CasADi function
casadi_model = casadi_pendulum_model()

# Convert CasADi function to JAX
jax_model = convert(casadi_model, compile=True)


# Function to generate random inputs
def generate_random_inputs(batch_size):
return np.random.uniform(-np.pi, np.pi, (batch_size, 2))


# CasADi: Sequential Evaluation
def casadi_sequential_rollout(initial_states):
batch_size = initial_states.shape[0]
rollout_states = np.zeros((timesteps + 1, batch_size, 2))

rollout_states[0] = initial_states
for t in range(1, timesteps + 1):
rollout_states[t] = np.array([casadi_model(state).full().flatten() for state in rollout_states[t - 1]])

return rollout_states


# JAX: Vectorized Evaluation
@jax.jit
def jax_vectorized_rollout(initial_states):
def single_step(state):
return jnp.array(jax_model(state)).reshape(
2,
)

def scan_fn(carry, _):
next_state = jax.vmap(single_step)(carry)
return next_state, next_state

_, rollout_states = jax.lax.scan(scan_fn, initial_states, None, length=timesteps)
return jnp.concatenate([initial_states[None, ...], rollout_states], axis=0)


# Generate random inputs
initial_states = generate_random_inputs(batch_size)

# Warm-up call for JAX
print("Performing warm-up call for JAX...")
_ = jax_vectorized_rollout(initial_states)
print("Warm-up call completed.")
# Performance comparison
print("\nPerformance comparison:")
# Generate new random inputs
initial_states = generate_random_inputs(batch_size)

print(f"CasADi sequential rollout ({batch_size} pendulums, {timesteps} timesteps):")
casadi_time = timeit.timeit(lambda: casadi_sequential_rollout(initial_states), number=1)
print(f"Time: {casadi_time:.4f} seconds")

print(f"\nJAX vectorized rollout ({batch_size} pendulums, {timesteps} timesteps):")
jax_time = timeit.timeit(lambda: np.array(jax_vectorized_rollout(initial_states)), number=1)
print(f"Time: {jax_time:.4f} seconds")

print(f"\nSpeedup factor: {casadi_time / jax_time:.2f}x")

# Verify results
print("\nVerifying results:")
casadi_results = casadi_sequential_rollout(initial_states[:10])
jax_results = np.array(jax_vectorized_rollout(initial_states[:10]))

print("First 10 rollouts match:", np.allclose(casadi_results, jax_results, atol=1e-4))
File renamed without changes.
File renamed without changes.
Loading