-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Avoid to use as Static attributes classes that do not have a __eq__ method that returns a scalar bool #105
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @traversaro! This should solve the longstanding issue we had.
I believe that if GroundContact.body
does not need to be modified, we could prefer to make it a tuple instead of a list, what do you think @diegoferigo?
It LGTM anyway 🚀
I'm thinking to send a PR to |
…s and could be used as Static jax_dataclasses attributes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem reported in #103 (and I guess #84, even if I did not checked directly that one) is specifically caused by some
Static
attributes of jaxsim classes not having__eq__
methods that return objects that can be casted to bools.
Wow thanks a lot for this investigation @traversaro, that's great! As mentioned in #103 (comment), I was suspecting something related to __hash__
or __eq__
was involved, but as first attempt I blamed the implementation of the pytree's class, not the one of its static attributes. Good to know! Super happy if we can finally have this solved. It simplifies prototyping in a Jupyter notebook (generally, in interactive iPython).
I'm thinking to send a PR to
jax_dataclasses
to check this before JIT is applied
I was about to suggest it, you anticipated me :) I would start investigating if we can raise an error in _flatten
if not isinstance(treedef, collections.abc.Hashable)
(see collections.abc.Hashable
).
After sleeping on this, I guess that using Python introspection also the |
What about comparing def __eq__(self, other):
if not isinstance(other, type(self)):
return False
return self.__dict__ == other.__dict__ |
Don't you get the same problem of |
Note that the problem here is not that the fact that the structure is hashable, but that it can be compare with |
Yes sorry for the confusion, I missed a step. I guess that if a class is hashable, its You can refer to Python data model for further details on the interaction between |
I don't like too much the idea of |
Ack, probably then we can merge the custom |
Definitely, feel free to merge this PR if it's ready to be merged. |
Regarding this, I made some additional tests and the problem doesn't seem to be related to Static attributes only. In fact, trying with: Test scriptimport jax.numpy as jnp
import jaxsim.api as js
import rod.builder.primitives
import rod.urdf.exporter
rod_model = (
rod.builder.primitives.BoxBuilder(x=0.3, y=0.2, z=0.1, mass=1.0, name="box")
.build_model()
.add_link()
.add_inertial()
.add_visual()
.add_collision()
.build()
)
# Export the URDF string.
urdf_string = rod.urdf.exporter.UrdfExporter.sdf_to_urdf_string(
sdf=rod_model, pretty=True
)
model1 = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_string,
gravity=jnp.array([0, 0, -10]),
is_urdf=True,
)
model2 = js.model.JaxSimModel.build_from_model_description(
model_description=urdf_string,
gravity=jnp.array([0, 0, -10]),
is_urdf=True,
)
data1 = js.data.JaxSimModelData.build(model=model1)
data2 = js.data.JaxSimModelData.build(model=model2)
data1 == data2 I still obtain: ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all() What solved the issue was to implement a def __eq__(self, other):
if self.__class__ is not other.__class__:
return False
return all(
(key in other.__dict__)
and (
np.array_equal(self.__dict__[key], other.__dict__[key])
if isinstance(self.__dict__[key], np.ndarray)
or isinstance(self.__dict__[key], jnp.ndarray)
else self.__dict__[key] == other.__dict__[key]
)
for key in self.__dict__
) and then inside def __eq__(self, other):
return super().__eq__(other) which actually I thought it wasn't necessary as: >>> js.data.JaxSimModelData.__mro__
(<class 'jaxsim.api.data.JaxSimModelData'>, <class 'jaxsim.api.common.ModelDataWithVelocityRepresentation'>, <class 'jaxsim.utils.jaxsim_dataclass.JaxsimDataclass'>, <class 'abc.ABC'>, <class 'object'>) Edit: the |
While you are investigating, if not necessary, I'd suggest to use |
Note that some dataclasses/jax_dataclasses decorator take in input a
Yes, but non-Static attributes do not the constraint that |
Fix #103 .
Problem description
The problem reported in #103 (and I guess #84, even if I did not checked directly that one) is specifically caused by some
Static
attributes of jaxsim classes not having__eq__
methods that return objects that can be casted to bools. In particular, according to https://jax.readthedocs.io/en/latest/pytrees.html#pytrees any value that is returned as part ofaux_data
second return value of thetree_flatten
method of a class passed tojax.tree_util.register_pytree_node_class
must:As it is made even more explicit in jax-ml/jax#19547 (comment) :
The problem was triggered only in the second run of a
jit
function with the same instance, as that was the only time in which the program actually compared the value of static attributes.Solution proposed in this PR
This condition is not respected in jaxsim before this PR. In this PR, I fixed the jaxsim classes to fix the minimal example provided in #103 . This is done in two ways:
CollidablePoint
,BoxCollision
,SphereCollision
,LinkDescription
andRootPose
, as these classes containednp.array
orjnp.array
attributes, I defined custom__eq__
methods to insert appropriatly the.all()
method when comparing for equality arrays (commit: 1c34033)GroundContact
the situation was a bit more complex, as in that case the problematic attribute was thebody
that was anp.array
and was itself marked asStatic
attribute , and I could not re-define its__eq__
method to return a scalar bool. So, just for this case I decided to change thebody
attribute to be alist
instead ofnd.array
.Realistically, other classes are affected by the same problem, and I did not noticed them as the test reported in #103 is quite minimal (for example, no joint was involved). However, I think that for fixing those it is just a matter of having more complete tests (such as the one added in #102) and just iterating on those tests until no
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
appears anymore.Requires #104 to be merged before.