Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use structural pattern matching for jcal #77

Merged
merged 1 commit into from
Feb 2, 2024
Merged

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented Feb 2, 2024

This PR focuses on optimizing the computation speed of core algorithms modules that rely on jcalc, e.g. ABA, CRBA, RNEA and soft contacts. The implementation introduces structural pattern matching, to replace extensive if-else statements. The anticipated outcome is an average speed improvement of about 80% in functions handling joint motion subspace or joint transformation.

if-else implementation:

%timeit -n 5 _ = model.physics_model.joint_transforms(q)
>>> 884 ms ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 _ = model.physics_model.motion_subspaces(q)
>>> 879 ms ± 28.2 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

structural pattern matching implementation:

%timeit -n 5 _ = model.physics_model.motion_subspaces(q)
>>> 173 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

%timeit -n 5 _ = model.physics_model.joint_transforms(q)
>>> 178 ms ± 6.12 ms per loop (mean ± std. dev. of 7 runs, 5 loops each)

📚 Documentation preview 📚: https://jaxsim--77.org.readthedocs.build//77/

@flferretti flferretti self-assigned this Feb 2, 2024
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks for the improvement! It would be interesting to understand the performance of the jitted execution since we never run the vanilla code. Maybe once lowered to XLA, the speed becomes comparable.

Few more details:

  • The joint type is a static member, stored in PhysicsModel._jtype_dict. Therefore, once compiled, only one branch becomes part of the compiled code. I suspect that this similarly happen also with the match-case statement.
  • Regardless, using match-case is more readable than if-else. The jcalc function was developed before Python 3.10.
  • Since JointType is an IntEnum using auto, the values of its fields are integers starting from 0. Probably we can rewrite the jcalc function using jax.lax.switch, that would maybe allow us to make the joint type not static. I have no clue how the the speed of both compilation and runtime would be affected (either positively or negatively). To do that, we need an enum class compatible with jax, I think equinox has one.

@flferretti
Copy link
Collaborator Author

Thanks Diego! For the third point I guess we can open an issue to discuss about it

@flferretti flferretti merged commit 0cdfa79 into main Feb 2, 2024
22 checks passed
@flferretti flferretti deleted the flferretti-patch-1 branch February 2, 2024 13:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants