diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..b6a0f46 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,36 @@ +name: Core Tests + +on: + push: + branches: [ main, init_dev ] + pull_request: + +permissions: + contents: read + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [ "3.7", "3.8", "3.9", "3.10" ] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install --upgrade -e .[dev] + - name: Run Black + run: | + black dex-retargeting/ tests/ --check + - name: Run Pyright + run: | + pyright + - name: Test with pytest + run: | + pytest \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e8ac431 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +build +.vscode +.pyc +*.blend +cmake-build-debug +/cmake-build-debug/ +/.ccls-cache/ +/.dir-locals.el +__pycache__ +.ini +*.convex.* +imgui.ini +.mypy_cache +.DS_Store +/.idea +/log diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..a9822ef --- /dev/null +++ b/.gitmodules @@ -0,0 +1,4 @@ +[submodule "assets"] + path = assets + url = https://github.com/dexsuite/dex-urdf.git + branch = init_dev diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..15904fb --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2023 Yuzhe Qin + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..e12dc25 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +Dex Retargeting +--- +

+ + + + License + +

\ No newline at end of file diff --git a/assets b/assets new file mode 160000 index 0000000..32ce565 --- /dev/null +++ b/assets @@ -0,0 +1 @@ +Subproject commit 32ce565393e9e44948fb4f86c55964ef65c2d33b diff --git a/dex_retargeting/__init__.py b/dex_retargeting/__init__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/dex_retargeting/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/dex_retargeting/configs/offline/allegro_hand_right.yml b/dex_retargeting/configs/offline/allegro_hand_right.yml new file mode 100644 index 0000000..fcada15 --- /dev/null +++ b/dex_retargeting/configs/offline/allegro_hand_right.yml @@ -0,0 +1,12 @@ +retargeting: + type: position + urdf_path: allegro_hand/allegro_hand_right.urdf + use_camera_frame_retargeting: False + + target_joint_names: null + target_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] + + target_link_human_indices: [ 4, 8, 12, 16 ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 diff --git a/dex_retargeting/configs/teleop/allegro_hand_left.yml b/dex_retargeting/configs/teleop/allegro_hand_left.yml new file mode 100644 index 0000000..ff6bf91 --- /dev/null +++ b/dex_retargeting/configs/teleop/allegro_hand_left.yml @@ -0,0 +1,17 @@ +retargeting: + type: vector + urdf_path: allegro_hand/allegro_hand_left.urdf + use_camera_frame_retargeting: False + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ] + target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] + scaling_factor: 1.6 + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 diff --git a/dex_retargeting/configs/teleop/allegro_hand_right.yml b/dex_retargeting/configs/teleop/allegro_hand_right.yml new file mode 100644 index 0000000..11296e1 --- /dev/null +++ b/dex_retargeting/configs/teleop/allegro_hand_right.yml @@ -0,0 +1,17 @@ +retargeting: + type: vector + urdf_path: allegro_hand/allegro_hand_right.urdf + use_camera_frame_retargeting: False + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ] + target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] + scaling_factor: 1.6 + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 diff --git a/dex_retargeting/configs/teleop/allegro_hand_right_dexpilot.yml b/dex_retargeting/configs/teleop/allegro_hand_right_dexpilot.yml new file mode 100644 index 0000000..dd5883d --- /dev/null +++ b/dex_retargeting/configs/teleop/allegro_hand_right_dexpilot.yml @@ -0,0 +1,13 @@ +retargeting: + type: DexPilot + urdf_path: allegro_hand/allegro_hand_right.urdf + use_camera_frame_retargeting: False + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + wrist_link_name: "wrist" + finger_tip_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ] + scaling_factor: 1.6 + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 diff --git a/dex_retargeting/configs/teleop/schunk_svh_hand_right.yml b/dex_retargeting/configs/teleop/schunk_svh_hand_right.yml new file mode 100644 index 0000000..bf622e3 --- /dev/null +++ b/dex_retargeting/configs/teleop/schunk_svh_hand_right.yml @@ -0,0 +1,17 @@ +retargeting: + type: vector + urdf_path: schunk_hand/schunk_svh_hand_right.urdf + use_camera_frame_retargeting: False + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + target_origin_link_names: [ "right_hand_base_link","right_hand_base_link", "right_hand_base_link", "right_hand_base_link", "right_hand_base_link", ] + target_task_link_names: [ "right_hand_c", "right_hand_t", "right_hand_s", "right_hand_r", "right_hand_q" ] + scaling_factor: 1.1 + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + target_link_human_indices: [ [ 0, 0, 0, 0, 0 ], [ 4, 8, 12, 16, 20, ] ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 diff --git a/dex_retargeting/configs/teleop/shadow_hand_right.yml b/dex_retargeting/configs/teleop/shadow_hand_right.yml new file mode 100644 index 0000000..b1835c1 --- /dev/null +++ b/dex_retargeting/configs/teleop/shadow_hand_right.yml @@ -0,0 +1,17 @@ +retargeting: + type: vector + urdf_path: shadow_hand/shadow_hand_right.urdf + use_camera_frame_retargeting: False + + # Target refers to the retargeting target, which is the robot hand + target_joint_names: null + target_origin_link_names: [ "palm", "palm", "palm", "palm", "palm", "palm", "palm", "palm", "palm", "palm" ] + target_task_link_names: [ "thtip", "fftip", "mftip", "rftip", "lftip", "thmiddle", "ffmiddle", "mfmiddle", "rfmiddle", "lfmiddle" ] + scaling_factor: 1.2 + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + target_link_human_indices: [ [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], [ 4, 8, 12, 16, 20, 2, 6, 10, 14, 18 ] ] + + # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency + low_pass_alpha: 0.2 diff --git a/dex_retargeting/optimizer.py b/dex_retargeting/optimizer.py new file mode 100644 index 0000000..6dda050 --- /dev/null +++ b/dex_retargeting/optimizer.py @@ -0,0 +1,475 @@ +from abc import abstractmethod +from typing import List + +import nlopt +import numpy as np +import sapien.core as sapien +import torch + + +class Optimizer: + retargeting_type = "BASE" + + def __init__( + self, robot: sapien.Articulation, target_joint_names: List[str], target_link_human_indices: np.ndarray + ): + self.robot = robot + self.robot_dof = robot.dof + self.model = robot.create_pinocchio_model() + + joint_names = [joint.get_name() for joint in robot.get_active_joints()] + target_joint_index = [] + for target_joint_name in target_joint_names: + if target_joint_name not in joint_names: + raise ValueError(f"Joint {target_joint_name} given does not appear to be in robot XML.") + target_joint_index.append(joint_names.index(target_joint_name)) + self.target_joint_names = target_joint_names + self.target_joint_indices = np.array(target_joint_index) + self.fixed_joint_indices = np.array([i for i in range(robot.dof) if i not in target_joint_index], dtype=int) + self.opt = nlopt.opt(nlopt.LD_SLSQP, len(target_joint_index)) + self.dof = len(target_joint_index) + + # Target + self.target_link_human_indices = target_link_human_indices + + def set_joint_limit(self, joint_limits: np.ndarray): + if joint_limits.shape != (self.dof, 2): + raise ValueError(f"Expect joint limits have shape: {(self.dof, 2)}, but get {joint_limits.shape}") + self.opt.set_lower_bounds(joint_limits[:, 0].tolist()) + self.opt.set_upper_bounds(joint_limits[:, 1].tolist()) + + def get_last_result(self): + return self.opt.last_optimize_result() + + def get_link_names(self): + return [link.get_name() for link in self.robot.get_links()] + + def get_link_indices(self, target_link_names): + target_link_index = [] + for target_link_name in target_link_names: + if target_link_name not in self.get_link_names(): + raise ValueError(f"Body {target_link_name} given does not appear to be in robot XML.") + target_link_index.append(self.get_link_names().index(target_link_name)) + return target_link_index + + @abstractmethod + def retarget(self, ref_value, fixed_qpos, last_qpos=None): + pass + + def optimize(self, objective_fn, last_qpos): + self.opt.set_min_objective(objective_fn) + try: + qpos = self.opt.optimize(last_qpos) + except RuntimeError as e: + print(e) + return np.array(last_qpos) + return qpos + + +class PositionOptimizer(Optimizer): + retargeting_type = "position" + + def __init__( + self, + robot: sapien.Articulation, + target_joint_names: List[str], + target_link_names: List[str], + target_link_human_indices: np.ndarray, + huber_delta=0.02, + norm_delta=4e-3, + ): + super().__init__(robot, target_joint_names, target_link_human_indices) + self.body_names = target_link_names + self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta) + self.norm_delta = norm_delta + + # Sanity check and cache link indices + self.target_link_indices = self.get_link_indices(target_link_names) + + # Use local jacobian if target link name <= 2, otherwise first cache all jacobian and then get all + # This is only for the speed but will not affect the performance + if len(target_link_names) <= 40: + self.use_sparse_jacobian = True + else: + self.use_sparse_jacobian = False + self.opt.set_ftol_abs(1e-5) + + def _get_objective_function(self, target_pos: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray): + qpos = np.zeros(self.robot_dof) + qpos[self.fixed_joint_indices] = fixed_qpos + torch_target_pos = torch.as_tensor(target_pos) + torch_target_pos.requires_grad_(False) + + def objective(x: np.ndarray, grad: np.ndarray) -> float: + qpos[self.target_joint_indices] = x + self.model.compute_forward_kinematics(qpos) + target_link_poses = [self.model.get_link_pose(index) for index in self.target_link_indices] + body_pos = np.array([pose.p for pose in target_link_poses]) + + # Torch computation for accurate loss and grad + torch_body_pos = torch.as_tensor(body_pos) + torch_body_pos.requires_grad_() + + # Loss term for kinematics retargeting based on 3D position error + huber_distance = self.huber_loss(torch_body_pos, torch_target_pos) + # huber_distance = torch.norm(torch_body_pos - torch_target_pos, dim=1).mean() + result = huber_distance.cpu().detach().item() + + if grad.size > 0: + if self.use_sparse_jacobian: + jacobians = [] + for i, index in enumerate(self.target_link_indices): + link_spatial_jacobian = self.model.compute_single_link_local_jacobian(qpos, index)[ + :3, self.target_joint_indices + ] + link_rot = self.model.get_link_pose(index).to_transformation_matrix()[:3, :3] + link_kinematics_jacobian = link_rot @ link_spatial_jacobian + jacobians.append(link_kinematics_jacobian) + jacobians = np.stack(jacobians, axis=0) + else: + self.model.compute_full_jacobian(qpos) + jacobians = [ + self.model.get_link_jacobian(index, local=True)[:3, self.target_joint_indices] + for index in self.target_link_indices + ] + + huber_distance.backward() + grad_pos = torch_body_pos.grad.cpu().numpy()[:, None, :] + grad_qpos = np.matmul(grad_pos, np.array(jacobians)) + grad_qpos = grad_qpos.mean(1).sum(0) + + grad_qpos += 2 * self.norm_delta * (x - last_qpos) + + grad[:] = grad_qpos[:] + + return result + + return objective + + def retarget(self, ref_value, fixed_qpos, last_qpos=None): + if len(fixed_qpos) != len(self.fixed_joint_indices): + raise ValueError( + f"Optimizer has {len(self.fixed_joint_indices)} joints but non_target_qpos {fixed_qpos} is given" + ) + if last_qpos is None: + last_qpos = np.zeros(self.dof) + if isinstance(last_qpos, np.ndarray): + last_qpos = last_qpos.astype(np.float32) + last_qpos = list(last_qpos) + objective_fn = self._get_objective_function(ref_value, fixed_qpos, np.array(last_qpos).astype(np.float32)) + return self.optimize(objective_fn, last_qpos) + + +class VectorOptimizer(Optimizer): + retargeting_type = "VECTOR" + + def __init__( + self, + robot: sapien.Articulation, + target_joint_names: List[str], + target_origin_link_names: List[str], + target_task_link_names: List[str], + target_link_human_indices: np.ndarray, + huber_delta=0.02, + norm_delta=4e-3, + scaling=1.0, + ): + super().__init__(robot, target_joint_names, target_link_human_indices) + self.origin_link_names = target_origin_link_names + self.task_link_names = target_task_link_names + self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta) + self.norm_delta = norm_delta + self.scaling = scaling + + # Computation cache for better performance + # For one link used in multiple vectors, e.g. hand palm, we do not want to compute it multiple times + self.computed_link_names = list(set(target_origin_link_names).union(set(target_task_link_names))) + self.origin_link_indices = torch.tensor( + [self.computed_link_names.index(name) for name in target_origin_link_names] + ) + self.task_link_indices = torch.tensor([self.computed_link_names.index(name) for name in target_task_link_names]) + + # Sanity check and cache link indices + self.robot_link_indices = self.get_link_indices(self.computed_link_names) + + # Use local jacobian if target link name <= 2, otherwise first cache all jacobian and then get all + # This is only for the speed but will not affect the performance + if len(self.computed_link_names) <= 40: + self.use_sparse_jacobian = True + else: + self.use_sparse_jacobian = False + self.opt.set_ftol_abs(1e-6) + + def _get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray): + qpos = np.zeros(self.robot_dof) + qpos[self.fixed_joint_indices] = fixed_qpos + torch_target_vec = torch.as_tensor(target_vector) * self.scaling + torch_target_vec.requires_grad_(False) + + def objective(x: np.ndarray, grad: np.ndarray) -> float: + qpos[self.target_joint_indices] = x + self.model.compute_forward_kinematics(qpos) + target_link_poses = [self.model.get_link_pose(index) for index in self.robot_link_indices] + body_pos = np.array([pose.p for pose in target_link_poses]) + + # Torch computation for accurate loss and grad + torch_body_pos = torch.as_tensor(body_pos) + torch_body_pos.requires_grad_() + + # Index link for computation + origin_link_pos = torch_body_pos[self.origin_link_indices, :] + task_link_pos = torch_body_pos[self.task_link_indices, :] + robot_vec = task_link_pos - origin_link_pos + + # Loss term for kinematics retargeting based on 3D position error + huber_distance = self.huber_loss(robot_vec, torch_target_vec) + result = huber_distance.cpu().detach().item() + + if grad.size > 0: + if self.use_sparse_jacobian: + jacobians = [] + for i, index in enumerate(self.robot_link_indices): + link_spatial_jacobian = self.model.compute_single_link_local_jacobian(qpos, index)[ + :3, self.target_joint_indices + ] + link_rot = self.model.get_link_pose(index).to_transformation_matrix()[:3, :3] + link_kinematics_jacobian = link_rot @ link_spatial_jacobian + jacobians.append(link_kinematics_jacobian) + jacobians = np.stack(jacobians, axis=0) + else: + self.model.compute_full_jacobian(qpos) + jacobians = [ + self.model.get_link_jacobian(index, local=True)[:3, self.target_joint_indices] + for index in self.robot_link_indices + ] + + huber_distance.backward() + grad_pos = torch_body_pos.grad.cpu().numpy()[:, None, :] + grad_qpos = np.matmul(grad_pos, np.array(jacobians)) + grad_qpos = grad_qpos.mean(1).sum(0) + + grad_qpos += 2 * self.norm_delta * (x - last_qpos) + + grad[:] = grad_qpos[:] + + return result + + return objective + + def retarget(self, ref_value, fixed_qpos, last_qpos=None): + if len(fixed_qpos) != len(self.fixed_joint_indices): + raise ValueError( + f"Optimizer has {len(self.fixed_joint_indices)} joints but non_target_qpos {fixed_qpos} is given" + ) + if last_qpos is None: + last_qpos = np.zeros(self.dof) + last_qpos = list(last_qpos) + objective_fn = self._get_objective_function(ref_value, fixed_qpos, np.array(last_qpos).astype(np.float32)) + return self.optimize(objective_fn, last_qpos) + + +class DexPilotAllegroOptimizer(Optimizer): + retargeting_type = "DEXPILOT" + + def __init__( + self, + robot: sapien.Articulation, + target_joint_names: List[str], + finger_tip_link_names: List[str], + wrist_link_name: str, + # DexPilot parameters + gamma=2.5e-3, + project_dist=0.03, + escape_dist=0.05, + eta1=1e-4, + eta2=3e-2, + scaling=1.0, + ): + is_four_finger = len(finger_tip_link_names) == 4 + link_names = [wrist_link_name] + finger_tip_link_names + if is_four_finger: + origin_link_index = [2, 3, 4, 3, 4, 4, 0, 0, 0, 0] + task_link_index = [1, 1, 1, 2, 2, 3, 1, 2, 3, 4] + target_origin_link_names = [link_names[index] for index in origin_link_index] + target_task_link_names = [link_names[index] for index in task_link_index] + target_link_human_indices = np.array( + [[8, 12, 16, 12, 16, 16, 0, 0, 0, 0], [4, 4, 4, 8, 8, 12, 4, 8, 12, 16]] + ) + else: + raise NotImplementedError + + super().__init__(robot, target_joint_names, target_link_human_indices) + self.origin_link_names = target_origin_link_names + self.task_link_names = target_task_link_names + self.scaling = scaling + # self.norm_delta = norm_delta + # self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta) + + # DexPilot parameters + self.gamma = gamma + self.project_dist = project_dist + self.escape_dist = escape_dist + self.eta1 = eta1 + self.eta2 = eta2 + + # Computation cache for better performance + # For one link used in multiple vectors, e.g. hand palm, we do not want to compute it multiple times + self.computed_link_names = list(set(target_origin_link_names).union(set(target_task_link_names))) + self.origin_link_indices = torch.tensor( + [self.computed_link_names.index(name) for name in target_origin_link_names] + ) + self.task_link_indices = torch.tensor([self.computed_link_names.index(name) for name in target_task_link_names]) + + # Sanity check and cache link indices + self.robot_link_indices = self.get_link_indices(self.computed_link_names) + + # Use local jacobian if target link name <= 2, otherwise first cache all jacobian and then get all + # This is only for the speed but will not affect the performance + if len(self.computed_link_names) <= 40: + self.use_sparse_jacobian = True + else: + self.use_sparse_jacobian = False + self.opt.set_ftol_abs(1e-6) + + # DexPilot cache + self.projected = np.zeros(6, dtype=bool) + self.s2_project_index_origin = np.array([1, 2, 2], dtype=int) + self.s2_project_index_task = np.array([0, 0, 1], dtype=int) + self.projected_dist = np.array([eta1] * 3 + [eta2] * 3) + + def _get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray): + target_vector = target_vector.astype(np.float32) + qpos = np.zeros(self.robot_dof) + qpos[self.fixed_joint_indices] = fixed_qpos + + # Update projection indicator + target_vec_dist = np.linalg.norm(target_vector[:6], axis=1) + self.projected[:3][target_vec_dist[0:3] < self.project_dist] = True + self.projected[:3][target_vec_dist[0:3] > self.escape_dist] = False + self.projected[3:6] = np.logical_and( + self.projected[:3][self.s2_project_index_origin], self.projected[:3][self.s2_project_index_task] + ) + + # Update weight vector + normal_weight = np.ones(6, dtype=np.float32) + high_weight = np.array([200] * 3 + [400] * 3, dtype=np.float32) + weight = np.where(self.projected, high_weight, normal_weight) + weight = torch.from_numpy(np.concatenate([weight, np.ones(4, dtype=np.float32) * 10])) + + # Compute reference distance vector + normal_vec = target_vector * self.scaling # (10, 3) + dir_vec = target_vector[:6] / (target_vec_dist[:, None] + 1e-6) # (6, 3) + projected_vec = dir_vec * self.projected_dist[:, None] # (6, 3) + + # Compute final reference vector + reference_vec = np.where(self.projected[:, None], projected_vec, normal_vec[:6]) # (6, 3) + reference_vec = np.concatenate([reference_vec, normal_vec[6:10]], axis=0) # (10, 3) + torch_ref_vec = torch.as_tensor(reference_vec, dtype=torch.float32) + torch_ref_vec.requires_grad_(False) + + def objective(x: np.ndarray, grad: np.ndarray) -> float: + qpos[self.target_joint_indices] = x + self.model.compute_forward_kinematics(qpos) + target_link_poses = [self.model.get_link_pose(index) for index in self.robot_link_indices] + body_pos = np.array([pose.p for pose in target_link_poses]) + + # Torch computation for accurate loss and grad + torch_body_pos = torch.as_tensor(body_pos) + torch_body_pos.requires_grad_() + + # Index link for computation + origin_link_pos = torch_body_pos[self.origin_link_indices, :] + task_link_pos = torch_body_pos[self.task_link_indices, :] + robot_vec = task_link_pos - origin_link_pos + + # Loss term for kinematics retargeting based on 3D position error + error = robot_vec - torch_ref_vec + mse_loss = torch.sum(error * error, dim=1) # (10) + weighted_mse_loss = torch.sum(mse_loss * weight) + result = weighted_mse_loss.cpu().detach().item() + + if grad.size > 0: + if self.use_sparse_jacobian: + jacobians = [] + for i, index in enumerate(self.robot_link_indices): + link_spatial_jacobian = self.model.compute_single_link_local_jacobian(qpos, index)[ + :3, self.target_joint_indices + ] + link_rot = self.model.get_link_pose(index).to_transformation_matrix()[:3, :3] + link_kinematics_jacobian = link_rot @ link_spatial_jacobian + jacobians.append(link_kinematics_jacobian) + jacobians = np.stack(jacobians, axis=0) + else: + self.model.compute_full_jacobian(qpos) + jacobians = [ + self.model.get_link_jacobian(index, local=True)[:3, self.target_joint_indices] + for index in self.robot_link_indices + ] + + weighted_mse_loss.backward() + grad_pos = torch_body_pos.grad.cpu().numpy()[:, None, :] + grad_qpos = np.matmul(grad_pos, np.array(jacobians)) + grad_qpos = grad_qpos.mean(1).sum(0) + + # Finally, γ = 2.5 × 10−3 is a weight on regularizing the Allegro angles to zero + # (equivalent to fully opened the hand) + grad_qpos += 2 * self.gamma * x + + grad[:] = grad_qpos[:] + + return result + + return objective + + def retarget(self, ref_value, fixed_qpos, last_qpos=None): + if len(fixed_qpos) != len(self.fixed_joint_indices): + raise ValueError( + f"Optimizer has {len(self.fixed_joint_indices)} joints but non_target_qpos {fixed_qpos} is given" + ) + if last_qpos is None: + last_qpos = np.zeros(self.dof) + objective_fn = self._get_objective_function(ref_value, fixed_qpos, np.array(last_qpos).astype(np.float32)) + return self.optimize(objective_fn, last_qpos) + + +class DexPilotAllegroV4Optimizer(DexPilotAllegroOptimizer): + origin_link_names = [ + "link_3.0_tip", + "link_7.0_tip", + "link_11.0_tip", + "link_7.0_tip", + "link_11.0_tip", + "link_11.0_tip", + "wrist", + "wrist", + "wrist", + "wrist", # root + ] + task_link_names = [ + "link_15.0_tip", + "link_15.0_tip", + "link_15.0_tip", + "link_7.0_tip", + "link_7.0_tip", + "link_11.0_tip", + "link_15.0_tip", + "link_3.0_tip", + "link_7.0_tip", + "link_11.0_tip", + ] + target_link_human_indices = np.array([[8, 12, 16, 12, 16, 16, 0, 0, 0, 0], [4, 4, 4, 8, 8, 12, 4, 8, 12, 16]]) + + def __init__( + self, + robot: sapien.Articulation, + target_joint_names: List[str], + # DexPilot parameters + gamma=2.5e-3, + project_dist=0.03, + escape_dist=0.05, + eta1=1e-4, + eta2=3e-2, + scaling=2.0, + ): + super().__init__(robot, target_joint_names, gamma, project_dist, escape_dist, eta1, eta2) + self.scaling = scaling diff --git a/dex_retargeting/optimizer_utils.py b/dex_retargeting/optimizer_utils.py new file mode 100644 index 0000000..430fc7f --- /dev/null +++ b/dex_retargeting/optimizer_utils.py @@ -0,0 +1,112 @@ +import numpy as np +import sapien.core as sapien +from sapien.core import Pose +from transforms3d.euler import euler2quat + + +class LPFilter: + def __init__(self, alpha): + self.alpha = alpha + self.y = None + self.is_init = False + + def next(self, x): + if not self.is_init: + self.y = x + self.is_init = True + return self.y.copy() + self.y = self.y + self.alpha * (x - self.y) + return self.y.copy() + + def reset(self): + self.y = None + self.is_init = False + + +def add_dummy_free_joint( + robot_builder: sapien.ArticulationBuilder, + joint_indicator=(True, True, True, True, True, True), + translation_range=(-1, 1), + rotation_range=(-np.pi, np.pi), +): + assert len(joint_indicator) == 6 + new_root = robot_builder.create_link_builder() + parent = new_root + + # Prepare link and joint properties + joint_types = ["prismatic"] * 3 + ["revolute"] * 3 + joint_limit = [translation_range] * 3 + [rotation_range] * 3 + joint_name = [f"dummy_{name}_translation_joint" for name in "xyz"] + [ + f"dummy_{name}_rotation_joint" for name in "xyz" + ] + link_name = [f"dummy_{name}_translation_link" for name in "xyz"] + [f"dummy_{name}_rotation_link" for name in "xyz"] + + # Find root link which has no parent + root_link_builder = None + for link_builder in robot_builder.get_link_builders(): + if link_builder.get_parent() == -1: + root_link_builder = link_builder + break + assert root_link_builder is not None + + # Build free root + valid_joint_num = 0 + for i in range(6): + # Joint orders are x,y,z translation and then x,y,z rotation, 6 in total + # If joint indicator for specific index is False, do not build this joint + if not joint_indicator[i]: + continue + valid_joint_num += 1 + + # Add small inertia property for more stable simulation + parent.set_mass_and_inertia(1e-4, Pose(np.zeros(3)), np.ones(3) * 1e-6) + parent.set_name(link_name[i]) + + # The last joint will connect the last dummy link to the original root link of the robot + if valid_joint_num < sum(joint_indicator): + child = robot_builder.create_link_builder(parent) + else: + child = root_link_builder + child.set_parent(parent.get_index()) + child.set_joint_name(joint_name[i]) + + # Build joint + if i == 3 or i == 0: + child.set_joint_properties(joint_types[i], limits=np.array([joint_limit[i]])) + elif i == 4 or i == 1: + child.set_joint_properties( + joint_types[i], + limits=np.array([joint_limit[i]]), + pose_in_child=Pose(q=euler2quat(0, 0, np.pi / 2)), + pose_in_parent=Pose(q=euler2quat(0, 0, np.pi / 2)), + ) + elif i == 2 or i == 5: + child.set_joint_properties( + joint_types[i], + limits=np.array([joint_limit[i]]), + pose_in_parent=Pose(q=euler2quat(0, -np.pi / 2, 0)), + pose_in_child=Pose(q=euler2quat(0, -np.pi / 2, 0)), + ) + parent = child + + return parent + + +class SAPIENKinematicsModelStandalone: + def __init__(self, urdf_path, add_dummy_translation=False, add_dummy_rotation=False): + self.engine = sapien.Engine() + self.scene = self.engine.create_scene() + loader = self.scene.create_urdf_loader() + + builder = loader.load_file_as_articulation_builder(urdf_path) + if add_dummy_rotation or add_dummy_translation: + dummy_joint_indicator = (add_dummy_translation,) * 3 + (add_dummy_rotation,) * 3 + add_dummy_free_joint(builder, dummy_joint_indicator) + self.robot = builder.build(fix_root_link=True) + self.robot.set_pose(sapien.Pose()) + self.robot.set_qpos(np.zeros(self.robot.dof)) + self.scene.step() + + def release(self): + self.scene = None + self.engine = None diff --git a/dex_retargeting/retargeting_config.py b/dex_retargeting/retargeting_config.py new file mode 100644 index 0000000..c075aea --- /dev/null +++ b/dex_retargeting/retargeting_config.py @@ -0,0 +1,199 @@ +from dataclasses import dataclass +from typing import List, Optional, Dict +from pathlib import Path + +import numpy as np +import yaml + +from dex_retargeting.optimizer_utils import LPFilter +from dex_retargeting.seq_retarget import SeqRetargeting + + +@dataclass +class RetargetingConfig: + type: str + urdf_path: str + use_camera_frame_retargeting: bool + + # Source refers to the retargeting input, which usually corresponds to the human hand + # The joint indices of human hand joint which corresponds to each link in the target_link_names + + target_link_human_indices: Optional[np.ndarray] = None + + # Position retargeting link names + target_link_names: Optional[List[str]] = None + + # Vector retargeting link names + target_joint_names: Optional[List[str]] = None + target_origin_link_names: Optional[List[str]] = None + target_task_link_names: Optional[List[str]] = None + + # DexPilot retargeting link names + finger_tip_link_names: Optional[List[str]] = None + wrist_link_name: Optional[str] = None + + # Scaling factor for vector retargeting only + # For example, Allegro is 1.6 times larger than normal human hand, then this scaling factor should be 1.6 + scaling_factor: float = 1.0 + + # Optimization hyperparameter + normal_delta: float = 4e-3 + huber_delta: float = 2e-2 + + # Joint limit tag + has_joint_limits: bool = True + + # Low pass filter + low_pass_alpha: float = 0.1 + + _TYPE = ["vector", "position", "dexpilot"] + _DEFAULT_URDF_DIR = "./" + + def __post_init__(self): + # Retargeting type check + self.type = self.type.lower() + if self.type not in self._TYPE: + raise ValueError(f"Retargeting type must be one of {self._TYPE}") + + # Vector retargeting requires: target_origin_link_names + target_task_link_names + # Position retargeting requires: target_link_names + if self.type == "vector": + if self.target_origin_link_names is None or self.target_task_link_names is None: + raise ValueError(f"Vector retargeting requires: target_origin_link_names + target_task_link_names") + if len(self.target_task_link_names) != len(self.target_origin_link_names): + raise ValueError(f"Vector retargeting origin and task links dim mismatch") + if self.target_link_human_indices.shape != (2, len(self.target_origin_link_names)): + raise ValueError(f"Vector retargeting link names and link indices dim mismatch") + if self.target_link_human_indices is None: + raise ValueError(f"Vector retargeting requires: target_link_human_indices") + + elif self.type == "position": + if self.target_link_names is None: + raise ValueError(f"Position retargeting requires: target_link_names") + self.target_link_human_indices = self.target_link_human_indices.squeeze() + if self.target_link_human_indices.shape != (len(self.target_link_human_indices),): + raise ValueError(f"Position retargeting link names and link indices dim mismatch") + if self.target_link_human_indices is None: + raise ValueError(f"Position retargeting requires: target_link_human_indices") + + elif self.type == "dexpilot": + if self.finger_tip_link_names is None or self.wrist_link_name is None: + raise ValueError(f"Position retargeting requires: finger_tip_link_names + wrist_link_name") + + # URDF path check + urdf_path = Path(self.urdf_path) + if not urdf_path.is_absolute(): + urdf_path = self._DEFAULT_URDF_DIR / urdf_path + urdf_path = urdf_path.absolute() + if not urdf_path.exists(): + raise ValueError(f"URDF path {urdf_path} does not exist") + self.urdf_path = str(urdf_path) + + @classmethod + def set_default_urdf_dir(cls, urdf_dir): + path = Path(urdf_dir) + if not path.exists(): + raise ValueError(f"URDF dir {urdf_dir} not exists.") + cls._DEFAULT_URDF_DIR = urdf_dir + + @classmethod + def load_from_file(cls, config_path, override: Optional[Dict] = None): + path = Path(config_path) + if not path.is_absolute(): + path = path.absolute() + + with path.open("r") as f: + yaml_config = yaml.load(f, Loader=yaml.FullLoader) + cfg = yaml_config["retargeting"] + if "target_link_human_indices" in cfg: + cfg["target_link_human_indices"] = np.array(cfg["target_link_human_indices"]) + if override is not None: + for key, value in override.items(): + cfg[key] = value + config = RetargetingConfig(**cfg) + return config + + def build(self) -> SeqRetargeting: + from dex_retargeting.optimizer import ( + VectorOptimizer, + PositionOptimizer, + DexPilotAllegroOptimizer, + DexPilotAllegroV4Optimizer, + ) + from dex_retargeting.optimizer_utils import SAPIENKinematicsModelStandalone + from dex_retargeting import yourdfpy as urdf + import tempfile + + # Process the URDF with yourdfpy to better find file path + robot_urdf = urdf.URDF.load(self.urdf_path) + urdf_name = self.urdf_path.split("/")[-1] + temp_dir = tempfile.mkdtemp(prefix="teleop-") + temp_path = f"{temp_dir}/{urdf_name}" + robot_urdf.write_xml_file(temp_path) + sapien_model = SAPIENKinematicsModelStandalone(temp_path, add_dummy_rotation=self.use_camera_frame_retargeting) + robot = sapien_model.robot + joint_names = ( + self.target_joint_names + if self.target_joint_names is not None + else [joint.get_name() for joint in robot.get_active_joints()] + ) + if self.type == "position": + optimizer = PositionOptimizer( + robot, + joint_names, + target_link_names=self.target_link_names, + target_link_human_indices=self.target_link_human_indices, + norm_delta=self.normal_delta, + huber_delta=self.huber_delta, + ) + elif self.type == "vector": + optimizer = VectorOptimizer( + robot, + joint_names, + target_origin_link_names=self.target_origin_link_names, + target_task_link_names=self.target_task_link_names, + target_link_human_indices=self.target_link_human_indices, + scaling=self.scaling_factor, + norm_delta=self.normal_delta, + huber_delta=self.huber_delta, + ) + elif self.type == "dexpilot": + optimizer = DexPilotAllegroOptimizer( + robot, + joint_names, + finger_tip_link_names=self.finger_tip_link_names, + wrist_link_name=self.wrist_link_name, + scaling=self.scaling_factor, + ) + else: + raise RuntimeError() + + if 0 <= self.low_pass_alpha <= 1: + lp_filter = LPFilter(self.low_pass_alpha) + else: + lp_filter = None + + retargeting = SeqRetargeting( + optimizer, + has_joint_limits=self.has_joint_limits, + use_camera_frame_retargeting=self.use_camera_frame_retargeting, + lp_filter=lp_filter, + ) + # TODO: hack here for SAPIEN + retargeting.scene = sapien_model.scene + return retargeting + + +def get_retargeting_config(config_path) -> RetargetingConfig: + config = RetargetingConfig.load_from_file(config_path) + return config + + +if __name__ == "__main__": + # Path below is relative to this file + from pathlib import Path + + test_config = get_retargeting_config(str(Path(__file__).parent / "configs/allegro_hand.yml")) + print(test_config) + opt = test_config.build() + print(opt.optimizer.target_link_human_indices) diff --git a/dex_retargeting/seq_retarget.py b/dex_retargeting/seq_retarget.py new file mode 100644 index 0000000..1e5e097 --- /dev/null +++ b/dex_retargeting/seq_retarget.py @@ -0,0 +1,66 @@ +import numpy as np +from time import time + +from dex_retargeting.optimizer import Optimizer +from dex_retargeting.optimizer_utils import LPFilter +from typing import Optional + + +class SeqRetargeting: + def __init__( + self, + optimizer: Optimizer, + has_joint_limits=True, + use_camera_frame_retargeting=False, + lp_filter: Optional[LPFilter] = None, + ): + self.optimizer = optimizer + robot = self.optimizer.robot + self.use_camera_frame_retargeting = use_camera_frame_retargeting + + # Joint limit + self.has_joint_limits = has_joint_limits + joint_limits = np.ones_like(robot.get_qlimits()) + joint_limits[:, 0] = -1e4 # a large value is equivalent to no limit + joint_limits[:, 1] = 1e4 + if has_joint_limits: + joint_limits[:] = robot.get_qlimits()[:] + self.optimizer.set_joint_limit(joint_limits[self.optimizer.target_joint_indices]) + self.joint_limits = joint_limits + + # Temporal information + self.last_qpos = joint_limits.mean(1)[self.optimizer.target_joint_indices] + self.accumulated_time = 0 + self.num_retargeting = 0 + + # Filter + self.filter = lp_filter + + # TODO: hack here + self.scene = None + + def retarget(self, ref_value, fixed_qpos=np.array([])): + tic = time() + qpos = self.optimizer.retarget( + ref_value=ref_value.astype(np.float32), + fixed_qpos=fixed_qpos.astype(np.float32), + last_qpos=self.last_qpos.astype(np.float32), + ) + self.accumulated_time += time() - tic + self.num_retargeting += 1 + self.last_qpos = qpos + robot_qpos = np.zeros(self.optimizer.robot.dof) + robot_qpos[self.optimizer.fixed_joint_indices] = fixed_qpos + robot_qpos[self.optimizer.target_joint_indices] = qpos + if self.filter is not None: + robot_qpos = self.filter.next(robot_qpos) + return robot_qpos + + def verbose(self): + min_value = self.optimizer.opt.last_optimum_value() + print(f"Retargeting {self.num_retargeting} times takes: {self.accumulated_time}s") + print(f"Last distance: {min_value}") + + def reset(self): + self.last_qpos = self.joint_limits.mean(1)[self.optimizer.target_joint_indices] + self.num_retargeting = 0 diff --git a/dex_retargeting/yourdfpy.py b/dex_retargeting/yourdfpy.py new file mode 100644 index 0000000..4555d36 --- /dev/null +++ b/dex_retargeting/yourdfpy.py @@ -0,0 +1,2178 @@ +# Code from yourdfpy with small modification for deprecated warning +# Source: https://github.com/clemense/yourdfpy/blob/main/src/yourdfpy/urdf.py + +import copy +import logging +import os +from dataclasses import dataclass, field, is_dataclass +from functools import partial +from typing import Dict, List, Optional, Union + +import anytree +import numpy as np +import six +import trimesh +import trimesh.transformations as tra +from anytree import Node, LevelOrderIter +from lxml import etree + +_logger = logging.getLogger(__name__) + + +def _array_eq(arr1, arr2): + if arr1 is None and arr2 is None: + return True + return ( + isinstance(arr1, np.ndarray) + and isinstance(arr2, np.ndarray) + and arr1.shape == arr2.shape + and (arr1 == arr2).all() + ) + + +@dataclass(eq=False) +class TransmissionJoint: + name: str + hardware_interfaces: List[str] = field(default_factory=list) + + def __eq__(self, other): + if not isinstance(other, TransmissionJoint): + return NotImplemented + return ( + self.name == other.name + and all(self_hi in other.hardware_interfaces for self_hi in self.hardware_interfaces) + and all(other_hi in self.hardware_interfaces for other_hi in other.hardware_interfaces) + ) + + +@dataclass(eq=False) +class Actuator: + name: str + mechanical_reduction: Optional[float] = None + # The follwing is only valid for ROS Indigo and prior versions + hardware_interfaces: List[str] = field(default_factory=list) + + def __eq__(self, other): + if not isinstance(other, Actuator): + return NotImplemented + return ( + self.name == other.name + and self.mechanical_reduction == other.mechanical_reduction + and all(self_hi in other.hardware_interfaces for self_hi in self.hardware_interfaces) + and all(other_hi in self.hardware_interfaces for other_hi in other.hardware_interfaces) + ) + + +@dataclass(eq=False) +class Transmission: + name: str + type: Optional[str] = None + joints: List[TransmissionJoint] = field(default_factory=list) + actuators: List[Actuator] = field(default_factory=list) + + def __eq__(self, other): + if not isinstance(other, Transmission): + return NotImplemented + return ( + self.name == other.name + and self.type == other.type + and all(self_joint in other.joints for self_joint in self.joints) + and all(other_joint in self.joints for other_joint in other.joints) + and all(self_actuator in other.actuators for self_actuator in self.actuators) + and all(other_actuator in self.actuators for other_actuator in other.actuators) + ) + + +@dataclass +class Calibration: + rising: Optional[float] = None + falling: Optional[float] = None + + +@dataclass +class Mimic: + joint: str + multiplier: Optional[float] = None + offset: Optional[float] = None + + +@dataclass +class SafetyController: + soft_lower_limit: Optional[float] = None + soft_upper_limit: Optional[float] = None + k_position: Optional[float] = None + k_velocity: Optional[float] = None + + +@dataclass +class Sphere: + radius: float + + +@dataclass +class Cylinder: + radius: float + length: float + + +@dataclass(eq=False) +class Box: + size: np.ndarray + + def __eq__(self, other): + if not isinstance(other, Box): + return NotImplemented + return _array_eq(self.size, other.size) + + +@dataclass(eq=False) +class Mesh: + filename: str + scale: Optional[Union[float, np.ndarray]] = None + + def __eq__(self, other): + if not isinstance(other, Mesh): + return NotImplemented + + if self.filename != other.filename: + return False + + if isinstance(self.scale, float) and isinstance(other.scale, float): + return self.scale == other.scale + + return _array_eq(self.scale, other.scale) + + +@dataclass +class Geometry: + box: Optional[Box] = None + cylinder: Optional[Cylinder] = None + sphere: Optional[Sphere] = None + mesh: Optional[Mesh] = None + + +@dataclass(eq=False) +class Color: + rgba: np.ndarray + + def __eq__(self, other): + if not isinstance(other, Color): + return NotImplemented + return _array_eq(self.rgba, other.rgba) + + +@dataclass +class Texture: + filename: str + + +@dataclass +class Material: + name: Optional[str] = None + color: Optional[Color] = None + texture: Optional[Texture] = None + + +@dataclass(eq=False) +class Visual: + name: Optional[str] = None + origin: Optional[np.ndarray] = None + geometry: Optional[Geometry] = None # That's not really optional according to ROS + material: Optional[Material] = None + + def __eq__(self, other): + if not isinstance(other, Visual): + return NotImplemented + return ( + self.name == other.name + and _array_eq(self.origin, other.origin) + and self.geometry == other.geometry + and self.material == other.material + ) + + +@dataclass(eq=False) +class Collision: + name: str + origin: Optional[np.ndarray] = None + geometry: Geometry = None + + def __eq__(self, other): + if not isinstance(other, Collision): + return NotImplemented + return self.name == other.name and _array_eq(self.origin, other.origin) and self.geometry == other.geometry + + +@dataclass(eq=False) +class Inertial: + origin: Optional[np.ndarray] = None + mass: Optional[float] = None + inertia: Optional[np.ndarray] = None + + def __eq__(self, other): + if not isinstance(other, Inertial): + return NotImplemented + return ( + _array_eq(self.origin, other.origin) and self.mass == other.mass and _array_eq(self.inertia, other.inertia) + ) + + +@dataclass(eq=False) +class Link: + name: str + inertial: Optional[Inertial] = None + visuals: List[Visual] = field(default_factory=list) + collisions: List[Collision] = field(default_factory=list) + + def __eq__(self, other): + if not isinstance(other, Link): + return NotImplemented + return ( + self.name == other.name + and self.inertial == other.inertial + and all(self_visual in other.visuals for self_visual in self.visuals) + and all(other_visual in self.visuals for other_visual in other.visuals) + and all(self_collision in other.collisions for self_collision in self.collisions) + and all(other_collision in self.collisions for other_collision in other.collisions) + ) + + +@dataclass +class Dynamics: + damping: Optional[float] = None + friction: Optional[float] = None + + +@dataclass +class Limit: + effort: Optional[float] = None + velocity: Optional[float] = None + lower: Optional[float] = None + upper: Optional[float] = None + + +@dataclass(eq=False) +class Joint: + name: str + type: str = None + parent: str = None + child: str = None + origin: np.ndarray = None + axis: np.ndarray = None + dynamics: Optional[Dynamics] = None + limit: Optional[Limit] = None + mimic: Optional[Mimic] = None + calibration: Optional[Calibration] = None + safety_controller: Optional[SafetyController] = None + + def __eq__(self, other): + if not isinstance(other, Joint): + return NotImplemented + return ( + self.name == other.name + and self.type == other.type + and self.parent == other.parent + and self.child == other.child + and _array_eq(self.origin, other.origin) + and _array_eq(self.axis, other.axis) + and self.dynamics == other.dynamics + and self.limit == other.limit + and self.mimic == other.mimic + and self.calibration == other.calibration + and self.safety_controller == other.safety_controller + ) + + +@dataclass(eq=False) +class Robot: + name: str + links: List[Link] = field(default_factory=list) + joints: List[Joint] = field(default_factory=list) + materials: List[Material] = field(default_factory=list) + transmission: List[str] = field(default_factory=list) + gazebo: List[str] = field(default_factory=list) + + def __eq__(self, other): + if not isinstance(other, Robot): + return NotImplemented + return ( + self.name == other.name + and all(self_link in other.links for self_link in self.links) + and all(other_link in self.links for other_link in other.links) + and all(self_joint in other.joints for self_joint in self.joints) + and all(other_joint in self.joints for other_joint in other.joints) + and all(self_material in other.materials for self_material in self.materials) + and all(other_material in self.materials for other_material in other.materials) + and all(self_transmission in other.transmission for self_transmission in self.transmission) + and all(other_transmission in self.transmission for other_transmission in other.transmission) + and all(self_gazebo in other.gazebo for self_gazebo in self.gazebo) + and all(other_gazebo in self.gazebo for other_gazebo in other.gazebo) + ) + + +class URDFError(Exception): + """General URDF exception.""" + + def __init__(self, msg): + super(URDFError, self).__init__() + self.msg = msg + + def __str__(self): + return type(self).__name__ + ": " + self.msg + + def __repr__(self): + return type(self).__name__ + '("' + self.msg + '")' + + +class URDFIncompleteError(URDFError): + """Raised when needed data for an object isn't there.""" + + pass + + +class URDFAttributeValueError(URDFError): + """Raised when attribute value is not contained in the set of allowed values.""" + + pass + + +class URDFBrokenRefError(URDFError): + """Raised when a referenced object is not found in the scope.""" + + pass + + +class URDFMalformedError(URDFError): + """Raised when data is found to be corrupted in some way.""" + + pass + + +class URDFUnsupportedError(URDFError): + """Raised when some unexpectedly unsupported feature is found.""" + + pass + + +class URDFSaveValidationError(URDFError): + """Raised when XML validation fails when saving.""" + + pass + + +def _str2float(s): + """Cast string to float if it is not None. Otherwise return None. + + Args: + s (str): String to convert or None. + + Returns: + str or NoneType: The converted string or None. + """ + return float(s) if s is not None else None + + +def apply_visual_color( + geom: trimesh.Trimesh, + visual: Visual, + material_map: Dict[str, Material], +) -> None: + """Apply the color of the visual material to the mesh. + + Args: + geom: Trimesh to color. + visual: Visual description from XML. + material_map: Dictionary mapping material names to their definitions. + """ + if visual.material is None: + return + + if visual.material.color is not None: + color = visual.material.color + elif visual.material.name is not None and visual.material.name in material_map: + color = material_map[visual.material.name].color + else: + return + + if color is None: + return + if isinstance(geom.visual, trimesh.visual.ColorVisuals): + geom.visual.face_colors[:] = [int(255 * channel) for channel in color.rgba] + + +def filename_handler_null(fname): + """A lazy filename handler that simply returns its input. + + Args: + fname (str): A file name. + + Returns: + str: Same file name. + """ + return fname + + +def filename_handler_ignore_directive(fname): + """A filename handler that removes anything before (and including) '://'. + + Args: + fname (str): A file name. + + Returns: + str: The file name without the prefix. + """ + if "://" in fname or ":\\\\" in fname: + return ":".join(fname.split(":")[1:])[2:] + return fname + + +def filename_handler_ignore_directive_package(fname): + """A filename handler that removes the 'package://' directive and the package it refers to. + It subsequently calls filename_handler_ignore_directive, i.e., it removes any other directive. + + Args: + fname (str): A file name. + + Returns: + str: The file name without 'package://' and the package name. + """ + if fname.startswith("package://"): + string_length = len("package://") + return os.path.join(*os.path.normpath(fname[string_length:]).split(os.path.sep)[1:]) + return filename_handler_ignore_directive(fname) + + +def filename_handler_add_prefix(fname, prefix): + """A filename handler that adds a prefix. + + Args: + fname (str): A file name. + prefix (str): A prefix. + + Returns: + str: Prefix plus file name. + """ + return prefix + fname + + +def filename_handler_absolute2relative(fname, dir): + """A filename handler that turns an absolute file name into a relative one. + + Args: + fname (str): A file name. + dir (str): A directory. + + Returns: + str: The file name relative to the directory. + """ + # TODO: that's not right + if fname.startswith(dir): + return fname[len(dir) :] + return fname + + +def filename_handler_relative(fname, dir): + """A filename handler that joins a file name with a directory. + + Args: + fname (str): A file name. + dir (str): A directory. + + Returns: + str: The directory joined with the file name. + """ + return os.path.join(dir, filename_handler_ignore_directive_package(fname)) + + +def filename_handler_relative_to_urdf_file(fname, urdf_fname): + return filename_handler_relative(fname, os.path.dirname(urdf_fname)) + + +def filename_handler_relative_to_urdf_file_recursive(fname, urdf_fname, level=0): + if level == 0: + return filename_handler_relative_to_urdf_file(fname, urdf_fname) + return filename_handler_relative_to_urdf_file_recursive(fname, os.path.split(urdf_fname)[0], level=level - 1) + + +def _create_filename_handlers_to_urdf_file_recursive(urdf_fname): + return [ + partial( + filename_handler_relative_to_urdf_file_recursive, + urdf_fname=urdf_fname, + level=i, + ) + for i in range(len(os.path.normpath(urdf_fname).split(os.path.sep))) + ] + + +def filename_handler_meta(fname, filename_handlers): + """A filename handler that calls other filename handlers until the resulting file name points to an existing file. + + Args: + fname (str): A file name. + filename_handlers (list(fn)): A list of function pointers to filename handlers. + + Returns: + str: The resolved file name that points to an existing file or the input if none of the files exists. + """ + for fn in filename_handlers: + candidate_fname = fn(fname=fname) + _logger.debug(f"Checking filename: {candidate_fname}") + if os.path.isfile(candidate_fname): + return candidate_fname + _logger.warning(f"Unable to resolve filename: {fname}") + return fname + + +def filename_handler_magic(fname, dir): + """A magic filename handler. + + Args: + fname (str): A file name. + dir (str): A directory. + + Returns: + str: The file name that exists or the input if nothing is found. + """ + return filename_handler_meta( + fname=fname, + filename_handlers=[ + partial(filename_handler_relative, dir=dir), + filename_handler_ignore_directive, + ] + + _create_filename_handlers_to_urdf_file_recursive(urdf_fname=dir), + ) + + +def validation_handler_strict(errors): + """A validation handler that does not allow any errors. + + Args: + errors (list[yourdfpy.URDFError]): List of errors. + + Returns: + bool: Whether any errors were found. + """ + return len(errors) == 0 + + +class URDF: + def __init__( + self, + robot: Robot = None, + build_scene_graph: bool = True, + build_collision_scene_graph: bool = False, + load_meshes: bool = True, + load_collision_meshes: bool = False, + filename_handler=None, + mesh_dir: str = "", + force_mesh: bool = False, + force_collision_mesh: bool = True, + build_tree: bool = False, + ): + """A URDF model. + + Args: + robot (Robot): The robot model. Defaults to None. + build_scene_graph (bool, optional): Wheter to build a scene graph to enable transformation queries and forward kinematics. Defaults to True. + build_collision_scene_graph (bool, optional): Wheter to build a scene graph for elements. Defaults to False. + load_meshes (bool, optional): Whether to load the meshes referenced in the elements. Defaults to True. + load_collision_meshes (bool, optional): Whether to load the collision meshes referenced in the elements. Defaults to False. + filename_handler ([type], optional): Any function f(in: str) -> str, that maps filenames in the URDF to actual resources. Can be used to customize treatment of `package://` directives or relative/absolute filenames. Defaults to None. + mesh_dir (str, optional): A root directory used for loading meshes. Defaults to "". + force_mesh (bool, optional): Each loaded geometry will be concatenated into a single one (instead of being turned into a graph; in case the underlying file contains multiple geometries). This might loose texture information but the resulting scene graph will be smaller. Defaults to False. + force_collision_mesh (bool, optional): Same as force_mesh, but for collision scene. Defaults to True. + build_tree (bool, optional): Build the tree structure for global kinematics computation + """ + if filename_handler is None: + self._filename_handler = partial(filename_handler_magic, dir=mesh_dir) + else: + self._filename_handler = filename_handler + + self.robot = robot + self._create_maps() + self._update_actuated_joints() + + self._cfg = self.zero_cfg + + if build_scene_graph or build_collision_scene_graph: + self._base_link = self._determine_base_link() + else: + self._base_link = None + + self._errors = [] + + if build_scene_graph: + self._scene = self._create_scene( + use_collision_geometry=False, + load_geometry=load_meshes, + force_mesh=force_mesh, + force_single_geometry_per_link=force_mesh, + ) + else: + self._scene = None + + if build_collision_scene_graph: + self._scene_collision = self._create_scene( + use_collision_geometry=True, + load_geometry=load_collision_meshes, + force_mesh=force_collision_mesh, + force_single_geometry_per_link=force_collision_mesh, + ) + else: + self._scene_collision = None + + if build_tree: + self.tree_root = self.build_tree() + else: + self.tree_root = None + + @property + def scene(self) -> trimesh.Scene: + """A scene object representing the URDF model. + + Returns: + trimesh.Scene: A trimesh scene object. + """ + return self._scene + + @property + def collision_scene(self) -> trimesh.Scene: + """A scene object representing the elements of the URDF model + + Returns: + trimesh.Scene: A trimesh scene object. + """ + return self._scene_collision + + @property + def link_map(self) -> dict: + """A dictionary mapping link names to link objects. + + Returns: + dict: Mapping from link name (str) to Link. + """ + return self._link_map + + @property + def joint_map(self) -> dict: + """A dictionary mapping joint names to joint objects. + + Returns: + dict: Mapping from joint name (str) to Joint. + """ + return self._joint_map + + @property + def joint_names(self): + """List of joint names. + + Returns: + list[str]: List of joint names of the URDF model. + """ + return [j.name for j in self.robot.joints] + + @property + def actuated_joints(self): + """List of actuated joints. This excludes mimic and fixed joints. + + Returns: + list[Joint]: List of actuated joints of the URDF model. + """ + return self._actuated_joints + + @property + def actuated_dof_indices(self): + """List of DOF indices per actuated joint. Can be used to reference configuration. + + Returns: + list[list[int]]: List of DOF indices per actuated joint. + """ + return self._actuated_dof_indices + + @property + def actuated_joint_indices(self): + """List of indices of all joints that are actuated, i.e., not of type mimic or fixed. + + Returns: + list[int]: List of indices of actuated joints. + """ + return self._actuated_joint_indices + + @property + def actuated_joint_names(self): + """List of names of actuated joints. This excludes mimic and fixed joints. + + Returns: + list[str]: List of names of actuated joints of the URDF model. + """ + return [j.name for j in self._actuated_joints] + + @property + def num_actuated_joints(self): + """Number of actuated joints. + + Returns: + int: Number of actuated joints. + """ + return len(self.actuated_joints) + + @property + def num_dofs(self): + """Number of degrees of freedom of actuated joints. Depending on the type of the joint, the number of DOFs might vary. + + Returns: + int: Degrees of freedom. + """ + total_num_dofs = 0 + for j in self._actuated_joints: + if j.type in ["revolute", "prismatic", "continuous"]: + total_num_dofs += 1 + elif j.type == "floating": + total_num_dofs += 6 + elif j.type == "planar": + total_num_dofs += 2 + return total_num_dofs + + @property + def zero_cfg(self): + """Return the zero configuration. + + Returns: + np.ndarray: The zero configuration. + """ + return np.zeros(self.num_dofs) + + @property + def center_cfg(self): + """Return center configuration of URDF model by using the average of each joint's limits if present, otherwise zero. + + Returns: + (n), float: Default configuration of URDF model. + """ + config = [] + config_names = [] + for j in self._actuated_joints: + if j.type == "revolute" or j.type == "prismatic": + if j.limit is not None: + cfg = [j.limit.lower + 0.5 * (j.limit.upper - j.limit.lower)] + else: + cfg = [0.0] + elif j.type == "continuous": + cfg = [0.0] + elif j.type == "floating": + cfg = [0.0] * 6 + elif j.type == "planar": + cfg = [0.0] * 2 + + config.append(cfg) + config_names.append(j.name) + + for i, j in enumerate(self.robot.joints): + if j.mimic is not None: + index = config_names.index(j.mimic.joint) + config[i][0] = config[index][0] * j.mimic.multiplier + j.mimic.offset + + if len(config) == 0: + return np.array([], dtype=np.float64) + return np.concatenate(config) + + @property + def cfg(self): + """Current configuration. + + Returns: + np.ndarray: Current configuration of URDF model. + """ + return self._cfg + + @property + def base_link(self): + """Name of URDF base/root link. + + Returns: + str: Name of base link of URDF model. + """ + return self._base_link + + @property + def errors(self) -> list: + """A list with validation errors. + + Returns: + list: A list of validation errors. + """ + return self._errors + + def clear_errors(self): + """Clear the validation error log.""" + self._errors = [] + + def show(self, collision_geometry=False, callback=None): + """Open a simpler viewer displaying the URDF model. + + Args: + collision_geometry (bool, optional): Whether to display the or elements. Defaults to False. + """ + if collision_geometry: + if self._scene_collision is None: + raise ValueError( + "No collision scene available. Use build_collision_scene_graph=True and load_collision_meshes=True during loading." + ) + else: + self._scene_collision.show(callback=callback) + else: + if self._scene is None: + raise ValueError("No scene available. Use build_scene_graph=True and load_meshes=True during loading.") + elif len(self._scene.bounds_corners) < 1: + raise ValueError( + "Scene is empty, maybe meshes failed to load? Use build_scene_graph=True and load_meshes=True during loading." + ) + else: + self._scene.show(callback=callback) + + def validate(self, validation_fn=None) -> bool: + """Validate URDF model. + + Args: + validation_fn (function, optional): A function f(list[yourdfpy.URDFError]) -> bool. None uses the strict handler (any error leads to False). Defaults to None. + + Returns: + bool: Whether the model is valid. + """ + self._errors = [] + self._validate_robot(self.robot) + + if validation_fn is None: + validation_fn = validation_handler_strict + + return validation_fn(self._errors) + + def _create_maps(self): + self._material_map = {} + for m in self.robot.materials: + self._material_map[m.name] = m + + self._joint_map = {} + for j in self.robot.joints: + self._joint_map[j.name] = j + + self._link_map = {} + for l in self.robot.links: + self._link_map[l.name] = l + + def _update_actuated_joints(self): + self._actuated_joints = [] + self._actuated_joint_indices = [] + self._actuated_dof_indices = [] + + dof_indices_cnt = 0 + for i, j in enumerate(self.robot.joints): + if j.mimic is None and j.type != "fixed": + self._actuated_joints.append(j) + self._actuated_joint_indices.append(i) + + if j.type in ["prismatic", "revolute", "continuous"]: + self._actuated_dof_indices.append([dof_indices_cnt]) + dof_indices_cnt += 1 + elif j.type == "floating": + self._actuated_dof_indices.append([dof_indices_cnt, dof_indices_cnt + 1, dof_indices_cnt + 2]) + dof_indices_cnt += 3 + elif j.type == "planar": + self._actuated_dof_indices.append([dof_indices_cnt, dof_indices_cnt + 1]) + dof_indices_cnt += 2 + + def _validate_required_attribute(self, attribute, error_msg, allowed_values=None): + if attribute is None: + self._errors.append(URDFIncompleteError(error_msg)) + elif isinstance(attribute, str) and len(attribute) == 0: + self._errors.append(URDFIncompleteError(error_msg)) + + if allowed_values is not None and attribute is not None: + if attribute not in allowed_values: + self._errors.append(URDFAttributeValueError(error_msg)) + + @staticmethod + def load(fname_or_file, **kwargs): + """Load URDF file from filename or file object. + + Args: + fname_or_file (str or file object): A filename or file object, file-like object, stream representing the URDF file. + **build_scene_graph (bool, optional): Wheter to build a scene graph to enable transformation queries and forward kinematics. Defaults to True. + **build_collision_scene_graph (bool, optional): Wheter to build a scene graph for elements. Defaults to False. + **load_meshes (bool, optional): Whether to load the meshes referenced in the elements. Defaults to True. + **load_collision_meshes (bool, optional): Whether to load the collision meshes referenced in the elements. Defaults to False. + **filename_handler ([type], optional): Any function f(in: str) -> str, that maps filenames in the URDF to actual resources. Can be used to customize treatment of `package://` directives or relative/absolute filenames. Defaults to None. + **mesh_dir (str, optional): A root directory used for loading meshes. Defaults to "". + **force_mesh (bool, optional): Each loaded geometry will be concatenated into a single one (instead of being turned into a graph; in case the underlying file contains multiple geometries). This might loose texture information but the resulting scene graph will be smaller. Defaults to False. + **force_collision_mesh (bool, optional): Same as force_mesh, but for collision scene. Defaults to True. + + Raises: + ValueError: If filename does not exist. + + Returns: + yourdfpy.URDF: URDF model. + """ + if isinstance(fname_or_file, six.string_types): + if not os.path.isfile(fname_or_file): + raise ValueError("{} is not a file".format(fname_or_file)) + + if not "mesh_dir" in kwargs: + kwargs["mesh_dir"] = os.path.dirname(fname_or_file) + + try: + parser = etree.XMLParser(remove_blank_text=True) + tree = etree.parse(fname_or_file, parser=parser) + xml_root = tree.getroot() + except Exception as e: + _logger.error(e) + _logger.error("Using different parsing approach.") + + events = ("start", "end", "start-ns", "end-ns") + xml = etree.iterparse(fname_or_file, recover=True, events=events) + + # Iterate through all XML elements + for action, elem in xml: + # Skip comments and processing instructions, + # because they do not have names + if not (isinstance(elem, etree._Comment) or isinstance(elem, etree._ProcessingInstruction)): + # Remove a namespace URI in the element's name + # elem.tag = etree.QName(elem).localname + if action == "end" and ":" in elem.tag: + elem.getparent().remove(elem) + + xml_root = xml.root + + # Remove comments + etree.strip_tags(xml_root, etree.Comment) + etree.cleanup_namespaces(xml_root) + + return URDF(robot=URDF._parse_robot(xml_element=xml_root), **kwargs) + + def contains(self, key, value, element=None) -> bool: + """Checks recursively whether the URDF tree contains the provided key-value pair. + + Args: + key (str): A key. + value (str): A value. + element (etree.Element, optional): The XML element from which to start the recursive search. None means URDF root. Defaults to None. + + Returns: + bool: Whether the key-value pair was found. + """ + if element is None: + element = self.robot + + result = False + for field in element.__dataclass_fields__: + field_value = getattr(element, field) + if is_dataclass(field_value): + result = result or self.contains(key=key, value=value, element=field_value) + elif isinstance(field_value, list) and len(field_value) > 0 and is_dataclass(field_value[0]): + for field_value_element in field_value: + result = result or self.contains(key=key, value=value, element=field_value_element) + else: + if key == field and value == field_value: + result = True + return result + + def _determine_base_link(self): + """Get the base link of the URDF tree by extracting all links without parents. + In case multiple links could be root chose the first. + + Returns: + str: Name of the base link. + """ + link_names = [l.name for l in self.robot.links] + + for j in self.robot.joints: + link_names.remove(j.child) + + if len(link_names) == 0: + # raise Error? + return None + + return link_names[0] + + def _forward_kinematics_joint(self, joint, q=None): + origin = np.eye(4) if joint.origin is None else joint.origin + + if joint.mimic is not None: + if joint.mimic.joint in self.actuated_joint_names: + mimic_joint_index = self.actuated_joint_names.index(joint.mimic.joint) + q = self._cfg[mimic_joint_index] * joint.mimic.multiplier + joint.mimic.offset + else: + _logger.warning( + f"Joint '{joint.name}' is supposed to mimic '{joint.mimic.joint}'. But this joint is not actuated - will assume (0.0 + offset)." + ) + q = 0.0 + joint.mimic.offset + + if joint.type in ["revolute", "prismatic", "continuous"]: + if q is None: + # Use internal cfg vector for forward kinematics + q = self.cfg[self.actuated_dof_indices[self.actuated_joint_names.index(joint.name)]] + + if joint.type == "prismatic": + matrix = origin @ tra.translation_matrix(q * joint.axis) + else: + matrix = origin @ tra.rotation_matrix(q, joint.axis) + else: + # this includes: floating, planar, fixed + matrix = origin + + return matrix, q + + def update_cfg(self, configuration): + """Update joint configuration of URDF; does forward kinematics. + + Args: + configuration (dict, list[float], tuple[float] or np.ndarray): A mapping from joints or joint names to configuration values, or a list containing a value for each actuated joint. + + Raises: + ValueError: Raised if dimensionality of configuration does not match number of actuated joints of URDF model. + TypeError: Raised if configuration is neither a dict, list, tuple or np.ndarray. + """ + joint_cfg = [] + + if isinstance(configuration, dict): + for joint in configuration: + if isinstance(joint, six.string_types): + joint_cfg.append((self._joint_map[joint], configuration[joint])) + elif isinstance(joint, Joint): + # TODO: Joint is not hashable; so this branch will not succeed + joint_cfg.append((joint, configuration[joint])) + elif isinstance(configuration, (list, tuple, np.ndarray)): + if len(configuration) == len(self.robot.joints): + for joint, value in zip(self.robot.joints, configuration): + joint_cfg.append((joint, value)) + elif len(configuration) == self.num_actuated_joints: + for joint, value in zip(self._actuated_joints, configuration): + joint_cfg.append((joint, value)) + else: + raise ValueError( + f"Dimensionality of configuration ({len(configuration)}) doesn't match number of all ({len(self.robot.joints)}) or actuated joints ({self.num_actuated_joints})." + ) + else: + raise TypeError("Invalid type for configuration") + + # append all mimic joints in the update + for j, q in joint_cfg + [(j, 0.0) for j in self.robot.joints if j.mimic is not None]: + matrix, joint_q = self._forward_kinematics_joint(j, q=q) + + # update internal configuration vector - only consider actuated joints + if j.name in self.actuated_joint_names: + self._cfg[self.actuated_dof_indices[self.actuated_joint_names.index(j.name)]] = joint_q + + if self._scene is not None: + self._scene.graph.update(frame_from=j.parent, frame_to=j.child, matrix=matrix) + if self._scene_collision is not None: + self._scene_collision.graph.update(frame_from=j.parent, frame_to=j.child, matrix=matrix) + + def get_transform(self, frame_to, frame_from=None, collision_geometry=False): + """Get the transform from one frame to another. + + Args: + frame_to (str): Node name. + frame_from (str, optional): Node name. If None it will be set to self.base_frame. Defaults to None. + collision_geometry (bool, optional): Whether to use the collision geometry scene graph (instead of the visual geometry). Defaults to False. + + Raises: + ValueError: Raised if scene graph wasn't constructed during intialization. + + Returns: + (4, 4) float: Homogeneous transformation matrix + """ + if collision_geometry: + if self._scene_collision is None: + raise ValueError("No collision scene available. Use build_collision_scene_graph=True during loading.") + else: + return self._scene_collision.graph.get(frame_to=frame_to, frame_from=frame_from)[0] + else: + if self._scene is None: + raise ValueError("No scene available. Use build_scene_graph=True during loading.") + else: + return self._scene.graph.get(frame_to=frame_to, frame_from=frame_from)[0] + + def _link_mesh(self, link, collision_geometry=True): + geometries = link.collisions if collision_geometry else link.visuals + + if len(geometries) == 0: + return None + + meshes = [] + for g in geometries: + for m in g.geometry.meshes: + m = m.copy() + pose = g.origin + if g.geometry.mesh is not None: + if g.geometry.mesh.scale is not None: + S = np.eye(4) + S[:3, :3] = np.diag(g.geometry.mesh.scale) + pose = pose.dot(S) + m.apply_transform(pose) + meshes.append(m) + if len(meshes) == 0: + return None + self._collision_mesh = meshes[0] + meshes[1:] + return self._collision_mesh + + def _geometry2trimeshscene(self, geometry, load_file, force_mesh, skip_materials): + new_s = None + if geometry.box is not None: + new_s = trimesh.primitives.Box(extents=geometry.box.size).scene() + elif geometry.sphere is not None: + new_s = trimesh.primitives.Sphere(radius=geometry.sphere.radius).scene() + elif geometry.cylinder is not None: + new_s = trimesh.primitives.Cylinder( + radius=geometry.cylinder.radius, height=geometry.cylinder.length + ).scene() + elif geometry.mesh is not None and load_file: + new_filename = self._filename_handler(fname=geometry.mesh.filename) + + if os.path.isfile(new_filename): + _logger.debug(f"Loading {geometry.mesh.filename} as {new_filename}") + + if force_mesh: + new_g = trimesh.load( + new_filename, + ignore_broken=True, + force="mesh", + skip_materials=skip_materials, + ) + + # add original filename + if "file_path" not in new_g.metadata: + new_g.metadata["file_path"] = os.path.abspath(new_filename) + new_g.metadata["file_name"] = os.path.basename(new_filename) + + new_s = trimesh.Scene() + new_s.add_geometry(new_g) + else: + new_s = trimesh.load( + new_filename, + ignore_broken=True, + force="scene", + skip_materials=skip_materials, + ) + + if "file_path" in new_s.metadata: + for i, (_, geom) in enumerate(new_s.geometry.items()): + if "file_path" not in geom.metadata: + geom.metadata["file_path"] = new_s.metadata["file_path"] + geom.metadata["file_name"] = new_s.metadata["file_name"] + geom.metadata["file_element"] = i + + # scale mesh appropriately + if geometry.mesh.scale is not None: + if isinstance(geometry.mesh.scale, float): + new_s = new_s.scaled(geometry.mesh.scale) + elif isinstance(geometry.mesh.scale, np.ndarray): + new_s = new_s.scaled(geometry.mesh.scale) + else: + _logger.warning(f"Warning: Can't interpret scale '{geometry.mesh.scale}'") + else: + _logger.warning(f"Can't find {new_filename}") + return new_s + + def _add_geometries_to_scene( + self, + s, + geometries, + link_name, + load_geometry, + force_mesh, + force_single_geometry, + skip_materials, + ): + if force_single_geometry: + tmp_scene = trimesh.Scene(base_frame=link_name) + + first_geom_name = None + + for v in geometries: + if v.geometry is not None: + if first_geom_name is None: + first_geom_name = v.name + + new_s = self._geometry2trimeshscene( + geometry=v.geometry, + load_file=load_geometry, + force_mesh=force_mesh, + skip_materials=skip_materials, + ) + if new_s is not None: + origin = v.origin if v.origin is not None else np.eye(4) + + if force_single_geometry: + for name, geom in new_s.geometry.items(): + if isinstance(v, Visual): + apply_visual_color(geom, v, self._material_map) + tmp_scene.add_geometry( + geometry=geom, + geom_name=v.name, + parent_node_name=link_name, + transform=origin @ new_s.graph.get(name)[0], + ) + else: + # The following map is used to deal with glb format + # when the graph node and geometry have different names + geom_name_map = {new_s.graph[node_name][1]: node_name for node_name in new_s.graph.nodes} + for name, geom in new_s.geometry.items(): + if isinstance(v, Visual): + apply_visual_color(geom, v, self._material_map) + s.add_geometry( + geometry=geom, + geom_name=v.name, + parent_node_name=link_name, + transform=origin @ new_s.graph.get(geom_name_map[name])[0], + ) + + if force_single_geometry and len(tmp_scene.geometry) > 0: + s.add_geometry( + geometry=tmp_scene.dump(concatenate=True), + geom_name=first_geom_name, + parent_node_name=link_name, + transform=np.eye(4), + ) + + def _create_scene( + self, + use_collision_geometry=False, + load_geometry=True, + force_mesh=False, + force_single_geometry_per_link=False, + ): + s = trimesh.scene.Scene(base_frame=self._base_link) + + for j in self.robot.joints: + matrix, _ = self._forward_kinematics_joint(j) + + s.graph.update(frame_from=j.parent, frame_to=j.child, matrix=matrix) + + for l in self.robot.links: + if l.name not in s.graph.nodes and l.name != s.graph.base_frame: + _logger.warning(f"{l.name} not connected via joints. Will add link to base frame.") + s.graph.update(frame_from=s.graph.base_frame, frame_to=l.name) + + meshes = l.collisions if use_collision_geometry else l.visuals + self._add_geometries_to_scene( + s, + geometries=meshes, + link_name=l.name, + load_geometry=load_geometry, + force_mesh=force_mesh, + force_single_geometry=force_single_geometry_per_link, + skip_materials=use_collision_geometry, + ) + + return s + + def _successors(self, node): + """ + Get all nodes of the scene that succeeds a specified node. + + Parameters + ------------ + node : any + Hashable key in `scene.graph` + + Returns + ----------- + subnodes : set[str] + Set of nodes. + """ + # get every node that is a successor to specified node + # this includes `node` + return self._scene.graph.transforms.successors(node) + + def _create_subrobot(self, robot_name, root_link_name): + subrobot = Robot(name=robot_name) + subnodes = self._successors(node=root_link_name) + + if len(subnodes) > 0: + for node in subnodes: + if node in self.link_map: + subrobot.links.append(copy.deepcopy(self.link_map[node])) + for joint_name, joint in self.joint_map.items(): + if joint.parent in subnodes and joint.child in subnodes: + subrobot.joints.append(copy.deepcopy(self.joint_map[joint_name])) + + return subrobot + + def split_along_joints(self, joint_type="floating", **kwargs): + """Split URDF model along a particular joint type. + The result is a set of URDF models which together compose the original URDF. + + Args: + joint_type (str, or list[str], optional): Type of joint to use for splitting. Defaults to "floating". + **kwargs: Arguments delegated to URDF constructor of new URDF models. + + Returns: + list[(np.ndarray, yourdfpy.URDF)]: A list of tuples (np.ndarray, yourdfpy.URDF) whereas each homogeneous 4x4 matrix describes the root transformation of the respective URDF model w.r.t. the original URDF. + """ + root_urdf = URDF(robot=copy.deepcopy(self.robot), build_scene_graph=False, load_meshes=False) + result = [] + + joint_types = joint_type if isinstance(joint_type, list) else [joint_type] + + # find all relevant joints + joint_names = [j.name for j in self.robot.joints if j.type in joint_types] + for joint_name in joint_names: + root_link = self.link_map[self.joint_map[joint_name].child] + new_robot = self._create_subrobot( + robot_name=root_link.name, + root_link_name=root_link.name, + ) + + result.append( + ( + self._scene.graph.get(root_link.name)[0], + URDF(robot=new_robot, **kwargs), + ) + ) + + # remove links and joints from root robot + for j in new_robot.joints: + root_urdf.robot.joints.remove(root_urdf.joint_map[j.name]) + for l in new_robot.links: + root_urdf.robot.links.remove(root_urdf.link_map[l.name]) + + # remove joint that connects root urdf to root_link + if root_link.name in [j.child for j in root_urdf.robot.joints]: + root_urdf.robot.joints.remove( + root_urdf.robot.joints[[j.child for j in root_urdf.robot.joints].index(root_link.name)] + ) + + result.insert(0, (np.eye(4), URDF(robot=root_urdf.robot, **kwargs))) + + return result + + def validate_filenames(self): + for l in self.robot.links: + meshes = [m.geometry.mesh for m in l.collisions + l.visuals if m.geometry.mesh is not None] + for m in meshes: + _logger.debug(m.filename, "-->", self._filename_handler(m.filename)) + if not os.path.isfile(self._filename_handler(m.filename)): + return False + return True + + def write_xml(self): + """Write URDF model to an XML element hierarchy. + + Returns: + etree.ElementTree: XML data. + """ + xml_element = self._write_robot(self.robot) + return etree.ElementTree(xml_element) + + def write_xml_string(self, **kwargs): + """Write URDF model to a string. + + Returns: + str: String of the xml representation of the URDF model. + """ + xml_element = self.write_xml() + return etree.tostring(xml_element, xml_declaration=True, *kwargs) + + def write_xml_file(self, fname): + """Write URDF model to an xml file. + + Args: + fname (str): Filename of the file to be written. Usually ends in `.urdf`. + """ + xml_element = self.write_xml() + xml_element.write(fname, xml_declaration=True, pretty_print=True) + + def _parse_mimic(xml_element): + if xml_element is None: + return None + + return Mimic( + joint=xml_element.get("joint"), + multiplier=_str2float(xml_element.get("multiplier", 1.0)), + offset=_str2float(xml_element.get("offset", 0.0)), + ) + + def _write_mimic(self, xml_parent, mimic): + etree.SubElement( + xml_parent, + "mimic", + attrib={ + "joint": mimic.joint, + "multiplier": str(mimic.multiplier), + "offset": str(mimic.offset), + }, + ) + + def _parse_safety_controller(xml_element): + if xml_element is None: + return None + + return SafetyController( + soft_lower_limit=_str2float(xml_element.get("soft_lower_limit")), + soft_upper_limit=_str2float(xml_element.get("soft_upper_limit")), + k_position=_str2float(xml_element.get("k_position")), + k_velocity=_str2float(xml_element.get("k_velocity")), + ) + + def _write_safety_controller(self, xml_parent, safety_controller): + etree.SubElement( + xml_parent, + "safety_controller", + attrib={ + "soft_lower_limit": str(safety_controller.soft_lower_limit), + "soft_upper_limit": str(safety_controller.soft_upper_limit), + "k_position": str(safety_controller.k_position), + "k_velocity": str(safety_controller.k_velocity), + }, + ) + + def _parse_transmission_joint(xml_element): + if xml_element is None: + return None + + transmission_joint = TransmissionJoint(name=xml_element.get("name")) + + for h in xml_element.findall("hardware_interface"): + transmission_joint.hardware_interfaces.append(h.text) + + return transmission_joint + + def _write_transmission_joint(self, xml_parent, transmission_joint): + xml_element = etree.SubElement( + xml_parent, + "joint", + attrib={ + "name": str(transmission_joint.name), + }, + ) + for h in transmission_joint.hardware_interfaces: + tmp = etree.SubElement( + xml_element, + "hardwareInterface", + ) + tmp.text = h + + def _parse_actuator(xml_element): + if xml_element is None: + return None + + actuator = Actuator(name=xml_element.get("name")) + if xml_element.find("mechanicalReduction"): + actuator.mechanical_reduction = float(xml_element.find("mechanicalReduction").text) + + for h in xml_element.findall("hardwareInterface"): + actuator.hardware_interfaces.append(h.text) + + return actuator + + def _write_actuator(self, xml_parent, actuator): + xml_element = etree.SubElement( + xml_parent, + "actuator", + attrib={ + "name": str(actuator.name), + }, + ) + if actuator.mechanical_reduction is not None: + tmp = etree.SubElement("mechanicalReduction") + tmp.text = str(actuator.mechanical_reduction) + + for h in actuator.hardware_interfaces: + tmp = etree.SubElement( + xml_element, + "hardwareInterface", + ) + tmp.text = h + + def _parse_transmission(xml_element): + if xml_element is None: + return None + + transmission = Transmission(name=xml_element.get("name")) + + for j in xml_element.findall("joint"): + transmission.joints.append(URDF._parse_transmission_joint(j)) + for a in xml_element.findall("actuator"): + transmission.actuators.append(URDF._parse_actuator(a)) + + return transmission + + def _write_transmission(self, xml_parent, transmission): + xml_element = etree.SubElement( + xml_parent, + "transmission", + attrib={ + "name": str(transmission.name), + }, + ) + + for j in transmission.joints: + self._write_transmission_joint(xml_element, j) + + for a in transmission.actuators: + self._write_actuator(xml_element, a) + + def _parse_calibration(xml_element): + if xml_element is None: + return None + + return Calibration( + rising=_str2float(xml_element.get("rising")), + falling=_str2float(xml_element.get("falling")), + ) + + def _write_calibration(self, xml_parent, calibration): + etree.SubElement( + xml_parent, + "calibration", + attrib={ + "rising": str(calibration.rising), + "falling": str(calibration.falling), + }, + ) + + def _parse_box(xml_element): + return Box(size=np.array(xml_element.attrib["size"].split(), dtype=float)) + + def _write_box(self, xml_parent, box): + etree.SubElement(xml_parent, "box", attrib={"size": " ".join(map(str, box.size))}) + + def _parse_cylinder(xml_element): + return Cylinder( + radius=float(xml_element.attrib["radius"]), + length=float(xml_element.attrib["length"]), + ) + + def _write_cylinder(self, xml_parent, cylinder): + etree.SubElement( + xml_parent, + "cylinder", + attrib={"radius": str(cylinder.radius), "length": str(cylinder.length)}, + ) + + def _parse_sphere(xml_element): + return Sphere(radius=float(xml_element.attrib["radius"])) + + def _write_sphere(self, xml_parent, sphere): + etree.SubElement(xml_parent, "sphere", attrib={"radius": str(sphere.radius)}) + + def _parse_scale(xml_element): + if "scale" in xml_element.attrib: + s = xml_element.get("scale").split() + if len(s) == 0: + return None + elif len(s) == 1: + return float(s[0]) + else: + return np.array(list(map(float, s))) + return None + + def _write_scale(self, xml_parent, scale): + if scale is not None: + if isinstance(scale, float) or isinstance(scale, int): + xml_parent.set("scale", " ".join([str(scale)] * 3)) + else: + xml_parent.set("scale", " ".join(map(str, scale))) + + def _parse_mesh(xml_element): + return Mesh(filename=xml_element.get("filename"), scale=URDF._parse_scale(xml_element)) + + def _write_mesh(self, xml_parent, mesh): + # TODO: turn into different filename handler + xml_element = etree.SubElement( + xml_parent, + "mesh", + attrib={"filename": self._filename_handler(mesh.filename)}, + ) + + self._write_scale(xml_element, mesh.scale) + + def _parse_geometry(xml_element): + geometry = Geometry() + if xml_element[0].tag == "box": + geometry.box = URDF._parse_box(xml_element[0]) + elif xml_element[0].tag == "cylinder": + geometry.cylinder = URDF._parse_cylinder(xml_element[0]) + elif xml_element[0].tag == "sphere": + geometry.sphere = URDF._parse_sphere(xml_element[0]) + elif xml_element[0].tag == "mesh": + geometry.mesh = URDF._parse_mesh(xml_element[0]) + else: + raise ValueError(f"Unknown tag: {xml_element[0].tag}") + + return geometry + + def _validate_geometry(self, geometry): + if geometry is None: + self._errors.append(URDFIncompleteError(" is missing.")) + + num_nones = sum( + [ + x is not None + for x in [ + geometry.box, + geometry.cylinder, + geometry.sphere, + geometry.mesh, + ] + ] + ) + if num_nones < 1: + self._errors.append( + URDFIncompleteError( + "One of , , , needs to be defined as a child of ." + ) + ) + elif num_nones > 1: + self._errors.append( + URDFError( + "Too many of , , , defined as a child of . Only one allowed." + ) + ) + + def _write_geometry(self, xml_parent, geometry): + if geometry is None: + return + + xml_element = etree.SubElement(xml_parent, "geometry") + if geometry.box is not None: + self._write_box(xml_element, geometry.box) + elif geometry.cylinder is not None: + self._write_cylinder(xml_element, geometry.cylinder) + elif geometry.sphere is not None: + self._write_sphere(xml_element, geometry.sphere) + elif geometry.mesh is not None: + self._write_mesh(xml_element, geometry.mesh) + + def _parse_origin(xml_element): + if xml_element is None: + return None + + xyz = xml_element.get("xyz", default="0 0 0") + rpy = xml_element.get("rpy", default="0 0 0") + + return tra.compose_matrix( + translate=np.array(list(map(float, xyz.split()))), + angles=np.array(list(map(float, rpy.split()))), + ) + + def _write_origin(self, xml_parent, origin): + if origin is None: + return + + etree.SubElement( + xml_parent, + "origin", + attrib={ + "xyz": " ".join(map(str, tra.translation_from_matrix(origin))), + "rpy": " ".join(map(str, tra.euler_from_matrix(origin))), + }, + ) + + def _parse_color(xml_element): + if xml_element is None: + return None + + rgba = xml_element.get("rgba", default="1 1 1 1") + + return Color(rgba=np.array(list(map(float, rgba.split())))) + + def _write_color(self, xml_parent, color): + if color is None: + return + + etree.SubElement(xml_parent, "color", attrib={"rgba": " ".join(map(str, color.rgba))}) + + def _parse_texture(xml_element): + if xml_element is None: + return None + + # TODO: use texture filename handler + return Texture(filename=xml_element.get("filename", default=None)) + + def _write_texture(self, xml_parent, texture): + if texture is None: + return + + # TODO: use texture filename handler + etree.SubElement(xml_parent, "texture", attrib={"filename": texture.filename}) + + def _parse_material(xml_element): + if xml_element is None: + return None + + material = Material(name=xml_element.get("name")) + material.color = URDF._parse_color(xml_element.find("color")) + material.texture = URDF._parse_texture(xml_element.find("texture")) + + return material + + def _write_material(self, xml_parent, material): + if material is None: + return + + attrib = {"name": material.name} if material.name is not None else {} + xml_element = etree.SubElement( + xml_parent, + "material", + attrib=attrib, + ) + + self._write_color(xml_element, material.color) + self._write_texture(xml_element, material.texture) + + def _parse_visual(xml_element): + visual = Visual(name=xml_element.get("name")) + + visual.geometry = URDF._parse_geometry(xml_element.find("geometry")) + visual.origin = URDF._parse_origin(xml_element.find("origin")) + visual.material = URDF._parse_material(xml_element.find("material")) + + return visual + + def _validate_visual(self, visual): + self._validate_geometry(visual.geometry) + + def _write_visual(self, xml_parent, visual): + attrib = {"name": visual.name} if visual.name is not None else {} + xml_element = etree.SubElement( + xml_parent, + "visual", + attrib=attrib, + ) + + self._write_geometry(xml_element, visual.geometry) + self._write_origin(xml_element, visual.origin) + self._write_material(xml_element, visual.material) + + def _parse_collision(xml_element): + collision = Collision(name=xml_element.get("name")) + + collision.geometry = URDF._parse_geometry(xml_element.find("geometry")) + collision.origin = URDF._parse_origin(xml_element.find("origin")) + + return collision + + def _validate_collision(self, collision): + self._validate_geometry(collision.geometry) + + def _write_collision(self, xml_parent, collision): + attrib = {"name": collision.name} if collision.name is not None else {} + xml_element = etree.SubElement( + xml_parent, + "collision", + attrib=attrib, + ) + + self._write_geometry(xml_element, collision.geometry) + self._write_origin(xml_element, collision.origin) + + def _parse_inertia(xml_element): + if xml_element is None: + return None + + x = xml_element + + return np.array( + [ + [ + x.get("ixx", default=1.0), + x.get("ixy", default=0.0), + x.get("ixz", default=0.0), + ], + [ + x.get("ixy", default=0.0), + x.get("iyy", default=1.0), + x.get("iyz", default=0.0), + ], + [ + x.get("ixz", default=0.0), + x.get("iyz", default=0.0), + x.get("izz", default=1.0), + ], + ], + dtype=np.float64, + ) + + def _write_inertia(self, xml_parent, inertia): + if inertia is None: + return None + + etree.SubElement( + xml_parent, + "inertia", + attrib={ + "ixx": str(inertia[0, 0]), + "ixy": str(inertia[0, 1]), + "ixz": str(inertia[0, 2]), + "iyy": str(inertia[1, 1]), + "iyz": str(inertia[1, 2]), + "izz": str(inertia[2, 2]), + }, + ) + + def _parse_mass(xml_element): + if xml_element is None: + return None + + return _str2float(xml_element.get("value", default=0.0)) + + def _write_mass(self, xml_parent, mass): + if mass is None: + return + + etree.SubElement( + xml_parent, + "mass", + attrib={ + "value": str(mass), + }, + ) + + def _parse_inertial(xml_element): + if xml_element is None: + return None + + inertial = Inertial() + inertial.origin = URDF._parse_origin(xml_element.find("origin")) + inertial.inertia = URDF._parse_inertia(xml_element.find("inertia")) + inertial.mass = URDF._parse_mass(xml_element.find("mass")) + + return inertial + + def _write_inertial(self, xml_parent, inertial): + if inertial is None: + return + + xml_element = etree.SubElement(xml_parent, "inertial") + + self._write_origin(xml_element, inertial.origin) + self._write_mass(xml_element, inertial.mass) + self._write_inertia(xml_element, inertial.inertia) + + def _parse_link(xml_element): + link = Link(name=xml_element.attrib["name"]) + + link.inertial = URDF._parse_inertial(xml_element.find("inertial")) + + for v in xml_element.findall("visual"): + link.visuals.append(URDF._parse_visual(v)) + + for c in xml_element.findall("collision"): + link.collisions.append(URDF._parse_collision(c)) + + return link + + def _validate_link(self, link): + self._validate_required_attribute(attribute=link.name, error_msg="The tag misses a 'name' attribute.") + + for v in link.visuals: + self._validate_visual(v) + + for c in link.collisions: + self._validate_collision(c) + + def _write_link(self, xml_parent, link): + xml_element = etree.SubElement( + xml_parent, + "link", + attrib={ + "name": link.name, + }, + ) + + self._write_inertial(xml_element, link.inertial) + for visual in link.visuals: + self._write_visual(xml_element, visual) + for collision in link.collisions: + self._write_collision(xml_element, collision) + + def _parse_axis(xml_element): + if xml_element is None: + return np.array([1.0, 0, 0]) + + xyz = xml_element.get("xyz", "1 0 0") + results = [] + for x in xyz.split(): + try: + x = float(x) + except ValueError: + x = 0 + results.append(x) + return np.array(results) + # return np.array(list(map(float, xyz.split()))) + + def _write_axis(self, xml_parent, axis): + if axis is None: + return + + etree.SubElement(xml_parent, "axis", attrib={"xyz": " ".join(map(str, axis))}) + + def _parse_limit(xml_element): + if xml_element is None: + return None + + return Limit( + effort=_str2float(xml_element.get("effort", default=None)), + velocity=_str2float(xml_element.get("velocity", default=None)), + lower=_str2float(xml_element.get("lower", default=None)), + upper=_str2float(xml_element.get("upper", default=None)), + ) + + def _validate_limit(self, limit, type): + if type in ["revolute", "prismatic"]: + self._validate_required_attribute( + limit, + error_msg="The of a (prismatic, revolute) joint is missing.", + ) + + if limit is not None: + self._validate_required_attribute( + limit.upper, + error_msg="Tag of joint is missing attribute 'upper'.", + ) + self._validate_required_attribute( + limit.lower, + error_msg="Tag of joint is missing attribute 'lower'.", + ) + + if limit is not None: + self._validate_required_attribute( + limit.effort, + error_msg="Tag of joint is missing attribute 'effort'.", + ) + + self._validate_required_attribute( + limit.velocity, + error_msg="Tag of joint is missing attribute 'velocity'.", + ) + + def _write_limit(self, xml_parent, limit): + if limit is None: + return + + attrib = {} + if limit.effort is not None: + attrib["effort"] = str(limit.effort) + if limit.velocity is not None: + attrib["velocity"] = str(limit.velocity) + if limit.lower is not None: + attrib["lower"] = str(limit.lower) + if limit.upper is not None: + attrib["upper"] = str(limit.upper) + + etree.SubElement( + xml_parent, + "limit", + attrib=attrib, + ) + + def _parse_dynamics(xml_element): + if xml_element is None: + return None + + dynamics = Dynamics() + dynamics.damping = xml_element.get("damping", default=None) + dynamics.friction = xml_element.get("friction", default=None) + + return dynamics + + def _write_dynamics(self, xml_parent, dynamics): + if dynamics is None: + return + + attrib = {} + if dynamics.damping is not None: + attrib["damping"] = str(dynamics.damping) + if dynamics.friction is not None: + attrib["friction"] = str(dynamics.friction) + + etree.SubElement( + xml_parent, + "dynamics", + attrib=attrib, + ) + + def _parse_joint(xml_element): + joint = Joint(name=xml_element.attrib["name"]) + + joint.type = xml_element.get("type", default=None) + joint.parent = xml_element.find("parent").get("link") + joint.child = xml_element.find("child").get("link") + joint.origin = URDF._parse_origin(xml_element.find("origin")) + joint.axis = URDF._parse_axis(xml_element.find("axis")) + joint.limit = URDF._parse_limit(xml_element.find("limit")) + joint.dynamics = URDF._parse_dynamics(xml_element.find("dynamics")) + joint.mimic = URDF._parse_mimic(xml_element.find("mimic")) + joint.calibration = URDF._parse_calibration(xml_element.find("calibration")) + joint.safety_controller = URDF._parse_safety_controller(xml_element.find("safety_controller")) + + return joint + + def _validate_joint(self, joint): + self._validate_required_attribute( + attribute=joint.name, + error_msg="The tag misses a 'name' attribute.", + ) + + allowed_types = [ + "revolute", + "continuous", + "prismatic", + "fixed", + "floating", + "planar", + ] + self._validate_required_attribute( + attribute=joint.type, + error_msg=f"The tag misses a 'type' attribute or value is not part of allowed values [{', '.join(allowed_types)}].", + allowed_values=allowed_types, + ) + + self._validate_required_attribute( + joint.parent, + error_msg=f"The of a is missing.", + ) + + self._validate_required_attribute( + joint.child, + error_msg=f"The of a is missing.", + ) + + self._validate_limit(joint.limit, type=joint.type) + + def _write_joint(self, xml_parent, joint): + xml_element = etree.SubElement( + xml_parent, + "joint", + attrib={ + "name": joint.name, + "type": joint.type, + }, + ) + + etree.SubElement(xml_element, "parent", attrib={"link": joint.parent}) + etree.SubElement(xml_element, "child", attrib={"link": joint.child}) + self._write_origin(xml_element, joint.origin) + self._write_axis(xml_element, joint.axis) + self._write_limit(xml_element, joint.limit) + self._write_dynamics(xml_element, joint.dynamics) + + @staticmethod + def _parse_robot(xml_element): + robot = Robot(name=xml_element.attrib["name"]) + + for l in xml_element.findall("link"): + robot.links.append(URDF._parse_link(l)) + for j in xml_element.findall("joint"): + robot.joints.append(URDF._parse_joint(j)) + for m in xml_element.findall("material"): + robot.materials.append(URDF._parse_material(m)) + return robot + + def _validate_robot(self, robot): + if robot is not None: + self._validate_required_attribute( + attribute=robot.name, + error_msg="The tag misses a 'name' attribute.", + ) + + for l in robot.links: + self._validate_link(l) + + for j in robot.joints: + self._validate_joint(j) + + def _write_robot(self, robot): + xml_element = etree.Element("robot", attrib={"name": robot.name}) + for link in robot.links: + self._write_link(xml_element, link) + for joint in robot.joints: + self._write_joint(xml_element, joint) + for material in robot.materials: + self._write_material(xml_element, material) + + return xml_element + + def __eq__(self, other): + if not isinstance(other, URDF): + raise NotImplemented + return self.robot == other.robot + + @property + def filename_handler(self): + return self._filename_handler + + def build_tree(self): + parent_child_map: Dict[str, List[str]] = {} + for joint in self.robot.joints: + if joint.parent in parent_child_map: + parent_child_map[joint.parent].append(joint.child) + else: + parent_child_map[joint.parent] = [joint.child] + + # Sort link with bfs order + bfs_link_list = [self.base_link] + to_be_handle_list = [self.base_link] + while len(to_be_handle_list) > 0: + parent = to_be_handle_list.pop(0) + if parent not in parent_child_map: + continue + + children = parent_child_map[parent] + to_be_handle_list.extend(children) + bfs_link_list.extend(children) + bfs_joint_list = [] + for link_name in bfs_link_list[1:]: + joint_index = [i for i in range(len(self.robot.joints)) if self.robot.joints[i].child == link_name][0] + bfs_joint_list.append(self.robot.joints[joint_index]) + + # Build tree + root = Node(self.base_link, matrix=np.eye(4)) + for joint in bfs_joint_list: + matrix, _ = self._forward_kinematics_joint(joint, 0) + parent_node = anytree.search.findall_by_attr(root, value=joint.parent)[0] + node = Node(joint.child, parent=parent_node, matrix=matrix) + return root + + def update_kinematics(self, configuration): + joint_cfg = [] + + if isinstance(configuration, dict): + for joint in configuration: + if isinstance(joint, six.string_types): + joint_cfg.append((self._joint_map[joint], configuration[joint])) + elif isinstance(joint, Joint): + # TODO: Joint is not hashable; so this branch will not succeed + joint_cfg.append((joint, configuration[joint])) + elif isinstance(configuration, (list, tuple, np.ndarray)): + if len(configuration) == len(self.robot.joints): + for joint, value in zip(self.robot.joints, configuration): + joint_cfg.append((joint, value)) + elif len(configuration) == self.num_actuated_joints: + for joint, value in zip(self._actuated_joints, configuration): + joint_cfg.append((joint, value)) + else: + raise ValueError( + f"Dimensionality of configuration ({len(configuration)}) doesn't match number of all ({len(self.robot.joints)}) or actuated joints ({self.num_actuated_joints})." + ) + else: + raise TypeError("Invalid type for configuration") + + # append all mimic joints in the update + for j, q in joint_cfg + [(j, 0.0) for j in self.robot.joints if j.mimic is not None]: + matrix, _ = self._forward_kinematics_joint(j, q=q) + node = anytree.search.findall_by_attr(self.tree_root, j.child)[0] + node.matrix = matrix + + for node in LevelOrderIter(self.tree_root): + if node.name == self.base_link: + node.global_pose = np.eye(4) + else: + node.global_pose = node.parent.global_pose @ node.matrix + + def get_link_global_transform(self, link_name): + node = anytree.search.findall_by_attr(self.tree_root, link_name)[0] + + return node.global_pose diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..0e54558 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,8 @@ +[tool.black] +line_length = 120 + +[tool.pytest.ini_options] +addopts = "--disable-warnings" +testpaths = [ + "tests", +] diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..15d5ba8 --- /dev/null +++ b/setup.py @@ -0,0 +1,88 @@ +import re +from pathlib import Path + +from setuptools import setup, find_packages + +_here = Path(__file__).resolve().parent +name = "dex_retargeting" + +# Reference: https://github.com/kevinzakka/mjc_viewer/blob/main/setup.py +with open(_here / name / "__init__.py") as f: + meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) + if meta_match: + version = meta_match.group(1) + else: + raise RuntimeError("Unable to find __version__ string.") + +core_requirements = [ + "numpy", + "torch", + "sapien>=2.0.0", + "nlopt", + "trimesh", + "anytree", + "pycollada", +] + +test_requirements = [ + "pytest", + "black", + "isort", + "pytest-xdist", + "pyright", + "ruff", +] + +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", +] + + +def setup_package(): + # Meta information of the project + author = "Yuzhe Qin" + author_email = "y1qin@ucsd.edu" + description = "Hand pose retargeting for dexterous robot hand." + url = "https://github.com/dexsuite/dex-retargeting" + with open(_here / "README.md", "r") as file: + readme = file.read() + + # Package data + packages = find_packages(".") + print(f"Packages: {packages}") + + setup( + name=name, + version=version, + author=author, + author_email=author_email, + maintainer=author, + maintainer_email=author_email, + description=description, + long_description=readme, + long_description_content_type="text/markdown", + url=url, + license='MIT', + license_files=("LICENSE",), + packages=packages, + python_requires='>=3.7,<3.11', + zip_safe=True, + install_requires=core_requirements, + extras_require={ + "test": test_requirements, + }, + classifiers=classifiers, + ) + + +setup_package() diff --git a/tests/test_retargeting_config.py b/tests/test_retargeting_config.py new file mode 100644 index 0000000..1b9a7fc --- /dev/null +++ b/tests/test_retargeting_config.py @@ -0,0 +1,24 @@ +from pathlib import Path + +import pytest + +from dex_retargeting.retargeting_config import RetargetingConfig +from dex_retargeting.seq_retarget import SeqRetargeting +from utils import VECTOR_CONFIG_DICT, POSITION_CONFIG_DICT, DEXPILOT_CONFIG_DICT + + +class TestRetargetingConfig: + config_dir = Path(__file__).parent.parent / "dex_retargeting" / "configs" + robot_dir = Path(__file__).parent.parent / "assets" / "robots" + RetargetingConfig.set_default_urdf_dir(str(robot_dir.absolute())) + + config_paths = ( + list(VECTOR_CONFIG_DICT.values()) + list(POSITION_CONFIG_DICT.values()) + list(DEXPILOT_CONFIG_DICT.values()) + ) + + @pytest.mark.parametrize("config_path", config_paths) + def test_config_parsing(self, config_path): + config_path = self.config_dir / config_path + config = RetargetingConfig.load_from_file(config_path) + retargeting = config.build() + assert type(retargeting) == SeqRetargeting diff --git a/tests/test_sapien_optimizer.py b/tests/test_sapien_optimizer.py new file mode 100644 index 0000000..dea42aa --- /dev/null +++ b/tests/test_sapien_optimizer.py @@ -0,0 +1,125 @@ +from pathlib import Path +from time import time + +import numpy as np +import pytest +import sapien.core as sapien + +from dex_retargeting.optimizer import VectorOptimizer, PositionOptimizer +from dex_retargeting.retargeting_config import RetargetingConfig +from utils import ROBOT_NAMES, VECTOR_CONFIG_DICT, POSITION_CONFIG_DICT + + +class TestVectorOptimizer: + np.set_printoptions(precision=4) + config_dir = Path(__file__).parent.parent / "dex_retargeting" / "configs" + robot_dir = Path(__file__).parent.parent / "assets" / "robots" + RetargetingConfig.set_default_urdf_dir(str(robot_dir.absolute())) + + @staticmethod + def generate_vector_retargeting_data_gt(robot: sapien.Articulation, optimizer: VectorOptimizer): + joint_limit = robot.get_qlimits() + random_qpos = np.random.uniform(joint_limit[:, 0], joint_limit[:, 1]) + robot.set_qpos(random_qpos) + + random_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.robot_link_indices]) + origin_pos = random_pos[optimizer.origin_link_indices] + task_pos = random_pos[optimizer.task_link_indices] + random_target_vector = task_pos - origin_pos + init_qpos = np.clip(random_qpos + np.random.randn(robot.dof) * 0.5, joint_limit[:, 0], joint_limit[:, 1]) + + return random_qpos, init_qpos, random_target_vector + + @staticmethod + def generate_position_retargeting_data_gt(robot: sapien.Articulation, optimizer: PositionOptimizer): + joint_limit = robot.get_qlimits() + random_qpos = np.random.uniform(joint_limit[:, 0], joint_limit[:, 1]) + robot.set_qpos(random_qpos) + + random_target_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.target_link_indices]) + init_qpos = np.clip(random_qpos + np.random.randn(robot.dof) * 0.5, joint_limit[:, 0], joint_limit[:, 1]) + + return random_qpos, init_qpos, random_target_pos + + @pytest.mark.parametrize("robot_name", ROBOT_NAMES[:1]) + def test_position_optimizer(self, robot_name): + config_path = self.config_dir / POSITION_CONFIG_DICT[robot_name] + + # Note: The parameters below are adjusted solely for this test + # The smoothness penalty is deactivated here, meaning no low pass filter and no continuous joint value + # This is because the test is focused solely on the efficiency of single step optimization + override = dict() + override["normal_delta"] = 0 + config = RetargetingConfig.load_from_file(config_path, override) + + retargeting = config.build() + + robot = retargeting.optimizer.robot + optimizer = retargeting.optimizer + + num_optimization = 100 + tic = time() + errors = dict(pos=[], joint=[]) + np.random.seed(1) + for i in range(num_optimization): + # Sampled random position + random_qpos, init_qpos, random_target_pos = self.generate_position_retargeting_data_gt(robot, optimizer) + + # Optimized position + computed_qpos = optimizer.retarget(random_target_pos, fixed_qpos=[], last_qpos=init_qpos[:]) + robot.set_qpos(np.array(computed_qpos)) + computed_target_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.target_link_indices]) + + # Position difference + error = np.mean(np.linalg.norm(computed_target_pos - random_target_pos, axis=1)) + errors["pos"].append(error) + + tac = time() + print(f"Mean optimization position error: {np.mean(errors['pos'])}") + print(f"Retargeting computation takes {tac - tic}s for {num_optimization} times") + assert np.mean(errors["pos"]) < 1e-2 + + @pytest.mark.parametrize("robot_name", ROBOT_NAMES) + def test_vector_optimizer(self, robot_name): + config_path = self.config_dir / VECTOR_CONFIG_DICT[robot_name] + + # Note: The parameters below are adjusted solely for this test + # For retargeting from human to robot, their values should remain the default in the retargeting config + # The smoothness penalty is deactivated here, meaning no low pass filter and no continuous joint value + # This is because the test is focused solely on the efficiency of single step optimization + override = dict() + override["type"] = "vector" # cast it to vector retargeting even if it is DexPilot retargeting + override["low_pass_alpha"] = 0 + override["scaling_factor"] = 1.0 + override["normal_delta"] = 0 + config = RetargetingConfig.load_from_file(config_path, override) + + retargeting = config.build() + + robot = retargeting.optimizer.robot + optimizer = retargeting.optimizer + + num_optimization = 100 + tic = time() + errors = dict(pos=[], joint=[]) + np.random.seed(1) + for i in range(num_optimization): + # Sampled random vector + random_qpos, init_qpos, random_target_vector = self.generate_vector_retargeting_data_gt(robot, optimizer) + + # Optimized vector + computed_qpos = optimizer.retarget(random_target_vector, fixed_qpos=[], last_qpos=init_qpos[:]) + robot.set_qpos(np.array(computed_qpos)) + computed_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.robot_link_indices]) + computed_origin_pos = computed_pos[optimizer.origin_link_indices] + computed_task_pos = computed_pos[optimizer.task_link_indices] + computed_target_vector = computed_task_pos - computed_origin_pos + + # Vector difference + error = np.mean(np.linalg.norm(computed_target_vector - random_target_vector, axis=1)) + errors["pos"].append(error) + + tac = time() + print(f"Mean optimization vector error: {np.mean(errors['pos'])}") + print(f"Retargeting computation takes {tac - tic}s for {num_optimization} times") + assert np.mean(errors["pos"]) < 1e-2 diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..7ce9961 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,17 @@ +VECTOR_CONFIG_DICT = { + "allegro_right": "teleop/allegro_hand_right.yml", + "allegro_left": "teleop/allegro_hand_left.yml", + "shadow_right": "teleop/shadow_hand_right.yml", + "svh_right": "teleop/schunk_svh_hand_right.yml", +} +POSITION_CONFIG_DICT = { + "allegro_right": "offline/allegro_hand_right.yml", + # "allegro_left": "offline/allegro_hand_left.yml", + # "shadow_right": "offline/shadow_hand_right.yml", + # "svh_right": "offline/schunk_svh_hand_right.yml", +} +DEXPILOT_CONFIG_DICT = { + "allegro_right": "teleop/allegro_hand_right_dexpilot.yml", +} + +ROBOT_NAMES = list(VECTOR_CONFIG_DICT.keys())