Skip to content

Commit

Permalink
Merge pull request #18 from clemense/clemense/add_equality_check
Browse files Browse the repository at this point in the history
Add equality operator for URDF model
  • Loading branch information
clemense authored May 5, 2022
2 parents 37429e1 + faac022 commit a62bdc2
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 14 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Changelog

## Version 0.0.45 (development)
## Version 0.0.45 (upcoming, development)
- Upgrade to trimesh version 3.11.2
- Add `__eq__` operator to URDF based on equality of individual elements (order-invariant)

## Version 0.0.44 (development)
- Parse and write `name` attribute of `material` element
Expand Down
193 changes: 180 additions & 13 deletions src/yourdfpy/urdf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from numpy.lib.npyio import load
import six
import copy
import logging
Expand All @@ -16,27 +15,85 @@
_logger = logging.getLogger(__name__)


@dataclass
def _array_eq(arr1, arr2):
if arr1 is None and arr2 is None:
return True
return (
isinstance(arr1, np.ndarray)
and isinstance(arr2, np.ndarray)
and arr1.shape == arr2.shape
and (arr1 == arr2).all()
)


@dataclass(eq=False)
class TransmissionJoint:
name: str
hardware_interfaces: List[str] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, TransmissionJoint):
return NotImplemented
return (
self.name == other.name
and all(
self_hi in other.hardware_interfaces
for self_hi in self.hardware_interfaces
)
and all(
other_hi in self.hardware_interfaces
for other_hi in other.hardware_interfaces
)
)

@dataclass

@dataclass(eq=False)
class Actuator:
name: str
mechanical_reduction: Optional[float] = None
# The follwing is only valid for ROS Indigo and prior versions
hardware_interfaces: List[str] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Actuator):
return NotImplemented
return (
self.name == other.name
and self.mechanical_reduction == other.mechanical_reduction
and all(
self_hi in other.hardware_interfaces
for self_hi in self.hardware_interfaces
)
and all(
other_hi in self.hardware_interfaces
for other_hi in other.hardware_interfaces
)
)


@dataclass
@dataclass(eq=False)
class Transmission:
name: str
type: Optional[str] = None
joints: List[TransmissionJoint] = field(default_factory=list)
actuators: List[Actuator] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Transmission):
return NotImplemented
return (
self.name == other.name
and self.type == other.type
and all(self_joint in other.joints for self_joint in self.joints)
and all(other_joint in self.joints for other_joint in other.joints)
and all(
self_actuator in other.actuators for self_actuator in self.actuators
)
and all(
other_actuator in self.actuators for other_actuator in other.actuators
)
)


@dataclass
class Calibration:
Expand Down Expand Up @@ -70,16 +127,33 @@ class Cylinder:
length: float


@dataclass
@dataclass(eq=False)
class Box:
size: np.ndarray

def __eq__(self, other):
if not isinstance(other, Box):
return NotImplemented
return _array_eq(self.size, other.size)

@dataclass

@dataclass(eq=False)
class Mesh:
filename: str
scale: Optional[Union[float, np.ndarray]] = None

def __eq__(self, other):
if not isinstance(other, Mesh):
return NotImplemented

if self.filename != other.filename:
return False

if isinstance(self.scale, float) and isinstance(other.scale, float):
return self.scale == other.scale

return _array_eq(self.scale, other.scale)


@dataclass
class Geometry:
Expand All @@ -89,10 +163,15 @@ class Geometry:
mesh: Optional[Mesh] = None


@dataclass
@dataclass(eq=False)
class Color:
rgba: np.ndarray

def __eq__(self, other):
if not isinstance(other, Color):
return NotImplemented
return _array_eq(self.rgba, other.rgba)


@dataclass
class Texture:
Expand All @@ -106,35 +185,80 @@ class Material:
texture: Optional[Texture] = None


@dataclass
@dataclass(eq=False)
class Visual:
name: Optional[str] = None
origin: Optional[np.ndarray] = None
geometry: Optional[Geometry] = None # That's not really optional according to ROS
material: Optional[Material] = None

def __eq__(self, other):
if not isinstance(other, Visual):
return NotImplemented
return (
self.name == other.name
and _array_eq(self.origin, other.origin)
and self.geometry == other.geometry
and self.material == other.material
)

@dataclass

