Skip to content

Commit

Permalink
feat: one-line translator (#6)
Browse files Browse the repository at this point in the history
* feat: one-line translator

* fix: remove legacy translator

* fix: version bump
  • Loading branch information
mattephi authored Sep 9, 2024
1 parent 10ea41d commit d3a4dc0
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 367 deletions.
4 changes: 2 additions & 2 deletions examples/04_mjx.py → examples/03_mjx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pinocchio as pin
import pinocchio.casadi as cpin
from robot_descriptions.iiwa14_mj_description import MJCF_PATH
from jaxadi import convert, translate
from jaxadi import convert
import mujoco
import mujoco.mjx as mjx

Expand Down Expand Up @@ -137,7 +137,7 @@ def mjx_fk(joint_pos):
print(f"\nSpeedup factors:")
print(f"Generated JAX vs Casadi: {casadi_time / jax_time:.2f}x")
print(f"MJX vs Casadi: {casadi_time / mjx_time:.2f}x")
print(f"MJX vs Generated JAX: {jax_time / mjx_time:.2f}x")
print(f"Generated JAX vs MJX: {mjx_time / jax_time:.2f}x")

# Verify results
print("\nVerifying performance test results:")
Expand Down
109 changes: 0 additions & 109 deletions examples/03_pinocchio.py

This file was deleted.

73 changes: 0 additions & 73 deletions examples/_pinocchio_old.py

This file was deleted.

2 changes: 1 addition & 1 deletion jaxadi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._compile import lower
from ._convert import convert
from ._translate import translate, legacy_translate
from ._translate import translate
from ._declare import declare
50 changes: 0 additions & 50 deletions jaxadi/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,56 +48,6 @@
OP_TWICE,
)

OP_JAX_DICT = {
OP_ASSIGN: "\n work = work.at[{0}].set(work[{1}])",
OP_ADD: "\n work = work.at[{0}].set(work[{1}] + work[{2}])",
OP_SUB: "\n work = work.at[{0}].set(work[{1}] - work[{2}])",
OP_MUL: "\n work = work.at[{0}].set(work[{1}] * work[{2}])",
OP_DIV: "\n work = work.at[{0}].set(work[{1}] / work[{2}])",
OP_NEG: "\n work = work.at[{0}].set(-work[{1}])",
OP_EXP: "\n work = work.at[{0}].set(jnp.exp(work[{1}]))",
OP_LOG: "\n work = work.at[{0}].set(jnp.log(work[{1}]))",
OP_POW: "\n work = work.at[{0}].set(jnp.power(work[{1}], work[{2}]))",
OP_CONSTPOW: "\n work = work.at[{0}].set(jnp.power(work[{1}], work[{2}]))",
OP_SQRT: "\n work = work.at[{0}].set(jnp.sqrt(work[{1}]))",
OP_SQ: "\n work = work.at[{0}].set(work[{1}] * work[{2}])",
OP_TWICE: "\n work = work.at[{0}].set(2 * work[{1}])",
OP_SIN: "\n work = work.at[{0}].set(jnp.sin(work[{1}]))",
OP_COS: "\n work = work.at[{0}].set(jnp.cos(work[{1}]))",
OP_TAN: "\n work = work.at[{0}].set(jnp.tan(work[{1}]))",
OP_ASIN: "\n work = work.at[{0}].set(jnp.arcsin(work[{1}]))",
OP_ACOS: "\n work = work.at[{0}].set(jnp.arccos(work[{1}]))",
OP_ATAN: "\n work = work.at[{0}].set(jnp.arctan(work[{1}]))",
OP_LT: "\n work = work.at[{0}].set(work[{1}] < work[{2}])",
OP_LE: "\n work = work.at[{0}].set(work[{1}] <= work[{2}])",
OP_EQ: "\n work = work.at[{0}].set(work[{1}] == work[{2}])",
OP_NE: "\n work = work.at[{0}].set(work[{1}] != work[{2}])",
OP_NOT: "\n work = work.at[{0}].set(jnp.logical_not(work[{1}]))",
OP_AND: "\n work = work.at[{0}].set(jnp.logical_and(work[{1}], work[{2}]))",
OP_OR: "\n work = work.at[{0}].set(jnp.logical_or(work[{1}], work[{2}]))",
OP_FLOOR: "\n work = work.at[{0}].set(jnp.floor(work[{1}]))",
OP_CEIL: "\n work = work.at[{0}].set(jnp.ceil(work[{1}]))",
OP_FMOD: "\n work = work.at[{0}].set(jnp.fmod(work[{1}], work[{2}]))",
OP_FABS: "\n work = work.at[{0}].set(jnp.abs(work[{1}]))",
OP_SIGN: "\n work = work.at[{0}].set(jnp.sign(work[{1}]))",
OP_COPYSIGN: "\n work = work.at[{0}].set(jnp.copysign(work[{1}], work[{2}]))",
OP_IF_ELSE_ZERO: "\n work = work.at[{0}].set(jnp.where(work[{1}] == 0, 0, work[{2}]))",
OP_ERF: "\n work = work.at[{0}].set(jax.scipy.special.erf(work[{1}]))",
OP_FMIN: "\n work = work.at[{0}].set(jnp.minimum(work[{1}], work[{2}]))",
OP_FMAX: "\n work = work.at[{0}].set(jnp.maximum(work[{1}], work[{2}]))",
OP_INV: "\n work = work.at[{0}].set(1.0 / work[{1}])",
OP_SINH: "\n work = work.at[{0}].set(jnp.sinh(work[{1}]))",
OP_COSH: "\n work = work.at[{0}].set(jnp.cosh(work[{1}]))",
OP_TANH: "\n work = work.at[{0}].set(jnp.tanh(work[{1}]))",
OP_ASINH: "\n work = work.at[{0}].set(jnp.arcsinh(work[{1}]))",
OP_ACOSH: "\n work = work.at[{0}].set(jnp.arccosh(work[{1}]))",
OP_ATANH: "\n work = work.at[{0}].set(jnp.arctanh(work[{1}]))",
OP_ATAN2: "\n work = work.at[{0}].set(jnp.arctan2(work[{1}], work[{2}]))",
OP_CONST: "\n work = work.at[{0}].set({1:.16f})",
OP_INPUT: "\n work = work.at[{0}].set(inputs[{1}][{2}, {3}])",
OP_OUTPUT: "\n outputs[{0}] = outputs[{0}].at[{1}, {2}].set(work[{3}][0])",
}

OP_JAX_VALUE_DICT = {
OP_ASSIGN: "work[{0}]",
OP_ADD: "work[{0}] + work[{1}]",
Expand Down
Loading

0 comments on commit d3a4dc0

Please sign in to comment.