Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add equality operator for URDF model #18

Merged
merged 5 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Changelog

## Version 0.0.45 (development)
## Version 0.0.45 (upcoming, development)
- Upgrade to trimesh version 3.11.2
- Add `__eq__` operator to URDF based on equality of individual elements (order-invariant)

## Version 0.0.44 (development)
- Parse and write `name` attribute of `material` element
Expand Down
193 changes: 180 additions & 13 deletions src/yourdfpy/urdf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from numpy.lib.npyio import load
import six
import copy
import logging
Expand All @@ -16,27 +15,85 @@
_logger = logging.getLogger(__name__)


@dataclass
def _array_eq(arr1, arr2):
if arr1 is None and arr2 is None:
return True
return (
isinstance(arr1, np.ndarray)
and isinstance(arr2, np.ndarray)
and arr1.shape == arr2.shape
and (arr1 == arr2).all()
)


@dataclass(eq=False)
class TransmissionJoint:
name: str
hardware_interfaces: List[str] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, TransmissionJoint):
return NotImplemented
return (
self.name == other.name
and all(
self_hi in other.hardware_interfaces
for self_hi in self.hardware_interfaces
)
and all(
other_hi in self.hardware_interfaces
for other_hi in other.hardware_interfaces
)
)

@dataclass

@dataclass(eq=False)
class Actuator:
name: str
mechanical_reduction: Optional[float] = None
# The follwing is only valid for ROS Indigo and prior versions
hardware_interfaces: List[str] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Actuator):
return NotImplemented
return (
self.name == other.name
and self.mechanical_reduction == other.mechanical_reduction
and all(
self_hi in other.hardware_interfaces
for self_hi in self.hardware_interfaces
)
and all(
other_hi in self.hardware_interfaces
for other_hi in other.hardware_interfaces
)
)


@dataclass
@dataclass(eq=False)
class Transmission:
name: str
type: Optional[str] = None
joints: List[TransmissionJoint] = field(default_factory=list)
actuators: List[Actuator] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Transmission):
return NotImplemented
return (
self.name == other.name
and self.type == other.type
and all(self_joint in other.joints for self_joint in self.joints)
and all(other_joint in self.joints for other_joint in other.joints)
and all(
self_actuator in other.actuators for self_actuator in self.actuators
)
and all(
other_actuator in self.actuators for other_actuator in other.actuators
)
)


@dataclass
class Calibration:
Expand Down Expand Up @@ -70,16 +127,33 @@ class Cylinder:
length: float


@dataclass
@dataclass(eq=False)
class Box:
size: np.ndarray

def __eq__(self, other):
if not isinstance(other, Box):
return NotImplemented
return _array_eq(self.size, other.size)

@dataclass

@dataclass(eq=False)
class Mesh:
filename: str
scale: Optional[Union[float, np.ndarray]] = None

def __eq__(self, other):
if not isinstance(other, Mesh):
return NotImplemented

if self.filename != other.filename:
return False

if isinstance(self.scale, float) and isinstance(other.scale, float):
return self.scale == other.scale

return _array_eq(self.scale, other.scale)


@dataclass
class Geometry:
Expand All @@ -89,10 +163,15 @@ class Geometry:
mesh: Optional[Mesh] = None


@dataclass
@dataclass(eq=False)
class Color:
rgba: np.ndarray

def __eq__(self, other):
if not isinstance(other, Color):
return NotImplemented
return _array_eq(self.rgba, other.rgba)


@dataclass
class Texture:
Expand All @@ -106,35 +185,80 @@ class Material:
texture: Optional[Texture] = None


@dataclass
@dataclass(eq=False)
class Visual:
name: Optional[str] = None
origin: Optional[np.ndarray] = None
geometry: Optional[Geometry] = None # That's not really optional according to ROS
material: Optional[Material] = None

def __eq__(self, other):
if not isinstance(other, Visual):
return NotImplemented
return (
self.name == other.name
and _array_eq(self.origin, other.origin)
and self.geometry == other.geometry
and self.material == other.material
)

@dataclass

@dataclass(eq=False)
class Collision:
name: str
origin: Optional[np.ndarray] = None
geometry: Geometry = None

