diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py index e4703332a..1075a5964 100644 --- a/src/jaxsim/parsers/descriptions/collision.py +++ b/src/jaxsim/parsers/descriptions/collision.py @@ -158,12 +158,11 @@ def __eq__(self, other: BoxCollision) -> bool: @dataclasses.dataclass class MeshCollision(CollisionShape): - - center: npt.NDArray[np.float64] + center: npt.NDArray def __eq__(self, other: Any) -> bool: if not isinstance(other, MeshCollision): return False return len(self.collidable_points) == len( other.collidable_points - ) and super().__eq__(other) + ) and super().__eq__(other) and (self.center == other.center).all() diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 0b110ec91..5de8f414c 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -250,10 +250,10 @@ def create_mesh_collision( points = mesh.vertices H = ( collision.pose.transform() if collision.pose is not None else np.eye(4) - ) # pose of the collision object + ) center_of_collision_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[ 0:-1 - ] # @ = matrix multiplication, hstack = stack arrays in sequence horizontally => center of the collision object + ] mesh_points_wrt_link = ( H @ np.hstack([points, np.vstack([1.0] * points.shape[0])]).T )[0:3, :]