diff --git a/src/jaxsim/parsers/descriptions/collision.py b/src/jaxsim/parsers/descriptions/collision.py
index e4703332a..1075a5964 100644
--- a/src/jaxsim/parsers/descriptions/collision.py
+++ b/src/jaxsim/parsers/descriptions/collision.py
@@ -158,12 +158,11 @@ def __eq__(self, other: BoxCollision) -> bool:
 
 @dataclasses.dataclass
 class MeshCollision(CollisionShape):
-
-    center: npt.NDArray[np.float64]
+    center: npt.NDArray
 
     def __eq__(self, other: Any) -> bool:
         if not isinstance(other, MeshCollision):
             return False
         return len(self.collidable_points) == len(
             other.collidable_points
-        ) and super().__eq__(other)
+        ) and super().__eq__(other) and (self.center == other.center).all()
diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py
index 0b110ec91..5de8f414c 100644
--- a/src/jaxsim/parsers/rod/utils.py
+++ b/src/jaxsim/parsers/rod/utils.py
@@ -250,10 +250,10 @@ def create_mesh_collision(
     points = mesh.vertices
     H = (
         collision.pose.transform() if collision.pose is not None else np.eye(4)
-    )  # pose of the collision object
+    )
     center_of_collision_wrt_link = (H @ np.hstack([0, 0, 0, 1.0]))[
         0:-1
-    ]  # @ = matrix multiplication, hstack = stack arrays in sequence horizontally => center of the collision object
+    ]
     mesh_points_wrt_link = (
         H @ np.hstack([points, np.vstack([1.0] * points.shape[0])]).T
     )[0:3, :]