def __eq__(self, other):
if not isinstance(other, Collision):
return NotImplemented
return (
self.name == other.name
and _array_eq(self.origin, other.origin)
and self.geometry == other.geometry
)

@dataclass

@dataclass(eq=False)
class Inertial:
origin: Optional[np.ndarray] = None
mass: Optional[float] = None
inertia: Optional[np.ndarray] = None

def __eq__(self, other):
if not isinstance(other, Inertial):
return NotImplemented
return (
_array_eq(self.origin, other.origin)
and self.mass == other.mass
and _array_eq(self.inertia, other.inertia)
)


@dataclass
@dataclass(eq=False)
class Link:
name: str
inertial: Optional[Inertial] = None
visuals: List[Visual] = field(default_factory=list)
collisions: List[Collision] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Link):
return NotImplemented
return (
self.name == other.name
and self.inertial == other.inertial
and all(self_visual in other.visuals for self_visual in self.visuals)
and all(other_visual in self.visuals for other_visual in other.visuals)
and all(
self_collision in other.collisions for self_collision in self.collisions
)
and all(
other_collision in self.collisions
for other_collision in other.collisions
)
)


@dataclass
class Dynamics:
Expand All @@ -150,7 +274,7 @@ class Limit:
upper: Optional[float] = None


@dataclass
@dataclass(eq=False)
class Joint:
name: str
type: str = None
Expand All @@ -164,15 +288,53 @@ class Joint:
calibration: Optional[Calibration] = None
safety_controller: Optional[SafetyController] = None

def __eq__(self, other):
if not isinstance(other, Joint):
return NotImplemented
return (
self.name == other.name
and self.type == other.type
and self.parent == other.parent
and self.child == other.child
and _array_eq(self.origin, other.origin)
and _array_eq(self.axis, other.axis)
and self.dynamics == other.dynamics
and self.limit == other.limit
and self.mimic == other.mimic
and self.calibration == other.calibration
and self.safety_controller == other.safety_controller
)

@dataclass

@dataclass(eq=False)
class Robot:
name: str
links: List[Link] = field(default_factory=list)
joints: List[Joint] = field(default_factory=list)
transmission: List[str] = field(default_factory=list)
gazebo: List[str] = field(default_factory=list)

def __eq__(self, other):
if not isinstance(other, Robot):
return NotImplemented
return (
self.name == other.name
and all(self_link in other.links for self_link in self.links)
and all(other_link in self.links for other_link in other.links)
and all(self_joint in other.joints for self_joint in self.joints)
and all(other_joint in self.joints for other_joint in other.joints)
and all(
self_transmission in other.transmission
for self_transmission in self.transmission
)
and all(
other_transmission in self.transmission
for other_transmission in other.transmission
)
and all(self_gazebo in other.gazebo for self_gazebo in self.gazebo)
and all(other_gazebo in self.gazebo for other_gazebo in other.gazebo)
)


class URDFError(Exception):
"""General URDF exception."""
Expand Down Expand Up @@ -1982,3 +2144,8 @@ def _write_robot(self, robot):
self._write_joint(xml_element, joint)

return xml_element

def __eq__(self, other):
if not isinstance(other, URDF):
raise NotImplemented
return self.robot == other.robot
23 changes: 23 additions & 0 deletions tests/test_urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,26 @@ def test_validate():
def test_mimic_joint():
urdf_fname = os.path.join(DIR_MODELS, "franka", "franka.urdf")
urdf_model = urdf.URDF.load(urdf_fname)

assert True


def test_equality():
urdf_fname = os.path.join(DIR_MODELS, "franka", "franka.urdf")
urdf_model_0 = urdf.URDF.load(urdf_fname)

urdf_model_1 = urdf.URDF.load(urdf_fname)

assert urdf_model_0 == urdf_model_1


def test_equality_different_link_order():
robot_0 = _create_robot()
robot_0.links.append(urdf.Link(name="link_0"))
robot_0.links.append(urdf.Link(name="link_1"))

robot_1 = _create_robot()
robot_1.links.append(urdf.Link(name="link_1"))
robot_1.links.append(urdf.Link(name="link_0"))

assert robot_0 == robot_1