@dataclass(eq=False)
class Collision:
name: str
origin: Optional[np.ndarray] = None
geometry: Geometry = None

def __eq__(self, other):
if not isinstance(other, Collision):
return NotImplemented
return (
self.name == other.name
and _array_eq(self.origin, other.origin)
and self.geometry == other.geometry
)

@dataclass

@dataclass(eq=False)
class Inertial:
origin: Optional[np.ndarray] = None
mass: Optional[float] = None
inertia: Optional[np.ndarray] = None

def __eq__(self, other):
if not isinstance(other, Inertial):
return NotImplemented
return (
_array_eq(self.origin, other.origin)
and self.mass == other.mass
and _array_eq(self.inertia, other.inertia)
)


@dataclass
@dataclass(eq=False)
class Link:
name: str
inertial: Optional[Inertial] = None
visuals: List[Visual] = field(default_factory=list)
collisions: List[Collision] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Link):
return NotImplemented
return (
self.name == other.name
and self.inertial == other.inertial
and all(self_visual in other.visuals for self_visual in self.visuals)
and all(other_visual in self.visuals for other_visual in other.visuals)
and all(
self_collision in other.collisions for self_collision in self.collisions
)
and all(
other_collision in self.collisions
for other_collision in other.collisions
)
)


@dataclass
class Dynamics:
Expand All @@ -150,7 +274,7 @@ class Limit:
upper: Optional[float] = None


@dataclass
@dataclass(eq=False)
class Joint:
name: str
type: str = None
Expand All @@ -164,15 +288,53 @@ class Joint:
calibration: Optional[Calibration] = None
safety_controller: Optional[SafetyController] = None

def __eq__(self, other):
if not isinstance(other, Joint):
return NotImplemented
return (
self.name == other.name
and self.type == other.type
and self.parent == other.parent
and self.child == other.child
and _array_eq(self.origin, other.origin)
and _array_eq(self.axis, other.axis)
and self.dynamics == other.dynamics
and self.limit == other.limit
and self.mimic == other.mimic
and self.calibration == other.calibration
and self.safety_controller == other.safety_controller
)

@dataclass

@dataclass(eq=False)
class Robot:
name: str
links: List[Link] = field(default_factory=list)
joints: List[Joint] = field(default_factory=list)
transmission: List[str] = field(default_factory=list)
gazebo: List[str] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Robot):
return NotImplemented
return (
self.name == other.name
and all(self_link in other.links for self_link in self.links)
and all(other_link in self.links for other_link in other.links)
and all(self_joint in other.joints for self_joint in self.joints)
and all(other_joint in self.joints for other_joint in other.joints)
and all(
self_transmission in other.transmission
for self_transmission in self.transmission
)
and all(
other_transmission in self.transmission
for other_transmission in other.transmission
)
and all(self_gazebo in other.gazebo for self_gazebo in self.gazebo)
and all(other_gazebo in self.gazebo for other_gazebo in other.gazebo)
)


class URDFError(Exception):
"""General URDF exception."""
Expand Down Expand Up @@ -1982,3 +2144,8 @@ def _write_robot(self, robot):
self._write_joint(xml_element, joint)

return xml_element

def __eq__(self, other):
if not isinstance(other, URDF):
raise NotImplemented
return self.robot == other.robot
23 changes: 23 additions & 0 deletions tests/test_urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,26 @@ def test_validate():
def test_mimic_joint():
urdf_fname = os.path.join(DIR_MODELS, "franka", "franka.urdf")
urdf_model = urdf.URDF.load(urdf_fname)

assert True


def test_equality():
urdf_fname = os.path.join(DIR_MODELS, "franka", "franka.urdf")
urdf_model_0 = urdf.URDF.load(urdf_fname)

urdf_model_1 = urdf.URDF.load(urdf_fname)

assert urdf_model_0 == urdf_model_1


def test_equality_different_link_order():
robot_0 = _create_robot()
robot_0.links.append(urdf.Link(name="link_0"))
robot_0.links.append(urdf.Link(name="link_1"))

robot_1 = _create_robot()
robot_1.links.append(urdf.Link(name="link_1"))
robot_1.links.append(urdf.Link(name="link_0"))

assert robot_0 == robot_1

0 comments on commit a62bdc2

Please sign in to comment.