diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..b6a0f46
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,36 @@
+name: Core Tests
+
+on:
+ push:
+ branches: [ main, init_dev ]
+ pull_request:
+
+permissions:
+ contents: read
+
+jobs:
+ build:
+
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [ "3.7", "3.8", "3.9", "3.10" ]
+
+ steps:
+ - uses: actions/checkout@v3
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ pip install --upgrade -e .[dev]
+ - name: Run Black
+ run: |
+ black dex-retargeting/ tests/ --check
+ - name: Run Pyright
+ run: |
+ pyright
+ - name: Test with pytest
+ run: |
+ pytest
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e8ac431
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+build
+.vscode
+.pyc
+*.blend
+cmake-build-debug
+/cmake-build-debug/
+/.ccls-cache/
+/.dir-locals.el
+__pycache__
+.ini
+*.convex.*
+imgui.ini
+.mypy_cache
+.DS_Store
+/.idea
+/log
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..a9822ef
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,4 @@
+[submodule "assets"]
+ path = assets
+ url = https://github.com/dexsuite/dex-urdf.git
+ branch = init_dev
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..15904fb
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+The MIT License (MIT)
+
+Copyright (c) 2023 Yuzhe Qin
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e12dc25
--- /dev/null
+++ b/README.md
@@ -0,0 +1,9 @@
+Dex Retargeting
+---
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/assets b/assets
new file mode 160000
index 0000000..32ce565
--- /dev/null
+++ b/assets
@@ -0,0 +1 @@
+Subproject commit 32ce565393e9e44948fb4f86c55964ef65c2d33b
diff --git a/dex_retargeting/__init__.py b/dex_retargeting/__init__.py
new file mode 100644
index 0000000..f102a9c
--- /dev/null
+++ b/dex_retargeting/__init__.py
@@ -0,0 +1 @@
+__version__ = "0.0.1"
diff --git a/dex_retargeting/configs/offline/allegro_hand_right.yml b/dex_retargeting/configs/offline/allegro_hand_right.yml
new file mode 100644
index 0000000..fcada15
--- /dev/null
+++ b/dex_retargeting/configs/offline/allegro_hand_right.yml
@@ -0,0 +1,12 @@
+retargeting:
+ type: position
+ urdf_path: allegro_hand/allegro_hand_right.urdf
+ use_camera_frame_retargeting: False
+
+ target_joint_names: null
+ target_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
+
+ target_link_human_indices: [ 4, 8, 12, 16 ]
+
+ # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
+ low_pass_alpha: 0.2
diff --git a/dex_retargeting/configs/teleop/allegro_hand_left.yml b/dex_retargeting/configs/teleop/allegro_hand_left.yml
new file mode 100644
index 0000000..ff6bf91
--- /dev/null
+++ b/dex_retargeting/configs/teleop/allegro_hand_left.yml
@@ -0,0 +1,17 @@
+retargeting:
+ type: vector
+ urdf_path: allegro_hand/allegro_hand_left.urdf
+ use_camera_frame_retargeting: False
+
+ # Target refers to the retargeting target, which is the robot hand
+ target_joint_names: null
+ target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ]
+ target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
+ scaling_factor: 1.6
+
+ # Source refers to the retargeting input, which usually corresponds to the human hand
+ # The joint indices of human hand joint which corresponds to each link in the target_link_names
+ target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ]
+
+ # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
+ low_pass_alpha: 0.2
diff --git a/dex_retargeting/configs/teleop/allegro_hand_right.yml b/dex_retargeting/configs/teleop/allegro_hand_right.yml
new file mode 100644
index 0000000..11296e1
--- /dev/null
+++ b/dex_retargeting/configs/teleop/allegro_hand_right.yml
@@ -0,0 +1,17 @@
+retargeting:
+ type: vector
+ urdf_path: allegro_hand/allegro_hand_right.urdf
+ use_camera_frame_retargeting: False
+
+ # Target refers to the retargeting target, which is the robot hand
+ target_joint_names: null
+ target_origin_link_names: [ "wrist", "wrist", "wrist", "wrist" ]
+ target_task_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
+ scaling_factor: 1.6
+
+ # Source refers to the retargeting input, which usually corresponds to the human hand
+ # The joint indices of human hand joint which corresponds to each link in the target_link_names
+ target_link_human_indices: [ [ 0, 0, 0, 0 ], [ 4, 8, 12, 16 ] ]
+
+ # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
+ low_pass_alpha: 0.2
diff --git a/dex_retargeting/configs/teleop/allegro_hand_right_dexpilot.yml b/dex_retargeting/configs/teleop/allegro_hand_right_dexpilot.yml
new file mode 100644
index 0000000..dd5883d
--- /dev/null
+++ b/dex_retargeting/configs/teleop/allegro_hand_right_dexpilot.yml
@@ -0,0 +1,13 @@
+retargeting:
+ type: DexPilot
+ urdf_path: allegro_hand/allegro_hand_right.urdf
+ use_camera_frame_retargeting: False
+
+ # Target refers to the retargeting target, which is the robot hand
+ target_joint_names: null
+ wrist_link_name: "wrist"
+ finger_tip_link_names: [ "link_15.0_tip", "link_3.0_tip", "link_7.0_tip", "link_11.0_tip" ]
+ scaling_factor: 1.6
+
+ # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
+ low_pass_alpha: 0.2
diff --git a/dex_retargeting/configs/teleop/schunk_svh_hand_right.yml b/dex_retargeting/configs/teleop/schunk_svh_hand_right.yml
new file mode 100644
index 0000000..bf622e3
--- /dev/null
+++ b/dex_retargeting/configs/teleop/schunk_svh_hand_right.yml
@@ -0,0 +1,17 @@
+retargeting:
+ type: vector
+ urdf_path: schunk_hand/schunk_svh_hand_right.urdf
+ use_camera_frame_retargeting: False
+
+ # Target refers to the retargeting target, which is the robot hand
+ target_joint_names: null
+ target_origin_link_names: [ "right_hand_base_link","right_hand_base_link", "right_hand_base_link", "right_hand_base_link", "right_hand_base_link", ]
+ target_task_link_names: [ "right_hand_c", "right_hand_t", "right_hand_s", "right_hand_r", "right_hand_q" ]
+ scaling_factor: 1.1
+
+ # Source refers to the retargeting input, which usually corresponds to the human hand
+ # The joint indices of human hand joint which corresponds to each link in the target_link_names
+ target_link_human_indices: [ [ 0, 0, 0, 0, 0 ], [ 4, 8, 12, 16, 20, ] ]
+
+ # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
+ low_pass_alpha: 0.2
diff --git a/dex_retargeting/configs/teleop/shadow_hand_right.yml b/dex_retargeting/configs/teleop/shadow_hand_right.yml
new file mode 100644
index 0000000..b1835c1
--- /dev/null
+++ b/dex_retargeting/configs/teleop/shadow_hand_right.yml
@@ -0,0 +1,17 @@
+retargeting:
+ type: vector
+ urdf_path: shadow_hand/shadow_hand_right.urdf
+ use_camera_frame_retargeting: False
+
+ # Target refers to the retargeting target, which is the robot hand
+ target_joint_names: null
+ target_origin_link_names: [ "palm", "palm", "palm", "palm", "palm", "palm", "palm", "palm", "palm", "palm" ]
+ target_task_link_names: [ "thtip", "fftip", "mftip", "rftip", "lftip", "thmiddle", "ffmiddle", "mfmiddle", "rfmiddle", "lfmiddle" ]
+ scaling_factor: 1.2
+
+ # Source refers to the retargeting input, which usually corresponds to the human hand
+ # The joint indices of human hand joint which corresponds to each link in the target_link_names
+ target_link_human_indices: [ [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ], [ 4, 8, 12, 16, 20, 2, 6, 10, 14, 18 ] ]
+
+ # A smaller alpha means stronger filtering, i.e. more smooth but also larger latency
+ low_pass_alpha: 0.2
diff --git a/dex_retargeting/optimizer.py b/dex_retargeting/optimizer.py
new file mode 100644
index 0000000..6dda050
--- /dev/null
+++ b/dex_retargeting/optimizer.py
@@ -0,0 +1,475 @@
+from abc import abstractmethod
+from typing import List
+
+import nlopt
+import numpy as np
+import sapien.core as sapien
+import torch
+
+
+class Optimizer:
+ retargeting_type = "BASE"
+
+ def __init__(
+ self, robot: sapien.Articulation, target_joint_names: List[str], target_link_human_indices: np.ndarray
+ ):
+ self.robot = robot
+ self.robot_dof = robot.dof
+ self.model = robot.create_pinocchio_model()
+
+ joint_names = [joint.get_name() for joint in robot.get_active_joints()]
+ target_joint_index = []
+ for target_joint_name in target_joint_names:
+ if target_joint_name not in joint_names:
+ raise ValueError(f"Joint {target_joint_name} given does not appear to be in robot XML.")
+ target_joint_index.append(joint_names.index(target_joint_name))
+ self.target_joint_names = target_joint_names
+ self.target_joint_indices = np.array(target_joint_index)
+ self.fixed_joint_indices = np.array([i for i in range(robot.dof) if i not in target_joint_index], dtype=int)
+ self.opt = nlopt.opt(nlopt.LD_SLSQP, len(target_joint_index))
+ self.dof = len(target_joint_index)
+
+ # Target
+ self.target_link_human_indices = target_link_human_indices
+
+ def set_joint_limit(self, joint_limits: np.ndarray):
+ if joint_limits.shape != (self.dof, 2):
+ raise ValueError(f"Expect joint limits have shape: {(self.dof, 2)}, but get {joint_limits.shape}")
+ self.opt.set_lower_bounds(joint_limits[:, 0].tolist())
+ self.opt.set_upper_bounds(joint_limits[:, 1].tolist())
+
+ def get_last_result(self):
+ return self.opt.last_optimize_result()
+
+ def get_link_names(self):
+ return [link.get_name() for link in self.robot.get_links()]
+
+ def get_link_indices(self, target_link_names):
+ target_link_index = []
+ for target_link_name in target_link_names:
+ if target_link_name not in self.get_link_names():
+ raise ValueError(f"Body {target_link_name} given does not appear to be in robot XML.")
+ target_link_index.append(self.get_link_names().index(target_link_name))
+ return target_link_index
+
+ @abstractmethod
+ def retarget(self, ref_value, fixed_qpos, last_qpos=None):
+ pass
+
+ def optimize(self, objective_fn, last_qpos):
+ self.opt.set_min_objective(objective_fn)
+ try:
+ qpos = self.opt.optimize(last_qpos)
+ except RuntimeError as e:
+ print(e)
+ return np.array(last_qpos)
+ return qpos
+
+
+class PositionOptimizer(Optimizer):
+ retargeting_type = "position"
+
+ def __init__(
+ self,
+ robot: sapien.Articulation,
+ target_joint_names: List[str],
+ target_link_names: List[str],
+ target_link_human_indices: np.ndarray,
+ huber_delta=0.02,
+ norm_delta=4e-3,
+ ):
+ super().__init__(robot, target_joint_names, target_link_human_indices)
+ self.body_names = target_link_names
+ self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta)
+ self.norm_delta = norm_delta
+
+ # Sanity check and cache link indices
+ self.target_link_indices = self.get_link_indices(target_link_names)
+
+ # Use local jacobian if target link name <= 2, otherwise first cache all jacobian and then get all
+ # This is only for the speed but will not affect the performance
+ if len(target_link_names) <= 40:
+ self.use_sparse_jacobian = True
+ else:
+ self.use_sparse_jacobian = False
+ self.opt.set_ftol_abs(1e-5)
+
+ def _get_objective_function(self, target_pos: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray):
+ qpos = np.zeros(self.robot_dof)
+ qpos[self.fixed_joint_indices] = fixed_qpos
+ torch_target_pos = torch.as_tensor(target_pos)
+ torch_target_pos.requires_grad_(False)
+
+ def objective(x: np.ndarray, grad: np.ndarray) -> float:
+ qpos[self.target_joint_indices] = x
+ self.model.compute_forward_kinematics(qpos)
+ target_link_poses = [self.model.get_link_pose(index) for index in self.target_link_indices]
+ body_pos = np.array([pose.p for pose in target_link_poses])
+
+ # Torch computation for accurate loss and grad
+ torch_body_pos = torch.as_tensor(body_pos)
+ torch_body_pos.requires_grad_()
+
+ # Loss term for kinematics retargeting based on 3D position error
+ huber_distance = self.huber_loss(torch_body_pos, torch_target_pos)
+ # huber_distance = torch.norm(torch_body_pos - torch_target_pos, dim=1).mean()
+ result = huber_distance.cpu().detach().item()
+
+ if grad.size > 0:
+ if self.use_sparse_jacobian:
+ jacobians = []
+ for i, index in enumerate(self.target_link_indices):
+ link_spatial_jacobian = self.model.compute_single_link_local_jacobian(qpos, index)[
+ :3, self.target_joint_indices
+ ]
+ link_rot = self.model.get_link_pose(index).to_transformation_matrix()[:3, :3]
+ link_kinematics_jacobian = link_rot @ link_spatial_jacobian
+ jacobians.append(link_kinematics_jacobian)
+ jacobians = np.stack(jacobians, axis=0)
+ else:
+ self.model.compute_full_jacobian(qpos)
+ jacobians = [
+ self.model.get_link_jacobian(index, local=True)[:3, self.target_joint_indices]
+ for index in self.target_link_indices
+ ]
+
+ huber_distance.backward()
+ grad_pos = torch_body_pos.grad.cpu().numpy()[:, None, :]
+ grad_qpos = np.matmul(grad_pos, np.array(jacobians))
+ grad_qpos = grad_qpos.mean(1).sum(0)
+
+ grad_qpos += 2 * self.norm_delta * (x - last_qpos)
+
+ grad[:] = grad_qpos[:]
+
+ return result
+
+ return objective
+
+ def retarget(self, ref_value, fixed_qpos, last_qpos=None):
+ if len(fixed_qpos) != len(self.fixed_joint_indices):
+ raise ValueError(
+ f"Optimizer has {len(self.fixed_joint_indices)} joints but non_target_qpos {fixed_qpos} is given"
+ )
+ if last_qpos is None:
+ last_qpos = np.zeros(self.dof)
+ if isinstance(last_qpos, np.ndarray):
+ last_qpos = last_qpos.astype(np.float32)
+ last_qpos = list(last_qpos)
+ objective_fn = self._get_objective_function(ref_value, fixed_qpos, np.array(last_qpos).astype(np.float32))
+ return self.optimize(objective_fn, last_qpos)
+
+
+class VectorOptimizer(Optimizer):
+ retargeting_type = "VECTOR"
+
+ def __init__(
+ self,
+ robot: sapien.Articulation,
+ target_joint_names: List[str],
+ target_origin_link_names: List[str],
+ target_task_link_names: List[str],
+ target_link_human_indices: np.ndarray,
+ huber_delta=0.02,
+ norm_delta=4e-3,
+ scaling=1.0,
+ ):
+ super().__init__(robot, target_joint_names, target_link_human_indices)
+ self.origin_link_names = target_origin_link_names
+ self.task_link_names = target_task_link_names
+ self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta)
+ self.norm_delta = norm_delta
+ self.scaling = scaling
+
+ # Computation cache for better performance
+ # For one link used in multiple vectors, e.g. hand palm, we do not want to compute it multiple times
+ self.computed_link_names = list(set(target_origin_link_names).union(set(target_task_link_names)))
+ self.origin_link_indices = torch.tensor(
+ [self.computed_link_names.index(name) for name in target_origin_link_names]
+ )
+ self.task_link_indices = torch.tensor([self.computed_link_names.index(name) for name in target_task_link_names])
+
+ # Sanity check and cache link indices
+ self.robot_link_indices = self.get_link_indices(self.computed_link_names)
+
+ # Use local jacobian if target link name <= 2, otherwise first cache all jacobian and then get all
+ # This is only for the speed but will not affect the performance
+ if len(self.computed_link_names) <= 40:
+ self.use_sparse_jacobian = True
+ else:
+ self.use_sparse_jacobian = False
+ self.opt.set_ftol_abs(1e-6)
+
+ def _get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray):
+ qpos = np.zeros(self.robot_dof)
+ qpos[self.fixed_joint_indices] = fixed_qpos
+ torch_target_vec = torch.as_tensor(target_vector) * self.scaling
+ torch_target_vec.requires_grad_(False)
+
+ def objective(x: np.ndarray, grad: np.ndarray) -> float:
+ qpos[self.target_joint_indices] = x
+ self.model.compute_forward_kinematics(qpos)
+ target_link_poses = [self.model.get_link_pose(index) for index in self.robot_link_indices]
+ body_pos = np.array([pose.p for pose in target_link_poses])
+
+ # Torch computation for accurate loss and grad
+ torch_body_pos = torch.as_tensor(body_pos)
+ torch_body_pos.requires_grad_()
+
+ # Index link for computation
+ origin_link_pos = torch_body_pos[self.origin_link_indices, :]
+ task_link_pos = torch_body_pos[self.task_link_indices, :]
+ robot_vec = task_link_pos - origin_link_pos
+
+ # Loss term for kinematics retargeting based on 3D position error
+ huber_distance = self.huber_loss(robot_vec, torch_target_vec)
+ result = huber_distance.cpu().detach().item()
+
+ if grad.size > 0:
+ if self.use_sparse_jacobian:
+ jacobians = []
+ for i, index in enumerate(self.robot_link_indices):
+ link_spatial_jacobian = self.model.compute_single_link_local_jacobian(qpos, index)[
+ :3, self.target_joint_indices
+ ]
+ link_rot = self.model.get_link_pose(index).to_transformation_matrix()[:3, :3]
+ link_kinematics_jacobian = link_rot @ link_spatial_jacobian
+ jacobians.append(link_kinematics_jacobian)
+ jacobians = np.stack(jacobians, axis=0)
+ else:
+ self.model.compute_full_jacobian(qpos)
+ jacobians = [
+ self.model.get_link_jacobian(index, local=True)[:3, self.target_joint_indices]
+ for index in self.robot_link_indices
+ ]
+
+ huber_distance.backward()
+ grad_pos = torch_body_pos.grad.cpu().numpy()[:, None, :]
+ grad_qpos = np.matmul(grad_pos, np.array(jacobians))
+ grad_qpos = grad_qpos.mean(1).sum(0)
+
+ grad_qpos += 2 * self.norm_delta * (x - last_qpos)
+
+ grad[:] = grad_qpos[:]
+
+ return result
+
+ return objective
+
+ def retarget(self, ref_value, fixed_qpos, last_qpos=None):
+ if len(fixed_qpos) != len(self.fixed_joint_indices):
+ raise ValueError(
+ f"Optimizer has {len(self.fixed_joint_indices)} joints but non_target_qpos {fixed_qpos} is given"
+ )
+ if last_qpos is None:
+ last_qpos = np.zeros(self.dof)
+ last_qpos = list(last_qpos)
+ objective_fn = self._get_objective_function(ref_value, fixed_qpos, np.array(last_qpos).astype(np.float32))
+ return self.optimize(objective_fn, last_qpos)
+
+
+class DexPilotAllegroOptimizer(Optimizer):
+ retargeting_type = "DEXPILOT"
+
+ def __init__(
+ self,
+ robot: sapien.Articulation,
+ target_joint_names: List[str],
+ finger_tip_link_names: List[str],
+ wrist_link_name: str,
+ # DexPilot parameters
+ gamma=2.5e-3,
+ project_dist=0.03,
+ escape_dist=0.05,
+ eta1=1e-4,
+ eta2=3e-2,
+ scaling=1.0,
+ ):
+ is_four_finger = len(finger_tip_link_names) == 4
+ link_names = [wrist_link_name] + finger_tip_link_names
+ if is_four_finger:
+ origin_link_index = [2, 3, 4, 3, 4, 4, 0, 0, 0, 0]
+ task_link_index = [1, 1, 1, 2, 2, 3, 1, 2, 3, 4]
+ target_origin_link_names = [link_names[index] for index in origin_link_index]
+ target_task_link_names = [link_names[index] for index in task_link_index]
+ target_link_human_indices = np.array(
+ [[8, 12, 16, 12, 16, 16, 0, 0, 0, 0], [4, 4, 4, 8, 8, 12, 4, 8, 12, 16]]
+ )
+ else:
+ raise NotImplementedError
+
+ super().__init__(robot, target_joint_names, target_link_human_indices)
+ self.origin_link_names = target_origin_link_names
+ self.task_link_names = target_task_link_names
+ self.scaling = scaling
+ # self.norm_delta = norm_delta
+ # self.huber_loss = torch.nn.SmoothL1Loss(beta=huber_delta)
+
+ # DexPilot parameters
+ self.gamma = gamma
+ self.project_dist = project_dist
+ self.escape_dist = escape_dist
+ self.eta1 = eta1
+ self.eta2 = eta2
+
+ # Computation cache for better performance
+ # For one link used in multiple vectors, e.g. hand palm, we do not want to compute it multiple times
+ self.computed_link_names = list(set(target_origin_link_names).union(set(target_task_link_names)))
+ self.origin_link_indices = torch.tensor(
+ [self.computed_link_names.index(name) for name in target_origin_link_names]
+ )
+ self.task_link_indices = torch.tensor([self.computed_link_names.index(name) for name in target_task_link_names])
+
+ # Sanity check and cache link indices
+ self.robot_link_indices = self.get_link_indices(self.computed_link_names)
+
+ # Use local jacobian if target link name <= 2, otherwise first cache all jacobian and then get all
+ # This is only for the speed but will not affect the performance
+ if len(self.computed_link_names) <= 40:
+ self.use_sparse_jacobian = True
+ else:
+ self.use_sparse_jacobian = False
+ self.opt.set_ftol_abs(1e-6)
+
+ # DexPilot cache
+ self.projected = np.zeros(6, dtype=bool)
+ self.s2_project_index_origin = np.array([1, 2, 2], dtype=int)
+ self.s2_project_index_task = np.array([0, 0, 1], dtype=int)
+ self.projected_dist = np.array([eta1] * 3 + [eta2] * 3)
+
+ def _get_objective_function(self, target_vector: np.ndarray, fixed_qpos: np.ndarray, last_qpos: np.ndarray):
+ target_vector = target_vector.astype(np.float32)
+ qpos = np.zeros(self.robot_dof)
+ qpos[self.fixed_joint_indices] = fixed_qpos
+
+ # Update projection indicator
+ target_vec_dist = np.linalg.norm(target_vector[:6], axis=1)
+ self.projected[:3][target_vec_dist[0:3] < self.project_dist] = True
+ self.projected[:3][target_vec_dist[0:3] > self.escape_dist] = False
+ self.projected[3:6] = np.logical_and(
+ self.projected[:3][self.s2_project_index_origin], self.projected[:3][self.s2_project_index_task]
+ )
+
+ # Update weight vector
+ normal_weight = np.ones(6, dtype=np.float32)
+ high_weight = np.array([200] * 3 + [400] * 3, dtype=np.float32)
+ weight = np.where(self.projected, high_weight, normal_weight)
+ weight = torch.from_numpy(np.concatenate([weight, np.ones(4, dtype=np.float32) * 10]))
+
+ # Compute reference distance vector
+ normal_vec = target_vector * self.scaling # (10, 3)
+ dir_vec = target_vector[:6] / (target_vec_dist[:, None] + 1e-6) # (6, 3)
+ projected_vec = dir_vec * self.projected_dist[:, None] # (6, 3)
+
+ # Compute final reference vector
+ reference_vec = np.where(self.projected[:, None], projected_vec, normal_vec[:6]) # (6, 3)
+ reference_vec = np.concatenate([reference_vec, normal_vec[6:10]], axis=0) # (10, 3)
+ torch_ref_vec = torch.as_tensor(reference_vec, dtype=torch.float32)
+ torch_ref_vec.requires_grad_(False)
+
+ def objective(x: np.ndarray, grad: np.ndarray) -> float:
+ qpos[self.target_joint_indices] = x
+ self.model.compute_forward_kinematics(qpos)
+ target_link_poses = [self.model.get_link_pose(index) for index in self.robot_link_indices]
+ body_pos = np.array([pose.p for pose in target_link_poses])
+
+ # Torch computation for accurate loss and grad
+ torch_body_pos = torch.as_tensor(body_pos)
+ torch_body_pos.requires_grad_()
+
+ # Index link for computation
+ origin_link_pos = torch_body_pos[self.origin_link_indices, :]
+ task_link_pos = torch_body_pos[self.task_link_indices, :]
+ robot_vec = task_link_pos - origin_link_pos
+
+ # Loss term for kinematics retargeting based on 3D position error
+ error = robot_vec - torch_ref_vec
+ mse_loss = torch.sum(error * error, dim=1) # (10)
+ weighted_mse_loss = torch.sum(mse_loss * weight)
+ result = weighted_mse_loss.cpu().detach().item()
+
+ if grad.size > 0:
+ if self.use_sparse_jacobian:
+ jacobians = []
+ for i, index in enumerate(self.robot_link_indices):
+ link_spatial_jacobian = self.model.compute_single_link_local_jacobian(qpos, index)[
+ :3, self.target_joint_indices
+ ]
+ link_rot = self.model.get_link_pose(index).to_transformation_matrix()[:3, :3]
+ link_kinematics_jacobian = link_rot @ link_spatial_jacobian
+ jacobians.append(link_kinematics_jacobian)
+ jacobians = np.stack(jacobians, axis=0)
+ else:
+ self.model.compute_full_jacobian(qpos)
+ jacobians = [
+ self.model.get_link_jacobian(index, local=True)[:3, self.target_joint_indices]
+ for index in self.robot_link_indices
+ ]
+
+ weighted_mse_loss.backward()
+ grad_pos = torch_body_pos.grad.cpu().numpy()[:, None, :]
+ grad_qpos = np.matmul(grad_pos, np.array(jacobians))
+ grad_qpos = grad_qpos.mean(1).sum(0)
+
+ # Finally, γ = 2.5 × 10−3 is a weight on regularizing the Allegro angles to zero
+ # (equivalent to fully opened the hand)
+ grad_qpos += 2 * self.gamma * x
+
+ grad[:] = grad_qpos[:]
+
+ return result
+
+ return objective
+
+ def retarget(self, ref_value, fixed_qpos, last_qpos=None):
+ if len(fixed_qpos) != len(self.fixed_joint_indices):
+ raise ValueError(
+ f"Optimizer has {len(self.fixed_joint_indices)} joints but non_target_qpos {fixed_qpos} is given"
+ )
+ if last_qpos is None:
+ last_qpos = np.zeros(self.dof)
+ objective_fn = self._get_objective_function(ref_value, fixed_qpos, np.array(last_qpos).astype(np.float32))
+ return self.optimize(objective_fn, last_qpos)
+
+
+class DexPilotAllegroV4Optimizer(DexPilotAllegroOptimizer):
+ origin_link_names = [
+ "link_3.0_tip",
+ "link_7.0_tip",
+ "link_11.0_tip",
+ "link_7.0_tip",
+ "link_11.0_tip",
+ "link_11.0_tip",
+ "wrist",
+ "wrist",
+ "wrist",
+ "wrist", # root
+ ]
+ task_link_names = [
+ "link_15.0_tip",
+ "link_15.0_tip",
+ "link_15.0_tip",
+ "link_7.0_tip",
+ "link_7.0_tip",
+ "link_11.0_tip",
+ "link_15.0_tip",
+ "link_3.0_tip",
+ "link_7.0_tip",
+ "link_11.0_tip",
+ ]
+ target_link_human_indices = np.array([[8, 12, 16, 12, 16, 16, 0, 0, 0, 0], [4, 4, 4, 8, 8, 12, 4, 8, 12, 16]])
+
+ def __init__(
+ self,
+ robot: sapien.Articulation,
+ target_joint_names: List[str],
+ # DexPilot parameters
+ gamma=2.5e-3,
+ project_dist=0.03,
+ escape_dist=0.05,
+ eta1=1e-4,
+ eta2=3e-2,
+ scaling=2.0,
+ ):
+ super().__init__(robot, target_joint_names, gamma, project_dist, escape_dist, eta1, eta2)
+ self.scaling = scaling
diff --git a/dex_retargeting/optimizer_utils.py b/dex_retargeting/optimizer_utils.py
new file mode 100644
index 0000000..430fc7f
--- /dev/null
+++ b/dex_retargeting/optimizer_utils.py
@@ -0,0 +1,112 @@
+import numpy as np
+import sapien.core as sapien
+from sapien.core import Pose
+from transforms3d.euler import euler2quat
+
+
+class LPFilter:
+ def __init__(self, alpha):
+ self.alpha = alpha
+ self.y = None
+ self.is_init = False
+
+ def next(self, x):
+ if not self.is_init:
+ self.y = x
+ self.is_init = True
+ return self.y.copy()
+ self.y = self.y + self.alpha * (x - self.y)
+ return self.y.copy()
+
+ def reset(self):
+ self.y = None
+ self.is_init = False
+
+
+def add_dummy_free_joint(
+ robot_builder: sapien.ArticulationBuilder,
+ joint_indicator=(True, True, True, True, True, True),
+ translation_range=(-1, 1),
+ rotation_range=(-np.pi, np.pi),
+):
+ assert len(joint_indicator) == 6
+ new_root = robot_builder.create_link_builder()
+ parent = new_root
+
+ # Prepare link and joint properties
+ joint_types = ["prismatic"] * 3 + ["revolute"] * 3
+ joint_limit = [translation_range] * 3 + [rotation_range] * 3
+ joint_name = [f"dummy_{name}_translation_joint" for name in "xyz"] + [
+ f"dummy_{name}_rotation_joint" for name in "xyz"
+ ]
+ link_name = [f"dummy_{name}_translation_link" for name in "xyz"] + [f"dummy_{name}_rotation_link" for name in "xyz"]
+
+ # Find root link which has no parent
+ root_link_builder = None
+ for link_builder in robot_builder.get_link_builders():
+ if link_builder.get_parent() == -1:
+ root_link_builder = link_builder
+ break
+ assert root_link_builder is not None
+
+ # Build free root
+ valid_joint_num = 0
+ for i in range(6):
+ # Joint orders are x,y,z translation and then x,y,z rotation, 6 in total
+ # If joint indicator for specific index is False, do not build this joint
+ if not joint_indicator[i]:
+ continue
+ valid_joint_num += 1
+
+ # Add small inertia property for more stable simulation
+ parent.set_mass_and_inertia(1e-4, Pose(np.zeros(3)), np.ones(3) * 1e-6)
+ parent.set_name(link_name[i])
+
+ # The last joint will connect the last dummy link to the original root link of the robot
+ if valid_joint_num < sum(joint_indicator):
+ child = robot_builder.create_link_builder(parent)
+ else:
+ child = root_link_builder
+ child.set_parent(parent.get_index())
+ child.set_joint_name(joint_name[i])
+
+ # Build joint
+ if i == 3 or i == 0:
+ child.set_joint_properties(joint_types[i], limits=np.array([joint_limit[i]]))
+ elif i == 4 or i == 1:
+ child.set_joint_properties(
+ joint_types[i],
+ limits=np.array([joint_limit[i]]),
+ pose_in_child=Pose(q=euler2quat(0, 0, np.pi / 2)),
+ pose_in_parent=Pose(q=euler2quat(0, 0, np.pi / 2)),
+ )
+ elif i == 2 or i == 5:
+ child.set_joint_properties(
+ joint_types[i],
+ limits=np.array([joint_limit[i]]),
+ pose_in_parent=Pose(q=euler2quat(0, -np.pi / 2, 0)),
+ pose_in_child=Pose(q=euler2quat(0, -np.pi / 2, 0)),
+ )
+ parent = child
+
+ return parent
+
+
+class SAPIENKinematicsModelStandalone:
+ def __init__(self, urdf_path, add_dummy_translation=False, add_dummy_rotation=False):
+ self.engine = sapien.Engine()
+ self.scene = self.engine.create_scene()
+ loader = self.scene.create_urdf_loader()
+
+ builder = loader.load_file_as_articulation_builder(urdf_path)
+ if add_dummy_rotation or add_dummy_translation:
+ dummy_joint_indicator = (add_dummy_translation,) * 3 + (add_dummy_rotation,) * 3
+ add_dummy_free_joint(builder, dummy_joint_indicator)
+ self.robot = builder.build(fix_root_link=True)
+ self.robot.set_pose(sapien.Pose())
+ self.robot.set_qpos(np.zeros(self.robot.dof))
+ self.scene.step()
+
+ def release(self):
+ self.scene = None
+ self.engine = None
diff --git a/dex_retargeting/retargeting_config.py b/dex_retargeting/retargeting_config.py
new file mode 100644
index 0000000..c075aea
--- /dev/null
+++ b/dex_retargeting/retargeting_config.py
@@ -0,0 +1,199 @@
+from dataclasses import dataclass
+from typing import List, Optional, Dict
+from pathlib import Path
+
+import numpy as np
+import yaml
+
+from dex_retargeting.optimizer_utils import LPFilter
+from dex_retargeting.seq_retarget import SeqRetargeting
+
+
+@dataclass
+class RetargetingConfig:
+ type: str
+ urdf_path: str
+ use_camera_frame_retargeting: bool
+
+ # Source refers to the retargeting input, which usually corresponds to the human hand
+ # The joint indices of human hand joint which corresponds to each link in the target_link_names
+
+ target_link_human_indices: Optional[np.ndarray] = None
+
+ # Position retargeting link names
+ target_link_names: Optional[List[str]] = None
+
+ # Vector retargeting link names
+ target_joint_names: Optional[List[str]] = None
+ target_origin_link_names: Optional[List[str]] = None
+ target_task_link_names: Optional[List[str]] = None
+
+ # DexPilot retargeting link names
+ finger_tip_link_names: Optional[List[str]] = None
+ wrist_link_name: Optional[str] = None
+
+ # Scaling factor for vector retargeting only
+ # For example, Allegro is 1.6 times larger than normal human hand, then this scaling factor should be 1.6
+ scaling_factor: float = 1.0
+
+ # Optimization hyperparameter
+ normal_delta: float = 4e-3
+ huber_delta: float = 2e-2
+
+ # Joint limit tag
+ has_joint_limits: bool = True
+
+ # Low pass filter
+ low_pass_alpha: float = 0.1
+
+ _TYPE = ["vector", "position", "dexpilot"]
+ _DEFAULT_URDF_DIR = "./"
+
+ def __post_init__(self):
+ # Retargeting type check
+ self.type = self.type.lower()
+ if self.type not in self._TYPE:
+ raise ValueError(f"Retargeting type must be one of {self._TYPE}")
+
+ # Vector retargeting requires: target_origin_link_names + target_task_link_names
+ # Position retargeting requires: target_link_names
+ if self.type == "vector":
+ if self.target_origin_link_names is None or self.target_task_link_names is None:
+ raise ValueError(f"Vector retargeting requires: target_origin_link_names + target_task_link_names")
+ if len(self.target_task_link_names) != len(self.target_origin_link_names):
+ raise ValueError(f"Vector retargeting origin and task links dim mismatch")
+ if self.target_link_human_indices.shape != (2, len(self.target_origin_link_names)):
+ raise ValueError(f"Vector retargeting link names and link indices dim mismatch")
+ if self.target_link_human_indices is None:
+ raise ValueError(f"Vector retargeting requires: target_link_human_indices")
+
+ elif self.type == "position":
+ if self.target_link_names is None:
+ raise ValueError(f"Position retargeting requires: target_link_names")
+ self.target_link_human_indices = self.target_link_human_indices.squeeze()
+ if self.target_link_human_indices.shape != (len(self.target_link_human_indices),):
+ raise ValueError(f"Position retargeting link names and link indices dim mismatch")
+ if self.target_link_human_indices is None:
+ raise ValueError(f"Position retargeting requires: target_link_human_indices")
+
+ elif self.type == "dexpilot":
+ if self.finger_tip_link_names is None or self.wrist_link_name is None:
+ raise ValueError(f"Position retargeting requires: finger_tip_link_names + wrist_link_name")
+
+ # URDF path check
+ urdf_path = Path(self.urdf_path)
+ if not urdf_path.is_absolute():
+ urdf_path = self._DEFAULT_URDF_DIR / urdf_path
+ urdf_path = urdf_path.absolute()
+ if not urdf_path.exists():
+ raise ValueError(f"URDF path {urdf_path} does not exist")
+ self.urdf_path = str(urdf_path)
+
+ @classmethod
+ def set_default_urdf_dir(cls, urdf_dir):
+ path = Path(urdf_dir)
+ if not path.exists():
+ raise ValueError(f"URDF dir {urdf_dir} not exists.")
+ cls._DEFAULT_URDF_DIR = urdf_dir
+
+ @classmethod
+ def load_from_file(cls, config_path, override: Optional[Dict] = None):
+ path = Path(config_path)
+ if not path.is_absolute():
+ path = path.absolute()
+
+ with path.open("r") as f:
+ yaml_config = yaml.load(f, Loader=yaml.FullLoader)
+ cfg = yaml_config["retargeting"]
+ if "target_link_human_indices" in cfg:
+ cfg["target_link_human_indices"] = np.array(cfg["target_link_human_indices"])
+ if override is not None:
+ for key, value in override.items():
+ cfg[key] = value
+ config = RetargetingConfig(**cfg)
+ return config
+
+ def build(self) -> SeqRetargeting:
+ from dex_retargeting.optimizer import (
+ VectorOptimizer,
+ PositionOptimizer,
+ DexPilotAllegroOptimizer,
+ DexPilotAllegroV4Optimizer,
+ )
+ from dex_retargeting.optimizer_utils import SAPIENKinematicsModelStandalone
+ from dex_retargeting import yourdfpy as urdf
+ import tempfile
+
+ # Process the URDF with yourdfpy to better find file path
+ robot_urdf = urdf.URDF.load(self.urdf_path)
+ urdf_name = self.urdf_path.split("/")[-1]
+ temp_dir = tempfile.mkdtemp(prefix="teleop-")
+ temp_path = f"{temp_dir}/{urdf_name}"
+ robot_urdf.write_xml_file(temp_path)
+ sapien_model = SAPIENKinematicsModelStandalone(temp_path, add_dummy_rotation=self.use_camera_frame_retargeting)
+ robot = sapien_model.robot
+ joint_names = (
+ self.target_joint_names
+ if self.target_joint_names is not None
+ else [joint.get_name() for joint in robot.get_active_joints()]
+ )
+ if self.type == "position":
+ optimizer = PositionOptimizer(
+ robot,
+ joint_names,
+ target_link_names=self.target_link_names,
+ target_link_human_indices=self.target_link_human_indices,
+ norm_delta=self.normal_delta,
+ huber_delta=self.huber_delta,
+ )
+ elif self.type == "vector":
+ optimizer = VectorOptimizer(
+ robot,
+ joint_names,
+ target_origin_link_names=self.target_origin_link_names,
+ target_task_link_names=self.target_task_link_names,
+ target_link_human_indices=self.target_link_human_indices,
+ scaling=self.scaling_factor,
+ norm_delta=self.normal_delta,
+ huber_delta=self.huber_delta,
+ )
+ elif self.type == "dexpilot":
+ optimizer = DexPilotAllegroOptimizer(
+ robot,
+ joint_names,
+ finger_tip_link_names=self.finger_tip_link_names,
+ wrist_link_name=self.wrist_link_name,
+ scaling=self.scaling_factor,
+ )
+ else:
+ raise RuntimeError()
+
+ if 0 <= self.low_pass_alpha <= 1:
+ lp_filter = LPFilter(self.low_pass_alpha)
+ else:
+ lp_filter = None
+
+ retargeting = SeqRetargeting(
+ optimizer,
+ has_joint_limits=self.has_joint_limits,
+ use_camera_frame_retargeting=self.use_camera_frame_retargeting,
+ lp_filter=lp_filter,
+ )
+ # TODO: hack here for SAPIEN
+ retargeting.scene = sapien_model.scene
+ return retargeting
+
+
+def get_retargeting_config(config_path) -> RetargetingConfig:
+ config = RetargetingConfig.load_from_file(config_path)
+ return config
+
+
+if __name__ == "__main__":
+ # Path below is relative to this file
+ from pathlib import Path
+
+ test_config = get_retargeting_config(str(Path(__file__).parent / "configs/allegro_hand.yml"))
+ print(test_config)
+ opt = test_config.build()
+ print(opt.optimizer.target_link_human_indices)
diff --git a/dex_retargeting/seq_retarget.py b/dex_retargeting/seq_retarget.py
new file mode 100644
index 0000000..1e5e097
--- /dev/null
+++ b/dex_retargeting/seq_retarget.py
@@ -0,0 +1,66 @@
+import numpy as np
+from time import time
+
+from dex_retargeting.optimizer import Optimizer
+from dex_retargeting.optimizer_utils import LPFilter
+from typing import Optional
+
+
+class SeqRetargeting:
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ has_joint_limits=True,
+ use_camera_frame_retargeting=False,
+ lp_filter: Optional[LPFilter] = None,
+ ):
+ self.optimizer = optimizer
+ robot = self.optimizer.robot
+ self.use_camera_frame_retargeting = use_camera_frame_retargeting
+
+ # Joint limit
+ self.has_joint_limits = has_joint_limits
+ joint_limits = np.ones_like(robot.get_qlimits())
+ joint_limits[:, 0] = -1e4 # a large value is equivalent to no limit
+ joint_limits[:, 1] = 1e4
+ if has_joint_limits:
+ joint_limits[:] = robot.get_qlimits()[:]
+ self.optimizer.set_joint_limit(joint_limits[self.optimizer.target_joint_indices])
+ self.joint_limits = joint_limits
+
+ # Temporal information
+ self.last_qpos = joint_limits.mean(1)[self.optimizer.target_joint_indices]
+ self.accumulated_time = 0
+ self.num_retargeting = 0
+
+ # Filter
+ self.filter = lp_filter
+
+ # TODO: hack here
+ self.scene = None
+
+ def retarget(self, ref_value, fixed_qpos=np.array([])):
+ tic = time()
+ qpos = self.optimizer.retarget(
+ ref_value=ref_value.astype(np.float32),
+ fixed_qpos=fixed_qpos.astype(np.float32),
+ last_qpos=self.last_qpos.astype(np.float32),
+ )
+ self.accumulated_time += time() - tic
+ self.num_retargeting += 1
+ self.last_qpos = qpos
+ robot_qpos = np.zeros(self.optimizer.robot.dof)
+ robot_qpos[self.optimizer.fixed_joint_indices] = fixed_qpos
+ robot_qpos[self.optimizer.target_joint_indices] = qpos
+ if self.filter is not None:
+ robot_qpos = self.filter.next(robot_qpos)
+ return robot_qpos
+
+ def verbose(self):
+ min_value = self.optimizer.opt.last_optimum_value()
+ print(f"Retargeting {self.num_retargeting} times takes: {self.accumulated_time}s")
+ print(f"Last distance: {min_value}")
+
+ def reset(self):
+ self.last_qpos = self.joint_limits.mean(1)[self.optimizer.target_joint_indices]
+ self.num_retargeting = 0
diff --git a/dex_retargeting/yourdfpy.py b/dex_retargeting/yourdfpy.py
new file mode 100644
index 0000000..4555d36
--- /dev/null
+++ b/dex_retargeting/yourdfpy.py
@@ -0,0 +1,2178 @@
+# Code from yourdfpy with small modification for deprecated warning
+# Source: https://github.com/clemense/yourdfpy/blob/main/src/yourdfpy/urdf.py
+
+import copy
+import logging
+import os
+from dataclasses import dataclass, field, is_dataclass
+from functools import partial
+from typing import Dict, List, Optional, Union
+
+import anytree
+import numpy as np
+import six
+import trimesh
+import trimesh.transformations as tra
+from anytree import Node, LevelOrderIter
+from lxml import etree
+
+_logger = logging.getLogger(__name__)
+
+
+def _array_eq(arr1, arr2):
+ if arr1 is None and arr2 is None:
+ return True
+ return (
+ isinstance(arr1, np.ndarray)
+ and isinstance(arr2, np.ndarray)
+ and arr1.shape == arr2.shape
+ and (arr1 == arr2).all()
+ )
+
+
+@dataclass(eq=False)
+class TransmissionJoint:
+ name: str
+ hardware_interfaces: List[str] = field(default_factory=list)
+
+ def __eq__(self, other):
+ if not isinstance(other, TransmissionJoint):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and all(self_hi in other.hardware_interfaces for self_hi in self.hardware_interfaces)
+ and all(other_hi in self.hardware_interfaces for other_hi in other.hardware_interfaces)
+ )
+
+
+@dataclass(eq=False)
+class Actuator:
+ name: str
+ mechanical_reduction: Optional[float] = None
+ # The follwing is only valid for ROS Indigo and prior versions
+ hardware_interfaces: List[str] = field(default_factory=list)
+
+ def __eq__(self, other):
+ if not isinstance(other, Actuator):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and self.mechanical_reduction == other.mechanical_reduction
+ and all(self_hi in other.hardware_interfaces for self_hi in self.hardware_interfaces)
+ and all(other_hi in self.hardware_interfaces for other_hi in other.hardware_interfaces)
+ )
+
+
+@dataclass(eq=False)
+class Transmission:
+ name: str
+ type: Optional[str] = None
+ joints: List[TransmissionJoint] = field(default_factory=list)
+ actuators: List[Actuator] = field(default_factory=list)
+
+ def __eq__(self, other):
+ if not isinstance(other, Transmission):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and self.type == other.type
+ and all(self_joint in other.joints for self_joint in self.joints)
+ and all(other_joint in self.joints for other_joint in other.joints)
+ and all(self_actuator in other.actuators for self_actuator in self.actuators)
+ and all(other_actuator in self.actuators for other_actuator in other.actuators)
+ )
+
+
+@dataclass
+class Calibration:
+ rising: Optional[float] = None
+ falling: Optional[float] = None
+
+
+@dataclass
+class Mimic:
+ joint: str
+ multiplier: Optional[float] = None
+ offset: Optional[float] = None
+
+
+@dataclass
+class SafetyController:
+ soft_lower_limit: Optional[float] = None
+ soft_upper_limit: Optional[float] = None
+ k_position: Optional[float] = None
+ k_velocity: Optional[float] = None
+
+
+@dataclass
+class Sphere:
+ radius: float
+
+
+@dataclass
+class Cylinder:
+ radius: float
+ length: float
+
+
+@dataclass(eq=False)
+class Box:
+ size: np.ndarray
+
+ def __eq__(self, other):
+ if not isinstance(other, Box):
+ return NotImplemented
+ return _array_eq(self.size, other.size)
+
+
+@dataclass(eq=False)
+class Mesh:
+ filename: str
+ scale: Optional[Union[float, np.ndarray]] = None
+
+ def __eq__(self, other):
+ if not isinstance(other, Mesh):
+ return NotImplemented
+
+ if self.filename != other.filename:
+ return False
+
+ if isinstance(self.scale, float) and isinstance(other.scale, float):
+ return self.scale == other.scale
+
+ return _array_eq(self.scale, other.scale)
+
+
+@dataclass
+class Geometry:
+ box: Optional[Box] = None
+ cylinder: Optional[Cylinder] = None
+ sphere: Optional[Sphere] = None
+ mesh: Optional[Mesh] = None
+
+
+@dataclass(eq=False)
+class Color:
+ rgba: np.ndarray
+
+ def __eq__(self, other):
+ if not isinstance(other, Color):
+ return NotImplemented
+ return _array_eq(self.rgba, other.rgba)
+
+
+@dataclass
+class Texture:
+ filename: str
+
+
+@dataclass
+class Material:
+ name: Optional[str] = None
+ color: Optional[Color] = None
+ texture: Optional[Texture] = None
+
+
+@dataclass(eq=False)
+class Visual:
+ name: Optional[str] = None
+ origin: Optional[np.ndarray] = None
+ geometry: Optional[Geometry] = None # That's not really optional according to ROS
+ material: Optional[Material] = None
+
+ def __eq__(self, other):
+ if not isinstance(other, Visual):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and _array_eq(self.origin, other.origin)
+ and self.geometry == other.geometry
+ and self.material == other.material
+ )
+
+
+@dataclass(eq=False)
+class Collision:
+ name: str
+ origin: Optional[np.ndarray] = None
+ geometry: Geometry = None
+
+ def __eq__(self, other):
+ if not isinstance(other, Collision):
+ return NotImplemented
+ return self.name == other.name and _array_eq(self.origin, other.origin) and self.geometry == other.geometry
+
+
+@dataclass(eq=False)
+class Inertial:
+ origin: Optional[np.ndarray] = None
+ mass: Optional[float] = None
+ inertia: Optional[np.ndarray] = None
+
+ def __eq__(self, other):
+ if not isinstance(other, Inertial):
+ return NotImplemented
+ return (
+ _array_eq(self.origin, other.origin) and self.mass == other.mass and _array_eq(self.inertia, other.inertia)
+ )
+
+
+@dataclass(eq=False)
+class Link:
+ name: str
+ inertial: Optional[Inertial] = None
+ visuals: List[Visual] = field(default_factory=list)
+ collisions: List[Collision] = field(default_factory=list)
+
+ def __eq__(self, other):
+ if not isinstance(other, Link):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and self.inertial == other.inertial
+ and all(self_visual in other.visuals for self_visual in self.visuals)
+ and all(other_visual in self.visuals for other_visual in other.visuals)
+ and all(self_collision in other.collisions for self_collision in self.collisions)
+ and all(other_collision in self.collisions for other_collision in other.collisions)
+ )
+
+
+@dataclass
+class Dynamics:
+ damping: Optional[float] = None
+ friction: Optional[float] = None
+
+
+@dataclass
+class Limit:
+ effort: Optional[float] = None
+ velocity: Optional[float] = None
+ lower: Optional[float] = None
+ upper: Optional[float] = None
+
+
+@dataclass(eq=False)
+class Joint:
+ name: str
+ type: str = None
+ parent: str = None
+ child: str = None
+ origin: np.ndarray = None
+ axis: np.ndarray = None
+ dynamics: Optional[Dynamics] = None
+ limit: Optional[Limit] = None
+ mimic: Optional[Mimic] = None
+ calibration: Optional[Calibration] = None
+ safety_controller: Optional[SafetyController] = None
+
+ def __eq__(self, other):
+ if not isinstance(other, Joint):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and self.type == other.type
+ and self.parent == other.parent
+ and self.child == other.child
+ and _array_eq(self.origin, other.origin)
+ and _array_eq(self.axis, other.axis)
+ and self.dynamics == other.dynamics
+ and self.limit == other.limit
+ and self.mimic == other.mimic
+ and self.calibration == other.calibration
+ and self.safety_controller == other.safety_controller
+ )
+
+
+@dataclass(eq=False)
+class Robot:
+ name: str
+ links: List[Link] = field(default_factory=list)
+ joints: List[Joint] = field(default_factory=list)
+ materials: List[Material] = field(default_factory=list)
+ transmission: List[str] = field(default_factory=list)
+ gazebo: List[str] = field(default_factory=list)
+
+ def __eq__(self, other):
+ if not isinstance(other, Robot):
+ return NotImplemented
+ return (
+ self.name == other.name
+ and all(self_link in other.links for self_link in self.links)
+ and all(other_link in self.links for other_link in other.links)
+ and all(self_joint in other.joints for self_joint in self.joints)
+ and all(other_joint in self.joints for other_joint in other.joints)
+ and all(self_material in other.materials for self_material in self.materials)
+ and all(other_material in self.materials for other_material in other.materials)
+ and all(self_transmission in other.transmission for self_transmission in self.transmission)
+ and all(other_transmission in self.transmission for other_transmission in other.transmission)
+ and all(self_gazebo in other.gazebo for self_gazebo in self.gazebo)
+ and all(other_gazebo in self.gazebo for other_gazebo in other.gazebo)
+ )
+
+
+class URDFError(Exception):
+ """General URDF exception."""
+
+ def __init__(self, msg):
+ super(URDFError, self).__init__()
+ self.msg = msg
+
+ def __str__(self):
+ return type(self).__name__ + ": " + self.msg
+
+ def __repr__(self):
+ return type(self).__name__ + '("' + self.msg + '")'
+
+
+class URDFIncompleteError(URDFError):
+ """Raised when needed data for an object isn't there."""
+
+ pass
+
+
+class URDFAttributeValueError(URDFError):
+ """Raised when attribute value is not contained in the set of allowed values."""
+
+ pass
+
+
+class URDFBrokenRefError(URDFError):
+ """Raised when a referenced object is not found in the scope."""
+
+ pass
+
+
+class URDFMalformedError(URDFError):
+ """Raised when data is found to be corrupted in some way."""
+
+ pass
+
+
+class URDFUnsupportedError(URDFError):
+ """Raised when some unexpectedly unsupported feature is found."""
+
+ pass
+
+
+class URDFSaveValidationError(URDFError):
+ """Raised when XML validation fails when saving."""
+
+ pass
+
+
+def _str2float(s):
+ """Cast string to float if it is not None. Otherwise return None.
+
+ Args:
+ s (str): String to convert or None.
+
+ Returns:
+ str or NoneType: The converted string or None.
+ """
+ return float(s) if s is not None else None
+
+
+def apply_visual_color(
+ geom: trimesh.Trimesh,
+ visual: Visual,
+ material_map: Dict[str, Material],
+) -> None:
+ """Apply the color of the visual material to the mesh.
+
+ Args:
+ geom: Trimesh to color.
+ visual: Visual description from XML.
+ material_map: Dictionary mapping material names to their definitions.
+ """
+ if visual.material is None:
+ return
+
+ if visual.material.color is not None:
+ color = visual.material.color
+ elif visual.material.name is not None and visual.material.name in material_map:
+ color = material_map[visual.material.name].color
+ else:
+ return
+
+ if color is None:
+ return
+ if isinstance(geom.visual, trimesh.visual.ColorVisuals):
+ geom.visual.face_colors[:] = [int(255 * channel) for channel in color.rgba]
+
+
+def filename_handler_null(fname):
+ """A lazy filename handler that simply returns its input.
+
+ Args:
+ fname (str): A file name.
+
+ Returns:
+ str: Same file name.
+ """
+ return fname
+
+
+def filename_handler_ignore_directive(fname):
+ """A filename handler that removes anything before (and including) '://'.
+
+ Args:
+ fname (str): A file name.
+
+ Returns:
+ str: The file name without the prefix.
+ """
+ if "://" in fname or ":\\\\" in fname:
+ return ":".join(fname.split(":")[1:])[2:]
+ return fname
+
+
+def filename_handler_ignore_directive_package(fname):
+ """A filename handler that removes the 'package://' directive and the package it refers to.
+ It subsequently calls filename_handler_ignore_directive, i.e., it removes any other directive.
+
+ Args:
+ fname (str): A file name.
+
+ Returns:
+ str: The file name without 'package://' and the package name.
+ """
+ if fname.startswith("package://"):
+ string_length = len("package://")
+ return os.path.join(*os.path.normpath(fname[string_length:]).split(os.path.sep)[1:])
+ return filename_handler_ignore_directive(fname)
+
+
+def filename_handler_add_prefix(fname, prefix):
+ """A filename handler that adds a prefix.
+
+ Args:
+ fname (str): A file name.
+ prefix (str): A prefix.
+
+ Returns:
+ str: Prefix plus file name.
+ """
+ return prefix + fname
+
+
+def filename_handler_absolute2relative(fname, dir):
+ """A filename handler that turns an absolute file name into a relative one.
+
+ Args:
+ fname (str): A file name.
+ dir (str): A directory.
+
+ Returns:
+ str: The file name relative to the directory.
+ """
+ # TODO: that's not right
+ if fname.startswith(dir):
+ return fname[len(dir) :]
+ return fname
+
+
+def filename_handler_relative(fname, dir):
+ """A filename handler that joins a file name with a directory.
+
+ Args:
+ fname (str): A file name.
+ dir (str): A directory.
+
+ Returns:
+ str: The directory joined with the file name.
+ """
+ return os.path.join(dir, filename_handler_ignore_directive_package(fname))
+
+
+def filename_handler_relative_to_urdf_file(fname, urdf_fname):
+ return filename_handler_relative(fname, os.path.dirname(urdf_fname))
+
+
+def filename_handler_relative_to_urdf_file_recursive(fname, urdf_fname, level=0):
+ if level == 0:
+ return filename_handler_relative_to_urdf_file(fname, urdf_fname)
+ return filename_handler_relative_to_urdf_file_recursive(fname, os.path.split(urdf_fname)[0], level=level - 1)
+
+
+def _create_filename_handlers_to_urdf_file_recursive(urdf_fname):
+ return [
+ partial(
+ filename_handler_relative_to_urdf_file_recursive,
+ urdf_fname=urdf_fname,
+ level=i,
+ )
+ for i in range(len(os.path.normpath(urdf_fname).split(os.path.sep)))
+ ]
+
+
+def filename_handler_meta(fname, filename_handlers):
+ """A filename handler that calls other filename handlers until the resulting file name points to an existing file.
+
+ Args:
+ fname (str): A file name.
+ filename_handlers (list(fn)): A list of function pointers to filename handlers.
+
+ Returns:
+ str: The resolved file name that points to an existing file or the input if none of the files exists.
+ """
+ for fn in filename_handlers:
+ candidate_fname = fn(fname=fname)
+ _logger.debug(f"Checking filename: {candidate_fname}")
+ if os.path.isfile(candidate_fname):
+ return candidate_fname
+ _logger.warning(f"Unable to resolve filename: {fname}")
+ return fname
+
+
+def filename_handler_magic(fname, dir):
+ """A magic filename handler.
+
+ Args:
+ fname (str): A file name.
+ dir (str): A directory.
+
+ Returns:
+ str: The file name that exists or the input if nothing is found.
+ """
+ return filename_handler_meta(
+ fname=fname,
+ filename_handlers=[
+ partial(filename_handler_relative, dir=dir),
+ filename_handler_ignore_directive,
+ ]
+ + _create_filename_handlers_to_urdf_file_recursive(urdf_fname=dir),
+ )
+
+
+def validation_handler_strict(errors):
+ """A validation handler that does not allow any errors.
+
+ Args:
+ errors (list[yourdfpy.URDFError]): List of errors.
+
+ Returns:
+ bool: Whether any errors were found.
+ """
+ return len(errors) == 0
+
+
+class URDF:
+ def __init__(
+ self,
+ robot: Robot = None,
+ build_scene_graph: bool = True,
+ build_collision_scene_graph: bool = False,
+ load_meshes: bool = True,
+ load_collision_meshes: bool = False,
+ filename_handler=None,
+ mesh_dir: str = "",
+ force_mesh: bool = False,
+ force_collision_mesh: bool = True,
+ build_tree: bool = False,
+ ):
+ """A URDF model.
+
+ Args:
+ robot (Robot): The robot model. Defaults to None.
+ build_scene_graph (bool, optional): Wheter to build a scene graph to enable transformation queries and forward kinematics. Defaults to True.
+ build_collision_scene_graph (bool, optional): Wheter to build a scene graph for elements. Defaults to False.
+ load_meshes (bool, optional): Whether to load the meshes referenced in the elements. Defaults to True.
+ load_collision_meshes (bool, optional): Whether to load the collision meshes referenced in the elements. Defaults to False.
+ filename_handler ([type], optional): Any function f(in: str) -> str, that maps filenames in the URDF to actual resources. Can be used to customize treatment of `package://` directives or relative/absolute filenames. Defaults to None.
+ mesh_dir (str, optional): A root directory used for loading meshes. Defaults to "".
+ force_mesh (bool, optional): Each loaded geometry will be concatenated into a single one (instead of being turned into a graph; in case the underlying file contains multiple geometries). This might loose texture information but the resulting scene graph will be smaller. Defaults to False.
+ force_collision_mesh (bool, optional): Same as force_mesh, but for collision scene. Defaults to True.
+ build_tree (bool, optional): Build the tree structure for global kinematics computation
+ """
+ if filename_handler is None:
+ self._filename_handler = partial(filename_handler_magic, dir=mesh_dir)
+ else:
+ self._filename_handler = filename_handler
+
+ self.robot = robot
+ self._create_maps()
+ self._update_actuated_joints()
+
+ self._cfg = self.zero_cfg
+
+ if build_scene_graph or build_collision_scene_graph:
+ self._base_link = self._determine_base_link()
+ else:
+ self._base_link = None
+
+ self._errors = []
+
+ if build_scene_graph:
+ self._scene = self._create_scene(
+ use_collision_geometry=False,
+ load_geometry=load_meshes,
+ force_mesh=force_mesh,
+ force_single_geometry_per_link=force_mesh,
+ )
+ else:
+ self._scene = None
+
+ if build_collision_scene_graph:
+ self._scene_collision = self._create_scene(
+ use_collision_geometry=True,
+ load_geometry=load_collision_meshes,
+ force_mesh=force_collision_mesh,
+ force_single_geometry_per_link=force_collision_mesh,
+ )
+ else:
+ self._scene_collision = None
+
+ if build_tree:
+ self.tree_root = self.build_tree()
+ else:
+ self.tree_root = None
+
+ @property
+ def scene(self) -> trimesh.Scene:
+ """A scene object representing the URDF model.
+
+ Returns:
+ trimesh.Scene: A trimesh scene object.
+ """
+ return self._scene
+
+ @property
+ def collision_scene(self) -> trimesh.Scene:
+ """A scene object representing the elements of the URDF model
+
+ Returns:
+ trimesh.Scene: A trimesh scene object.
+ """
+ return self._scene_collision
+
+ @property
+ def link_map(self) -> dict:
+ """A dictionary mapping link names to link objects.
+
+ Returns:
+ dict: Mapping from link name (str) to Link.
+ """
+ return self._link_map
+
+ @property
+ def joint_map(self) -> dict:
+ """A dictionary mapping joint names to joint objects.
+
+ Returns:
+ dict: Mapping from joint name (str) to Joint.
+ """
+ return self._joint_map
+
+ @property
+ def joint_names(self):
+ """List of joint names.
+
+ Returns:
+ list[str]: List of joint names of the URDF model.
+ """
+ return [j.name for j in self.robot.joints]
+
+ @property
+ def actuated_joints(self):
+ """List of actuated joints. This excludes mimic and fixed joints.
+
+ Returns:
+ list[Joint]: List of actuated joints of the URDF model.
+ """
+ return self._actuated_joints
+
+ @property
+ def actuated_dof_indices(self):
+ """List of DOF indices per actuated joint. Can be used to reference configuration.
+
+ Returns:
+ list[list[int]]: List of DOF indices per actuated joint.
+ """
+ return self._actuated_dof_indices
+
+ @property
+ def actuated_joint_indices(self):
+ """List of indices of all joints that are actuated, i.e., not of type mimic or fixed.
+
+ Returns:
+ list[int]: List of indices of actuated joints.
+ """
+ return self._actuated_joint_indices
+
+ @property
+ def actuated_joint_names(self):
+ """List of names of actuated joints. This excludes mimic and fixed joints.
+
+ Returns:
+ list[str]: List of names of actuated joints of the URDF model.
+ """
+ return [j.name for j in self._actuated_joints]
+
+ @property
+ def num_actuated_joints(self):
+ """Number of actuated joints.
+
+ Returns:
+ int: Number of actuated joints.
+ """
+ return len(self.actuated_joints)
+
+ @property
+ def num_dofs(self):
+ """Number of degrees of freedom of actuated joints. Depending on the type of the joint, the number of DOFs might vary.
+
+ Returns:
+ int: Degrees of freedom.
+ """
+ total_num_dofs = 0
+ for j in self._actuated_joints:
+ if j.type in ["revolute", "prismatic", "continuous"]:
+ total_num_dofs += 1
+ elif j.type == "floating":
+ total_num_dofs += 6
+ elif j.type == "planar":
+ total_num_dofs += 2
+ return total_num_dofs
+
+ @property
+ def zero_cfg(self):
+ """Return the zero configuration.
+
+ Returns:
+ np.ndarray: The zero configuration.
+ """
+ return np.zeros(self.num_dofs)
+
+ @property
+ def center_cfg(self):
+ """Return center configuration of URDF model by using the average of each joint's limits if present, otherwise zero.
+
+ Returns:
+ (n), float: Default configuration of URDF model.
+ """
+ config = []
+ config_names = []
+ for j in self._actuated_joints:
+ if j.type == "revolute" or j.type == "prismatic":
+ if j.limit is not None:
+ cfg = [j.limit.lower + 0.5 * (j.limit.upper - j.limit.lower)]
+ else:
+ cfg = [0.0]
+ elif j.type == "continuous":
+ cfg = [0.0]
+ elif j.type == "floating":
+ cfg = [0.0] * 6
+ elif j.type == "planar":
+ cfg = [0.0] * 2
+
+ config.append(cfg)
+ config_names.append(j.name)
+
+ for i, j in enumerate(self.robot.joints):
+ if j.mimic is not None:
+ index = config_names.index(j.mimic.joint)
+ config[i][0] = config[index][0] * j.mimic.multiplier + j.mimic.offset
+
+ if len(config) == 0:
+ return np.array([], dtype=np.float64)
+ return np.concatenate(config)
+
+ @property
+ def cfg(self):
+ """Current configuration.
+
+ Returns:
+ np.ndarray: Current configuration of URDF model.
+ """
+ return self._cfg
+
+ @property
+ def base_link(self):
+ """Name of URDF base/root link.
+
+ Returns:
+ str: Name of base link of URDF model.
+ """
+ return self._base_link
+
+ @property
+ def errors(self) -> list:
+ """A list with validation errors.
+
+ Returns:
+ list: A list of validation errors.
+ """
+ return self._errors
+
+ def clear_errors(self):
+ """Clear the validation error log."""
+ self._errors = []
+
+ def show(self, collision_geometry=False, callback=None):
+ """Open a simpler viewer displaying the URDF model.
+
+ Args:
+ collision_geometry (bool, optional): Whether to display the or elements. Defaults to False.
+ """
+ if collision_geometry:
+ if self._scene_collision is None:
+ raise ValueError(
+ "No collision scene available. Use build_collision_scene_graph=True and load_collision_meshes=True during loading."
+ )
+ else:
+ self._scene_collision.show(callback=callback)
+ else:
+ if self._scene is None:
+ raise ValueError("No scene available. Use build_scene_graph=True and load_meshes=True during loading.")
+ elif len(self._scene.bounds_corners) < 1:
+ raise ValueError(
+ "Scene is empty, maybe meshes failed to load? Use build_scene_graph=True and load_meshes=True during loading."
+ )
+ else:
+ self._scene.show(callback=callback)
+
+ def validate(self, validation_fn=None) -> bool:
+ """Validate URDF model.
+
+ Args:
+ validation_fn (function, optional): A function f(list[yourdfpy.URDFError]) -> bool. None uses the strict handler (any error leads to False). Defaults to None.
+
+ Returns:
+ bool: Whether the model is valid.
+ """
+ self._errors = []
+ self._validate_robot(self.robot)
+
+ if validation_fn is None:
+ validation_fn = validation_handler_strict
+
+ return validation_fn(self._errors)
+
+ def _create_maps(self):
+ self._material_map = {}
+ for m in self.robot.materials:
+ self._material_map[m.name] = m
+
+ self._joint_map = {}
+ for j in self.robot.joints:
+ self._joint_map[j.name] = j
+
+ self._link_map = {}
+ for l in self.robot.links:
+ self._link_map[l.name] = l
+
+ def _update_actuated_joints(self):
+ self._actuated_joints = []
+ self._actuated_joint_indices = []
+ self._actuated_dof_indices = []
+
+ dof_indices_cnt = 0
+ for i, j in enumerate(self.robot.joints):
+ if j.mimic is None and j.type != "fixed":
+ self._actuated_joints.append(j)
+ self._actuated_joint_indices.append(i)
+
+ if j.type in ["prismatic", "revolute", "continuous"]:
+ self._actuated_dof_indices.append([dof_indices_cnt])
+ dof_indices_cnt += 1
+ elif j.type == "floating":
+ self._actuated_dof_indices.append([dof_indices_cnt, dof_indices_cnt + 1, dof_indices_cnt + 2])
+ dof_indices_cnt += 3
+ elif j.type == "planar":
+ self._actuated_dof_indices.append([dof_indices_cnt, dof_indices_cnt + 1])
+ dof_indices_cnt += 2
+
+ def _validate_required_attribute(self, attribute, error_msg, allowed_values=None):
+ if attribute is None:
+ self._errors.append(URDFIncompleteError(error_msg))
+ elif isinstance(attribute, str) and len(attribute) == 0:
+ self._errors.append(URDFIncompleteError(error_msg))
+
+ if allowed_values is not None and attribute is not None:
+ if attribute not in allowed_values:
+ self._errors.append(URDFAttributeValueError(error_msg))
+
+ @staticmethod
+ def load(fname_or_file, **kwargs):
+ """Load URDF file from filename or file object.
+
+ Args:
+ fname_or_file (str or file object): A filename or file object, file-like object, stream representing the URDF file.
+ **build_scene_graph (bool, optional): Wheter to build a scene graph to enable transformation queries and forward kinematics. Defaults to True.
+ **build_collision_scene_graph (bool, optional): Wheter to build a scene graph for elements. Defaults to False.
+ **load_meshes (bool, optional): Whether to load the meshes referenced in the elements. Defaults to True.
+ **load_collision_meshes (bool, optional): Whether to load the collision meshes referenced in the elements. Defaults to False.
+ **filename_handler ([type], optional): Any function f(in: str) -> str, that maps filenames in the URDF to actual resources. Can be used to customize treatment of `package://` directives or relative/absolute filenames. Defaults to None.
+ **mesh_dir (str, optional): A root directory used for loading meshes. Defaults to "".
+ **force_mesh (bool, optional): Each loaded geometry will be concatenated into a single one (instead of being turned into a graph; in case the underlying file contains multiple geometries). This might loose texture information but the resulting scene graph will be smaller. Defaults to False.
+ **force_collision_mesh (bool, optional): Same as force_mesh, but for collision scene. Defaults to True.
+
+ Raises:
+ ValueError: If filename does not exist.
+
+ Returns:
+ yourdfpy.URDF: URDF model.
+ """
+ if isinstance(fname_or_file, six.string_types):
+ if not os.path.isfile(fname_or_file):
+ raise ValueError("{} is not a file".format(fname_or_file))
+
+ if not "mesh_dir" in kwargs:
+ kwargs["mesh_dir"] = os.path.dirname(fname_or_file)
+
+ try:
+ parser = etree.XMLParser(remove_blank_text=True)
+ tree = etree.parse(fname_or_file, parser=parser)
+ xml_root = tree.getroot()
+ except Exception as e:
+ _logger.error(e)
+ _logger.error("Using different parsing approach.")
+
+ events = ("start", "end", "start-ns", "end-ns")
+ xml = etree.iterparse(fname_or_file, recover=True, events=events)
+
+ # Iterate through all XML elements
+ for action, elem in xml:
+ # Skip comments and processing instructions,
+ # because they do not have names
+ if not (isinstance(elem, etree._Comment) or isinstance(elem, etree._ProcessingInstruction)):
+ # Remove a namespace URI in the element's name
+ # elem.tag = etree.QName(elem).localname
+ if action == "end" and ":" in elem.tag:
+ elem.getparent().remove(elem)
+
+ xml_root = xml.root
+
+ # Remove comments
+ etree.strip_tags(xml_root, etree.Comment)
+ etree.cleanup_namespaces(xml_root)
+
+ return URDF(robot=URDF._parse_robot(xml_element=xml_root), **kwargs)
+
+ def contains(self, key, value, element=None) -> bool:
+ """Checks recursively whether the URDF tree contains the provided key-value pair.
+
+ Args:
+ key (str): A key.
+ value (str): A value.
+ element (etree.Element, optional): The XML element from which to start the recursive search. None means URDF root. Defaults to None.
+
+ Returns:
+ bool: Whether the key-value pair was found.
+ """
+ if element is None:
+ element = self.robot
+
+ result = False
+ for field in element.__dataclass_fields__:
+ field_value = getattr(element, field)
+ if is_dataclass(field_value):
+ result = result or self.contains(key=key, value=value, element=field_value)
+ elif isinstance(field_value, list) and len(field_value) > 0 and is_dataclass(field_value[0]):
+ for field_value_element in field_value:
+ result = result or self.contains(key=key, value=value, element=field_value_element)
+ else:
+ if key == field and value == field_value:
+ result = True
+ return result
+
+ def _determine_base_link(self):
+ """Get the base link of the URDF tree by extracting all links without parents.
+ In case multiple links could be root chose the first.
+
+ Returns:
+ str: Name of the base link.
+ """
+ link_names = [l.name for l in self.robot.links]
+
+ for j in self.robot.joints:
+ link_names.remove(j.child)
+
+ if len(link_names) == 0:
+ # raise Error?
+ return None
+
+ return link_names[0]
+
+ def _forward_kinematics_joint(self, joint, q=None):
+ origin = np.eye(4) if joint.origin is None else joint.origin
+
+ if joint.mimic is not None:
+ if joint.mimic.joint in self.actuated_joint_names:
+ mimic_joint_index = self.actuated_joint_names.index(joint.mimic.joint)
+ q = self._cfg[mimic_joint_index] * joint.mimic.multiplier + joint.mimic.offset
+ else:
+ _logger.warning(
+ f"Joint '{joint.name}' is supposed to mimic '{joint.mimic.joint}'. But this joint is not actuated - will assume (0.0 + offset)."
+ )
+ q = 0.0 + joint.mimic.offset
+
+ if joint.type in ["revolute", "prismatic", "continuous"]:
+ if q is None:
+ # Use internal cfg vector for forward kinematics
+ q = self.cfg[self.actuated_dof_indices[self.actuated_joint_names.index(joint.name)]]
+
+ if joint.type == "prismatic":
+ matrix = origin @ tra.translation_matrix(q * joint.axis)
+ else:
+ matrix = origin @ tra.rotation_matrix(q, joint.axis)
+ else:
+ # this includes: floating, planar, fixed
+ matrix = origin
+
+ return matrix, q
+
+ def update_cfg(self, configuration):
+ """Update joint configuration of URDF; does forward kinematics.
+
+ Args:
+ configuration (dict, list[float], tuple[float] or np.ndarray): A mapping from joints or joint names to configuration values, or a list containing a value for each actuated joint.
+
+ Raises:
+ ValueError: Raised if dimensionality of configuration does not match number of actuated joints of URDF model.
+ TypeError: Raised if configuration is neither a dict, list, tuple or np.ndarray.
+ """
+ joint_cfg = []
+
+ if isinstance(configuration, dict):
+ for joint in configuration:
+ if isinstance(joint, six.string_types):
+ joint_cfg.append((self._joint_map[joint], configuration[joint]))
+ elif isinstance(joint, Joint):
+ # TODO: Joint is not hashable; so this branch will not succeed
+ joint_cfg.append((joint, configuration[joint]))
+ elif isinstance(configuration, (list, tuple, np.ndarray)):
+ if len(configuration) == len(self.robot.joints):
+ for joint, value in zip(self.robot.joints, configuration):
+ joint_cfg.append((joint, value))
+ elif len(configuration) == self.num_actuated_joints:
+ for joint, value in zip(self._actuated_joints, configuration):
+ joint_cfg.append((joint, value))
+ else:
+ raise ValueError(
+ f"Dimensionality of configuration ({len(configuration)}) doesn't match number of all ({len(self.robot.joints)}) or actuated joints ({self.num_actuated_joints})."
+ )
+ else:
+ raise TypeError("Invalid type for configuration")
+
+ # append all mimic joints in the update
+ for j, q in joint_cfg + [(j, 0.0) for j in self.robot.joints if j.mimic is not None]:
+ matrix, joint_q = self._forward_kinematics_joint(j, q=q)
+
+ # update internal configuration vector - only consider actuated joints
+ if j.name in self.actuated_joint_names:
+ self._cfg[self.actuated_dof_indices[self.actuated_joint_names.index(j.name)]] = joint_q
+
+ if self._scene is not None:
+ self._scene.graph.update(frame_from=j.parent, frame_to=j.child, matrix=matrix)
+ if self._scene_collision is not None:
+ self._scene_collision.graph.update(frame_from=j.parent, frame_to=j.child, matrix=matrix)
+
+ def get_transform(self, frame_to, frame_from=None, collision_geometry=False):
+ """Get the transform from one frame to another.
+
+ Args:
+ frame_to (str): Node name.
+ frame_from (str, optional): Node name. If None it will be set to self.base_frame. Defaults to None.
+ collision_geometry (bool, optional): Whether to use the collision geometry scene graph (instead of the visual geometry). Defaults to False.
+
+ Raises:
+ ValueError: Raised if scene graph wasn't constructed during intialization.
+
+ Returns:
+ (4, 4) float: Homogeneous transformation matrix
+ """
+ if collision_geometry:
+ if self._scene_collision is None:
+ raise ValueError("No collision scene available. Use build_collision_scene_graph=True during loading.")
+ else:
+ return self._scene_collision.graph.get(frame_to=frame_to, frame_from=frame_from)[0]
+ else:
+ if self._scene is None:
+ raise ValueError("No scene available. Use build_scene_graph=True during loading.")
+ else:
+ return self._scene.graph.get(frame_to=frame_to, frame_from=frame_from)[0]
+
+ def _link_mesh(self, link, collision_geometry=True):
+ geometries = link.collisions if collision_geometry else link.visuals
+
+ if len(geometries) == 0:
+ return None
+
+ meshes = []
+ for g in geometries:
+ for m in g.geometry.meshes:
+ m = m.copy()
+ pose = g.origin
+ if g.geometry.mesh is not None:
+ if g.geometry.mesh.scale is not None:
+ S = np.eye(4)
+ S[:3, :3] = np.diag(g.geometry.mesh.scale)
+ pose = pose.dot(S)
+ m.apply_transform(pose)
+ meshes.append(m)
+ if len(meshes) == 0:
+ return None
+ self._collision_mesh = meshes[0] + meshes[1:]
+ return self._collision_mesh
+
+ def _geometry2trimeshscene(self, geometry, load_file, force_mesh, skip_materials):
+ new_s = None
+ if geometry.box is not None:
+ new_s = trimesh.primitives.Box(extents=geometry.box.size).scene()
+ elif geometry.sphere is not None:
+ new_s = trimesh.primitives.Sphere(radius=geometry.sphere.radius).scene()
+ elif geometry.cylinder is not None:
+ new_s = trimesh.primitives.Cylinder(
+ radius=geometry.cylinder.radius, height=geometry.cylinder.length
+ ).scene()
+ elif geometry.mesh is not None and load_file:
+ new_filename = self._filename_handler(fname=geometry.mesh.filename)
+
+ if os.path.isfile(new_filename):
+ _logger.debug(f"Loading {geometry.mesh.filename} as {new_filename}")
+
+ if force_mesh:
+ new_g = trimesh.load(
+ new_filename,
+ ignore_broken=True,
+ force="mesh",
+ skip_materials=skip_materials,
+ )
+
+ # add original filename
+ if "file_path" not in new_g.metadata:
+ new_g.metadata["file_path"] = os.path.abspath(new_filename)
+ new_g.metadata["file_name"] = os.path.basename(new_filename)
+
+ new_s = trimesh.Scene()
+ new_s.add_geometry(new_g)
+ else:
+ new_s = trimesh.load(
+ new_filename,
+ ignore_broken=True,
+ force="scene",
+ skip_materials=skip_materials,
+ )
+
+ if "file_path" in new_s.metadata:
+ for i, (_, geom) in enumerate(new_s.geometry.items()):
+ if "file_path" not in geom.metadata:
+ geom.metadata["file_path"] = new_s.metadata["file_path"]
+ geom.metadata["file_name"] = new_s.metadata["file_name"]
+ geom.metadata["file_element"] = i
+
+ # scale mesh appropriately
+ if geometry.mesh.scale is not None:
+ if isinstance(geometry.mesh.scale, float):
+ new_s = new_s.scaled(geometry.mesh.scale)
+ elif isinstance(geometry.mesh.scale, np.ndarray):
+ new_s = new_s.scaled(geometry.mesh.scale)
+ else:
+ _logger.warning(f"Warning: Can't interpret scale '{geometry.mesh.scale}'")
+ else:
+ _logger.warning(f"Can't find {new_filename}")
+ return new_s
+
+ def _add_geometries_to_scene(
+ self,
+ s,
+ geometries,
+ link_name,
+ load_geometry,
+ force_mesh,
+ force_single_geometry,
+ skip_materials,
+ ):
+ if force_single_geometry:
+ tmp_scene = trimesh.Scene(base_frame=link_name)
+
+ first_geom_name = None
+
+ for v in geometries:
+ if v.geometry is not None:
+ if first_geom_name is None:
+ first_geom_name = v.name
+
+ new_s = self._geometry2trimeshscene(
+ geometry=v.geometry,
+ load_file=load_geometry,
+ force_mesh=force_mesh,
+ skip_materials=skip_materials,
+ )
+ if new_s is not None:
+ origin = v.origin if v.origin is not None else np.eye(4)
+
+ if force_single_geometry:
+ for name, geom in new_s.geometry.items():
+ if isinstance(v, Visual):
+ apply_visual_color(geom, v, self._material_map)
+ tmp_scene.add_geometry(
+ geometry=geom,
+ geom_name=v.name,
+ parent_node_name=link_name,
+ transform=origin @ new_s.graph.get(name)[0],
+ )
+ else:
+ # The following map is used to deal with glb format
+ # when the graph node and geometry have different names
+ geom_name_map = {new_s.graph[node_name][1]: node_name for node_name in new_s.graph.nodes}
+ for name, geom in new_s.geometry.items():
+ if isinstance(v, Visual):
+ apply_visual_color(geom, v, self._material_map)
+ s.add_geometry(
+ geometry=geom,
+ geom_name=v.name,
+ parent_node_name=link_name,
+ transform=origin @ new_s.graph.get(geom_name_map[name])[0],
+ )
+
+ if force_single_geometry and len(tmp_scene.geometry) > 0:
+ s.add_geometry(
+ geometry=tmp_scene.dump(concatenate=True),
+ geom_name=first_geom_name,
+ parent_node_name=link_name,
+ transform=np.eye(4),
+ )
+
+ def _create_scene(
+ self,
+ use_collision_geometry=False,
+ load_geometry=True,
+ force_mesh=False,
+ force_single_geometry_per_link=False,
+ ):
+ s = trimesh.scene.Scene(base_frame=self._base_link)
+
+ for j in self.robot.joints:
+ matrix, _ = self._forward_kinematics_joint(j)
+
+ s.graph.update(frame_from=j.parent, frame_to=j.child, matrix=matrix)
+
+ for l in self.robot.links:
+ if l.name not in s.graph.nodes and l.name != s.graph.base_frame:
+ _logger.warning(f"{l.name} not connected via joints. Will add link to base frame.")
+ s.graph.update(frame_from=s.graph.base_frame, frame_to=l.name)
+
+ meshes = l.collisions if use_collision_geometry else l.visuals
+ self._add_geometries_to_scene(
+ s,
+ geometries=meshes,
+ link_name=l.name,
+ load_geometry=load_geometry,
+ force_mesh=force_mesh,
+ force_single_geometry=force_single_geometry_per_link,
+ skip_materials=use_collision_geometry,
+ )
+
+ return s
+
+ def _successors(self, node):
+ """
+ Get all nodes of the scene that succeeds a specified node.
+
+ Parameters
+ ------------
+ node : any
+ Hashable key in `scene.graph`
+
+ Returns
+ -----------
+ subnodes : set[str]
+ Set of nodes.
+ """
+ # get every node that is a successor to specified node
+ # this includes `node`
+ return self._scene.graph.transforms.successors(node)
+
+ def _create_subrobot(self, robot_name, root_link_name):
+ subrobot = Robot(name=robot_name)
+ subnodes = self._successors(node=root_link_name)
+
+ if len(subnodes) > 0:
+ for node in subnodes:
+ if node in self.link_map:
+ subrobot.links.append(copy.deepcopy(self.link_map[node]))
+ for joint_name, joint in self.joint_map.items():
+ if joint.parent in subnodes and joint.child in subnodes:
+ subrobot.joints.append(copy.deepcopy(self.joint_map[joint_name]))
+
+ return subrobot
+
+ def split_along_joints(self, joint_type="floating", **kwargs):
+ """Split URDF model along a particular joint type.
+ The result is a set of URDF models which together compose the original URDF.
+
+ Args:
+ joint_type (str, or list[str], optional): Type of joint to use for splitting. Defaults to "floating".
+ **kwargs: Arguments delegated to URDF constructor of new URDF models.
+
+ Returns:
+ list[(np.ndarray, yourdfpy.URDF)]: A list of tuples (np.ndarray, yourdfpy.URDF) whereas each homogeneous 4x4 matrix describes the root transformation of the respective URDF model w.r.t. the original URDF.
+ """
+ root_urdf = URDF(robot=copy.deepcopy(self.robot), build_scene_graph=False, load_meshes=False)
+ result = []
+
+ joint_types = joint_type if isinstance(joint_type, list) else [joint_type]
+
+ # find all relevant joints
+ joint_names = [j.name for j in self.robot.joints if j.type in joint_types]
+ for joint_name in joint_names:
+ root_link = self.link_map[self.joint_map[joint_name].child]
+ new_robot = self._create_subrobot(
+ robot_name=root_link.name,
+ root_link_name=root_link.name,
+ )
+
+ result.append(
+ (
+ self._scene.graph.get(root_link.name)[0],
+ URDF(robot=new_robot, **kwargs),
+ )
+ )
+
+ # remove links and joints from root robot
+ for j in new_robot.joints:
+ root_urdf.robot.joints.remove(root_urdf.joint_map[j.name])
+ for l in new_robot.links:
+ root_urdf.robot.links.remove(root_urdf.link_map[l.name])
+
+ # remove joint that connects root urdf to root_link
+ if root_link.name in [j.child for j in root_urdf.robot.joints]:
+ root_urdf.robot.joints.remove(
+ root_urdf.robot.joints[[j.child for j in root_urdf.robot.joints].index(root_link.name)]
+ )
+
+ result.insert(0, (np.eye(4), URDF(robot=root_urdf.robot, **kwargs)))
+
+ return result
+
+ def validate_filenames(self):
+ for l in self.robot.links:
+ meshes = [m.geometry.mesh for m in l.collisions + l.visuals if m.geometry.mesh is not None]
+ for m in meshes:
+ _logger.debug(m.filename, "-->", self._filename_handler(m.filename))
+ if not os.path.isfile(self._filename_handler(m.filename)):
+ return False
+ return True
+
+ def write_xml(self):
+ """Write URDF model to an XML element hierarchy.
+
+ Returns:
+ etree.ElementTree: XML data.
+ """
+ xml_element = self._write_robot(self.robot)
+ return etree.ElementTree(xml_element)
+
+ def write_xml_string(self, **kwargs):
+ """Write URDF model to a string.
+
+ Returns:
+ str: String of the xml representation of the URDF model.
+ """
+ xml_element = self.write_xml()
+ return etree.tostring(xml_element, xml_declaration=True, *kwargs)
+
+ def write_xml_file(self, fname):
+ """Write URDF model to an xml file.
+
+ Args:
+ fname (str): Filename of the file to be written. Usually ends in `.urdf`.
+ """
+ xml_element = self.write_xml()
+ xml_element.write(fname, xml_declaration=True, pretty_print=True)
+
+ def _parse_mimic(xml_element):
+ if xml_element is None:
+ return None
+
+ return Mimic(
+ joint=xml_element.get("joint"),
+ multiplier=_str2float(xml_element.get("multiplier", 1.0)),
+ offset=_str2float(xml_element.get("offset", 0.0)),
+ )
+
+ def _write_mimic(self, xml_parent, mimic):
+ etree.SubElement(
+ xml_parent,
+ "mimic",
+ attrib={
+ "joint": mimic.joint,
+ "multiplier": str(mimic.multiplier),
+ "offset": str(mimic.offset),
+ },
+ )
+
+ def _parse_safety_controller(xml_element):
+ if xml_element is None:
+ return None
+
+ return SafetyController(
+ soft_lower_limit=_str2float(xml_element.get("soft_lower_limit")),
+ soft_upper_limit=_str2float(xml_element.get("soft_upper_limit")),
+ k_position=_str2float(xml_element.get("k_position")),
+ k_velocity=_str2float(xml_element.get("k_velocity")),
+ )
+
+ def _write_safety_controller(self, xml_parent, safety_controller):
+ etree.SubElement(
+ xml_parent,
+ "safety_controller",
+ attrib={
+ "soft_lower_limit": str(safety_controller.soft_lower_limit),
+ "soft_upper_limit": str(safety_controller.soft_upper_limit),
+ "k_position": str(safety_controller.k_position),
+ "k_velocity": str(safety_controller.k_velocity),
+ },
+ )
+
+ def _parse_transmission_joint(xml_element):
+ if xml_element is None:
+ return None
+
+ transmission_joint = TransmissionJoint(name=xml_element.get("name"))
+
+ for h in xml_element.findall("hardware_interface"):
+ transmission_joint.hardware_interfaces.append(h.text)
+
+ return transmission_joint
+
+ def _write_transmission_joint(self, xml_parent, transmission_joint):
+ xml_element = etree.SubElement(
+ xml_parent,
+ "joint",
+ attrib={
+ "name": str(transmission_joint.name),
+ },
+ )
+ for h in transmission_joint.hardware_interfaces:
+ tmp = etree.SubElement(
+ xml_element,
+ "hardwareInterface",
+ )
+ tmp.text = h
+
+ def _parse_actuator(xml_element):
+ if xml_element is None:
+ return None
+
+ actuator = Actuator(name=xml_element.get("name"))
+ if xml_element.find("mechanicalReduction"):
+ actuator.mechanical_reduction = float(xml_element.find("mechanicalReduction").text)
+
+ for h in xml_element.findall("hardwareInterface"):
+ actuator.hardware_interfaces.append(h.text)
+
+ return actuator
+
+ def _write_actuator(self, xml_parent, actuator):
+ xml_element = etree.SubElement(
+ xml_parent,
+ "actuator",
+ attrib={
+ "name": str(actuator.name),
+ },
+ )
+ if actuator.mechanical_reduction is not None:
+ tmp = etree.SubElement("mechanicalReduction")
+ tmp.text = str(actuator.mechanical_reduction)
+
+ for h in actuator.hardware_interfaces:
+ tmp = etree.SubElement(
+ xml_element,
+ "hardwareInterface",
+ )
+ tmp.text = h
+
+ def _parse_transmission(xml_element):
+ if xml_element is None:
+ return None
+
+ transmission = Transmission(name=xml_element.get("name"))
+
+ for j in xml_element.findall("joint"):
+ transmission.joints.append(URDF._parse_transmission_joint(j))
+ for a in xml_element.findall("actuator"):
+ transmission.actuators.append(URDF._parse_actuator(a))
+
+ return transmission
+
+ def _write_transmission(self, xml_parent, transmission):
+ xml_element = etree.SubElement(
+ xml_parent,
+ "transmission",
+ attrib={
+ "name": str(transmission.name),
+ },
+ )
+
+ for j in transmission.joints:
+ self._write_transmission_joint(xml_element, j)
+
+ for a in transmission.actuators:
+ self._write_actuator(xml_element, a)
+
+ def _parse_calibration(xml_element):
+ if xml_element is None:
+ return None
+
+ return Calibration(
+ rising=_str2float(xml_element.get("rising")),
+ falling=_str2float(xml_element.get("falling")),
+ )
+
+ def _write_calibration(self, xml_parent, calibration):
+ etree.SubElement(
+ xml_parent,
+ "calibration",
+ attrib={
+ "rising": str(calibration.rising),
+ "falling": str(calibration.falling),
+ },
+ )
+
+ def _parse_box(xml_element):
+ return Box(size=np.array(xml_element.attrib["size"].split(), dtype=float))
+
+ def _write_box(self, xml_parent, box):
+ etree.SubElement(xml_parent, "box", attrib={"size": " ".join(map(str, box.size))})
+
+ def _parse_cylinder(xml_element):
+ return Cylinder(
+ radius=float(xml_element.attrib["radius"]),
+ length=float(xml_element.attrib["length"]),
+ )
+
+ def _write_cylinder(self, xml_parent, cylinder):
+ etree.SubElement(
+ xml_parent,
+ "cylinder",
+ attrib={"radius": str(cylinder.radius), "length": str(cylinder.length)},
+ )
+
+ def _parse_sphere(xml_element):
+ return Sphere(radius=float(xml_element.attrib["radius"]))
+
+ def _write_sphere(self, xml_parent, sphere):
+ etree.SubElement(xml_parent, "sphere", attrib={"radius": str(sphere.radius)})
+
+ def _parse_scale(xml_element):
+ if "scale" in xml_element.attrib:
+ s = xml_element.get("scale").split()
+ if len(s) == 0:
+ return None
+ elif len(s) == 1:
+ return float(s[0])
+ else:
+ return np.array(list(map(float, s)))
+ return None
+
+ def _write_scale(self, xml_parent, scale):
+ if scale is not None:
+ if isinstance(scale, float) or isinstance(scale, int):
+ xml_parent.set("scale", " ".join([str(scale)] * 3))
+ else:
+ xml_parent.set("scale", " ".join(map(str, scale)))
+
+ def _parse_mesh(xml_element):
+ return Mesh(filename=xml_element.get("filename"), scale=URDF._parse_scale(xml_element))
+
+ def _write_mesh(self, xml_parent, mesh):
+ # TODO: turn into different filename handler
+ xml_element = etree.SubElement(
+ xml_parent,
+ "mesh",
+ attrib={"filename": self._filename_handler(mesh.filename)},
+ )
+
+ self._write_scale(xml_element, mesh.scale)
+
+ def _parse_geometry(xml_element):
+ geometry = Geometry()
+ if xml_element[0].tag == "box":
+ geometry.box = URDF._parse_box(xml_element[0])
+ elif xml_element[0].tag == "cylinder":
+ geometry.cylinder = URDF._parse_cylinder(xml_element[0])
+ elif xml_element[0].tag == "sphere":
+ geometry.sphere = URDF._parse_sphere(xml_element[0])
+ elif xml_element[0].tag == "mesh":
+ geometry.mesh = URDF._parse_mesh(xml_element[0])
+ else:
+ raise ValueError(f"Unknown tag: {xml_element[0].tag}")
+
+ return geometry
+
+ def _validate_geometry(self, geometry):
+ if geometry is None:
+ self._errors.append(URDFIncompleteError(" is missing."))
+
+ num_nones = sum(
+ [
+ x is not None
+ for x in [
+ geometry.box,
+ geometry.cylinder,
+ geometry.sphere,
+ geometry.mesh,
+ ]
+ ]
+ )
+ if num_nones < 1:
+ self._errors.append(
+ URDFIncompleteError(
+ "One of , , , needs to be defined as a child of ."
+ )
+ )
+ elif num_nones > 1:
+ self._errors.append(
+ URDFError(
+ "Too many of , , , defined as a child of . Only one allowed."
+ )
+ )
+
+ def _write_geometry(self, xml_parent, geometry):
+ if geometry is None:
+ return
+
+ xml_element = etree.SubElement(xml_parent, "geometry")
+ if geometry.box is not None:
+ self._write_box(xml_element, geometry.box)
+ elif geometry.cylinder is not None:
+ self._write_cylinder(xml_element, geometry.cylinder)
+ elif geometry.sphere is not None:
+ self._write_sphere(xml_element, geometry.sphere)
+ elif geometry.mesh is not None:
+ self._write_mesh(xml_element, geometry.mesh)
+
+ def _parse_origin(xml_element):
+ if xml_element is None:
+ return None
+
+ xyz = xml_element.get("xyz", default="0 0 0")
+ rpy = xml_element.get("rpy", default="0 0 0")
+
+ return tra.compose_matrix(
+ translate=np.array(list(map(float, xyz.split()))),
+ angles=np.array(list(map(float, rpy.split()))),
+ )
+
+ def _write_origin(self, xml_parent, origin):
+ if origin is None:
+ return
+
+ etree.SubElement(
+ xml_parent,
+ "origin",
+ attrib={
+ "xyz": " ".join(map(str, tra.translation_from_matrix(origin))),
+ "rpy": " ".join(map(str, tra.euler_from_matrix(origin))),
+ },
+ )
+
+ def _parse_color(xml_element):
+ if xml_element is None:
+ return None
+
+ rgba = xml_element.get("rgba", default="1 1 1 1")
+
+ return Color(rgba=np.array(list(map(float, rgba.split()))))
+
+ def _write_color(self, xml_parent, color):
+ if color is None:
+ return
+
+ etree.SubElement(xml_parent, "color", attrib={"rgba": " ".join(map(str, color.rgba))})
+
+ def _parse_texture(xml_element):
+ if xml_element is None:
+ return None
+
+ # TODO: use texture filename handler
+ return Texture(filename=xml_element.get("filename", default=None))
+
+ def _write_texture(self, xml_parent, texture):
+ if texture is None:
+ return
+
+ # TODO: use texture filename handler
+ etree.SubElement(xml_parent, "texture", attrib={"filename": texture.filename})
+
+ def _parse_material(xml_element):
+ if xml_element is None:
+ return None
+
+ material = Material(name=xml_element.get("name"))
+ material.color = URDF._parse_color(xml_element.find("color"))
+ material.texture = URDF._parse_texture(xml_element.find("texture"))
+
+ return material
+
+ def _write_material(self, xml_parent, material):
+ if material is None:
+ return
+
+ attrib = {"name": material.name} if material.name is not None else {}
+ xml_element = etree.SubElement(
+ xml_parent,
+ "material",
+ attrib=attrib,
+ )
+
+ self._write_color(xml_element, material.color)
+ self._write_texture(xml_element, material.texture)
+
+ def _parse_visual(xml_element):
+ visual = Visual(name=xml_element.get("name"))
+
+ visual.geometry = URDF._parse_geometry(xml_element.find("geometry"))
+ visual.origin = URDF._parse_origin(xml_element.find("origin"))
+ visual.material = URDF._parse_material(xml_element.find("material"))
+
+ return visual
+
+ def _validate_visual(self, visual):
+ self._validate_geometry(visual.geometry)
+
+ def _write_visual(self, xml_parent, visual):
+ attrib = {"name": visual.name} if visual.name is not None else {}
+ xml_element = etree.SubElement(
+ xml_parent,
+ "visual",
+ attrib=attrib,
+ )
+
+ self._write_geometry(xml_element, visual.geometry)
+ self._write_origin(xml_element, visual.origin)
+ self._write_material(xml_element, visual.material)
+
+ def _parse_collision(xml_element):
+ collision = Collision(name=xml_element.get("name"))
+
+ collision.geometry = URDF._parse_geometry(xml_element.find("geometry"))
+ collision.origin = URDF._parse_origin(xml_element.find("origin"))
+
+ return collision
+
+ def _validate_collision(self, collision):
+ self._validate_geometry(collision.geometry)
+
+ def _write_collision(self, xml_parent, collision):
+ attrib = {"name": collision.name} if collision.name is not None else {}
+ xml_element = etree.SubElement(
+ xml_parent,
+ "collision",
+ attrib=attrib,
+ )
+
+ self._write_geometry(xml_element, collision.geometry)
+ self._write_origin(xml_element, collision.origin)
+
+ def _parse_inertia(xml_element):
+ if xml_element is None:
+ return None
+
+ x = xml_element
+
+ return np.array(
+ [
+ [
+ x.get("ixx", default=1.0),
+ x.get("ixy", default=0.0),
+ x.get("ixz", default=0.0),
+ ],
+ [
+ x.get("ixy", default=0.0),
+ x.get("iyy", default=1.0),
+ x.get("iyz", default=0.0),
+ ],
+ [
+ x.get("ixz", default=0.0),
+ x.get("iyz", default=0.0),
+ x.get("izz", default=1.0),
+ ],
+ ],
+ dtype=np.float64,
+ )
+
+ def _write_inertia(self, xml_parent, inertia):
+ if inertia is None:
+ return None
+
+ etree.SubElement(
+ xml_parent,
+ "inertia",
+ attrib={
+ "ixx": str(inertia[0, 0]),
+ "ixy": str(inertia[0, 1]),
+ "ixz": str(inertia[0, 2]),
+ "iyy": str(inertia[1, 1]),
+ "iyz": str(inertia[1, 2]),
+ "izz": str(inertia[2, 2]),
+ },
+ )
+
+ def _parse_mass(xml_element):
+ if xml_element is None:
+ return None
+
+ return _str2float(xml_element.get("value", default=0.0))
+
+ def _write_mass(self, xml_parent, mass):
+ if mass is None:
+ return
+
+ etree.SubElement(
+ xml_parent,
+ "mass",
+ attrib={
+ "value": str(mass),
+ },
+ )
+
+ def _parse_inertial(xml_element):
+ if xml_element is None:
+ return None
+
+ inertial = Inertial()
+ inertial.origin = URDF._parse_origin(xml_element.find("origin"))
+ inertial.inertia = URDF._parse_inertia(xml_element.find("inertia"))
+ inertial.mass = URDF._parse_mass(xml_element.find("mass"))
+
+ return inertial
+
+ def _write_inertial(self, xml_parent, inertial):
+ if inertial is None:
+ return
+
+ xml_element = etree.SubElement(xml_parent, "inertial")
+
+ self._write_origin(xml_element, inertial.origin)
+ self._write_mass(xml_element, inertial.mass)
+ self._write_inertia(xml_element, inertial.inertia)
+
+ def _parse_link(xml_element):
+ link = Link(name=xml_element.attrib["name"])
+
+ link.inertial = URDF._parse_inertial(xml_element.find("inertial"))
+
+ for v in xml_element.findall("visual"):
+ link.visuals.append(URDF._parse_visual(v))
+
+ for c in xml_element.findall("collision"):
+ link.collisions.append(URDF._parse_collision(c))
+
+ return link
+
+ def _validate_link(self, link):
+ self._validate_required_attribute(attribute=link.name, error_msg="The tag misses a 'name' attribute.")
+
+ for v in link.visuals:
+ self._validate_visual(v)
+
+ for c in link.collisions:
+ self._validate_collision(c)
+
+ def _write_link(self, xml_parent, link):
+ xml_element = etree.SubElement(
+ xml_parent,
+ "link",
+ attrib={
+ "name": link.name,
+ },
+ )
+
+ self._write_inertial(xml_element, link.inertial)
+ for visual in link.visuals:
+ self._write_visual(xml_element, visual)
+ for collision in link.collisions:
+ self._write_collision(xml_element, collision)
+
+ def _parse_axis(xml_element):
+ if xml_element is None:
+ return np.array([1.0, 0, 0])
+
+ xyz = xml_element.get("xyz", "1 0 0")
+ results = []
+ for x in xyz.split():
+ try:
+ x = float(x)
+ except ValueError:
+ x = 0
+ results.append(x)
+ return np.array(results)
+ # return np.array(list(map(float, xyz.split())))
+
+ def _write_axis(self, xml_parent, axis):
+ if axis is None:
+ return
+
+ etree.SubElement(xml_parent, "axis", attrib={"xyz": " ".join(map(str, axis))})
+
+ def _parse_limit(xml_element):
+ if xml_element is None:
+ return None
+
+ return Limit(
+ effort=_str2float(xml_element.get("effort", default=None)),
+ velocity=_str2float(xml_element.get("velocity", default=None)),
+ lower=_str2float(xml_element.get("lower", default=None)),
+ upper=_str2float(xml_element.get("upper", default=None)),
+ )
+
+ def _validate_limit(self, limit, type):
+ if type in ["revolute", "prismatic"]:
+ self._validate_required_attribute(
+ limit,
+ error_msg="The of a (prismatic, revolute) joint is missing.",
+ )
+
+ if limit is not None:
+ self._validate_required_attribute(
+ limit.upper,
+ error_msg="Tag of joint is missing attribute 'upper'.",
+ )
+ self._validate_required_attribute(
+ limit.lower,
+ error_msg="Tag of joint is missing attribute 'lower'.",
+ )
+
+ if limit is not None:
+ self._validate_required_attribute(
+ limit.effort,
+ error_msg="Tag of joint is missing attribute 'effort'.",
+ )
+
+ self._validate_required_attribute(
+ limit.velocity,
+ error_msg="Tag of joint is missing attribute 'velocity'.",
+ )
+
+ def _write_limit(self, xml_parent, limit):
+ if limit is None:
+ return
+
+ attrib = {}
+ if limit.effort is not None:
+ attrib["effort"] = str(limit.effort)
+ if limit.velocity is not None:
+ attrib["velocity"] = str(limit.velocity)
+ if limit.lower is not None:
+ attrib["lower"] = str(limit.lower)
+ if limit.upper is not None:
+ attrib["upper"] = str(limit.upper)
+
+ etree.SubElement(
+ xml_parent,
+ "limit",
+ attrib=attrib,
+ )
+
+ def _parse_dynamics(xml_element):
+ if xml_element is None:
+ return None
+
+ dynamics = Dynamics()
+ dynamics.damping = xml_element.get("damping", default=None)
+ dynamics.friction = xml_element.get("friction", default=None)
+
+ return dynamics
+
+ def _write_dynamics(self, xml_parent, dynamics):
+ if dynamics is None:
+ return
+
+ attrib = {}
+ if dynamics.damping is not None:
+ attrib["damping"] = str(dynamics.damping)
+ if dynamics.friction is not None:
+ attrib["friction"] = str(dynamics.friction)
+
+ etree.SubElement(
+ xml_parent,
+ "dynamics",
+ attrib=attrib,
+ )
+
+ def _parse_joint(xml_element):
+ joint = Joint(name=xml_element.attrib["name"])
+
+ joint.type = xml_element.get("type", default=None)
+ joint.parent = xml_element.find("parent").get("link")
+ joint.child = xml_element.find("child").get("link")
+ joint.origin = URDF._parse_origin(xml_element.find("origin"))
+ joint.axis = URDF._parse_axis(xml_element.find("axis"))
+ joint.limit = URDF._parse_limit(xml_element.find("limit"))
+ joint.dynamics = URDF._parse_dynamics(xml_element.find("dynamics"))
+ joint.mimic = URDF._parse_mimic(xml_element.find("mimic"))
+ joint.calibration = URDF._parse_calibration(xml_element.find("calibration"))
+ joint.safety_controller = URDF._parse_safety_controller(xml_element.find("safety_controller"))
+
+ return joint
+
+ def _validate_joint(self, joint):
+ self._validate_required_attribute(
+ attribute=joint.name,
+ error_msg="The tag misses a 'name' attribute.",
+ )
+
+ allowed_types = [
+ "revolute",
+ "continuous",
+ "prismatic",
+ "fixed",
+ "floating",
+ "planar",
+ ]
+ self._validate_required_attribute(
+ attribute=joint.type,
+ error_msg=f"The tag misses a 'type' attribute or value is not part of allowed values [{', '.join(allowed_types)}].",
+ allowed_values=allowed_types,
+ )
+
+ self._validate_required_attribute(
+ joint.parent,
+ error_msg=f"The of a is missing.",
+ )
+
+ self._validate_required_attribute(
+ joint.child,
+ error_msg=f"The of a is missing.",
+ )
+
+ self._validate_limit(joint.limit, type=joint.type)
+
+ def _write_joint(self, xml_parent, joint):
+ xml_element = etree.SubElement(
+ xml_parent,
+ "joint",
+ attrib={
+ "name": joint.name,
+ "type": joint.type,
+ },
+ )
+
+ etree.SubElement(xml_element, "parent", attrib={"link": joint.parent})
+ etree.SubElement(xml_element, "child", attrib={"link": joint.child})
+ self._write_origin(xml_element, joint.origin)
+ self._write_axis(xml_element, joint.axis)
+ self._write_limit(xml_element, joint.limit)
+ self._write_dynamics(xml_element, joint.dynamics)
+
+ @staticmethod
+ def _parse_robot(xml_element):
+ robot = Robot(name=xml_element.attrib["name"])
+
+ for l in xml_element.findall("link"):
+ robot.links.append(URDF._parse_link(l))
+ for j in xml_element.findall("joint"):
+ robot.joints.append(URDF._parse_joint(j))
+ for m in xml_element.findall("material"):
+ robot.materials.append(URDF._parse_material(m))
+ return robot
+
+ def _validate_robot(self, robot):
+ if robot is not None:
+ self._validate_required_attribute(
+ attribute=robot.name,
+ error_msg="The tag misses a 'name' attribute.",
+ )
+
+ for l in robot.links:
+ self._validate_link(l)
+
+ for j in robot.joints:
+ self._validate_joint(j)
+
+ def _write_robot(self, robot):
+ xml_element = etree.Element("robot", attrib={"name": robot.name})
+ for link in robot.links:
+ self._write_link(xml_element, link)
+ for joint in robot.joints:
+ self._write_joint(xml_element, joint)
+ for material in robot.materials:
+ self._write_material(xml_element, material)
+
+ return xml_element
+
+ def __eq__(self, other):
+ if not isinstance(other, URDF):
+ raise NotImplemented
+ return self.robot == other.robot
+
+ @property
+ def filename_handler(self):
+ return self._filename_handler
+
+ def build_tree(self):
+ parent_child_map: Dict[str, List[str]] = {}
+ for joint in self.robot.joints:
+ if joint.parent in parent_child_map:
+ parent_child_map[joint.parent].append(joint.child)
+ else:
+ parent_child_map[joint.parent] = [joint.child]
+
+ # Sort link with bfs order
+ bfs_link_list = [self.base_link]
+ to_be_handle_list = [self.base_link]
+ while len(to_be_handle_list) > 0:
+ parent = to_be_handle_list.pop(0)
+ if parent not in parent_child_map:
+ continue
+
+ children = parent_child_map[parent]
+ to_be_handle_list.extend(children)
+ bfs_link_list.extend(children)
+ bfs_joint_list = []
+ for link_name in bfs_link_list[1:]:
+ joint_index = [i for i in range(len(self.robot.joints)) if self.robot.joints[i].child == link_name][0]
+ bfs_joint_list.append(self.robot.joints[joint_index])
+
+ # Build tree
+ root = Node(self.base_link, matrix=np.eye(4))
+ for joint in bfs_joint_list:
+ matrix, _ = self._forward_kinematics_joint(joint, 0)
+ parent_node = anytree.search.findall_by_attr(root, value=joint.parent)[0]
+ node = Node(joint.child, parent=parent_node, matrix=matrix)
+ return root
+
+ def update_kinematics(self, configuration):
+ joint_cfg = []
+
+ if isinstance(configuration, dict):
+ for joint in configuration:
+ if isinstance(joint, six.string_types):
+ joint_cfg.append((self._joint_map[joint], configuration[joint]))
+ elif isinstance(joint, Joint):
+ # TODO: Joint is not hashable; so this branch will not succeed
+ joint_cfg.append((joint, configuration[joint]))
+ elif isinstance(configuration, (list, tuple, np.ndarray)):
+ if len(configuration) == len(self.robot.joints):
+ for joint, value in zip(self.robot.joints, configuration):
+ joint_cfg.append((joint, value))
+ elif len(configuration) == self.num_actuated_joints:
+ for joint, value in zip(self._actuated_joints, configuration):
+ joint_cfg.append((joint, value))
+ else:
+ raise ValueError(
+ f"Dimensionality of configuration ({len(configuration)}) doesn't match number of all ({len(self.robot.joints)}) or actuated joints ({self.num_actuated_joints})."
+ )
+ else:
+ raise TypeError("Invalid type for configuration")
+
+ # append all mimic joints in the update
+ for j, q in joint_cfg + [(j, 0.0) for j in self.robot.joints if j.mimic is not None]:
+ matrix, _ = self._forward_kinematics_joint(j, q=q)
+ node = anytree.search.findall_by_attr(self.tree_root, j.child)[0]
+ node.matrix = matrix
+
+ for node in LevelOrderIter(self.tree_root):
+ if node.name == self.base_link:
+ node.global_pose = np.eye(4)
+ else:
+ node.global_pose = node.parent.global_pose @ node.matrix
+
+ def get_link_global_transform(self, link_name):
+ node = anytree.search.findall_by_attr(self.tree_root, link_name)[0]
+
+ return node.global_pose
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..0e54558
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,8 @@
+[tool.black]
+line_length = 120
+
+[tool.pytest.ini_options]
+addopts = "--disable-warnings"
+testpaths = [
+ "tests",
+]
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..15d5ba8
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,88 @@
+import re
+from pathlib import Path
+
+from setuptools import setup, find_packages
+
+_here = Path(__file__).resolve().parent
+name = "dex_retargeting"
+
+# Reference: https://github.com/kevinzakka/mjc_viewer/blob/main/setup.py
+with open(_here / name / "__init__.py") as f:
+ meta_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M)
+ if meta_match:
+ version = meta_match.group(1)
+ else:
+ raise RuntimeError("Unable to find __version__ string.")
+
+core_requirements = [
+ "numpy",
+ "torch",
+ "sapien>=2.0.0",
+ "nlopt",
+ "trimesh",
+ "anytree",
+ "pycollada",
+]
+
+test_requirements = [
+ "pytest",
+ "black",
+ "isort",
+ "pytest-xdist",
+ "pyright",
+ "ruff",
+]
+
+classifiers = [
+ "Development Status :: 3 - Alpha",
+ "Intended Audience :: Developers",
+ "Intended Audience :: Science/Research",
+ "License :: OSI Approved :: MIT License",
+ "Natural Language :: English",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
+]
+
+
+def setup_package():
+ # Meta information of the project
+ author = "Yuzhe Qin"
+ author_email = "y1qin@ucsd.edu"
+ description = "Hand pose retargeting for dexterous robot hand."
+ url = "https://github.com/dexsuite/dex-retargeting"
+ with open(_here / "README.md", "r") as file:
+ readme = file.read()
+
+ # Package data
+ packages = find_packages(".")
+ print(f"Packages: {packages}")
+
+ setup(
+ name=name,
+ version=version,
+ author=author,
+ author_email=author_email,
+ maintainer=author,
+ maintainer_email=author_email,
+ description=description,
+ long_description=readme,
+ long_description_content_type="text/markdown",
+ url=url,
+ license='MIT',
+ license_files=("LICENSE",),
+ packages=packages,
+ python_requires='>=3.7,<3.11',
+ zip_safe=True,
+ install_requires=core_requirements,
+ extras_require={
+ "test": test_requirements,
+ },
+ classifiers=classifiers,
+ )
+
+
+setup_package()
diff --git a/tests/test_retargeting_config.py b/tests/test_retargeting_config.py
new file mode 100644
index 0000000..1b9a7fc
--- /dev/null
+++ b/tests/test_retargeting_config.py
@@ -0,0 +1,24 @@
+from pathlib import Path
+
+import pytest
+
+from dex_retargeting.retargeting_config import RetargetingConfig
+from dex_retargeting.seq_retarget import SeqRetargeting
+from utils import VECTOR_CONFIG_DICT, POSITION_CONFIG_DICT, DEXPILOT_CONFIG_DICT
+
+
+class TestRetargetingConfig:
+ config_dir = Path(__file__).parent.parent / "dex_retargeting" / "configs"
+ robot_dir = Path(__file__).parent.parent / "assets" / "robots"
+ RetargetingConfig.set_default_urdf_dir(str(robot_dir.absolute()))
+
+ config_paths = (
+ list(VECTOR_CONFIG_DICT.values()) + list(POSITION_CONFIG_DICT.values()) + list(DEXPILOT_CONFIG_DICT.values())
+ )
+
+ @pytest.mark.parametrize("config_path", config_paths)
+ def test_config_parsing(self, config_path):
+ config_path = self.config_dir / config_path
+ config = RetargetingConfig.load_from_file(config_path)
+ retargeting = config.build()
+ assert type(retargeting) == SeqRetargeting
diff --git a/tests/test_sapien_optimizer.py b/tests/test_sapien_optimizer.py
new file mode 100644
index 0000000..dea42aa
--- /dev/null
+++ b/tests/test_sapien_optimizer.py
@@ -0,0 +1,125 @@
+from pathlib import Path
+from time import time
+
+import numpy as np
+import pytest
+import sapien.core as sapien
+
+from dex_retargeting.optimizer import VectorOptimizer, PositionOptimizer
+from dex_retargeting.retargeting_config import RetargetingConfig
+from utils import ROBOT_NAMES, VECTOR_CONFIG_DICT, POSITION_CONFIG_DICT
+
+
+class TestVectorOptimizer:
+ np.set_printoptions(precision=4)
+ config_dir = Path(__file__).parent.parent / "dex_retargeting" / "configs"
+ robot_dir = Path(__file__).parent.parent / "assets" / "robots"
+ RetargetingConfig.set_default_urdf_dir(str(robot_dir.absolute()))
+
+ @staticmethod
+ def generate_vector_retargeting_data_gt(robot: sapien.Articulation, optimizer: VectorOptimizer):
+ joint_limit = robot.get_qlimits()
+ random_qpos = np.random.uniform(joint_limit[:, 0], joint_limit[:, 1])
+ robot.set_qpos(random_qpos)
+
+ random_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.robot_link_indices])
+ origin_pos = random_pos[optimizer.origin_link_indices]
+ task_pos = random_pos[optimizer.task_link_indices]
+ random_target_vector = task_pos - origin_pos
+ init_qpos = np.clip(random_qpos + np.random.randn(robot.dof) * 0.5, joint_limit[:, 0], joint_limit[:, 1])
+
+ return random_qpos, init_qpos, random_target_vector
+
+ @staticmethod
+ def generate_position_retargeting_data_gt(robot: sapien.Articulation, optimizer: PositionOptimizer):
+ joint_limit = robot.get_qlimits()
+ random_qpos = np.random.uniform(joint_limit[:, 0], joint_limit[:, 1])
+ robot.set_qpos(random_qpos)
+
+ random_target_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.target_link_indices])
+ init_qpos = np.clip(random_qpos + np.random.randn(robot.dof) * 0.5, joint_limit[:, 0], joint_limit[:, 1])
+
+ return random_qpos, init_qpos, random_target_pos
+
+ @pytest.mark.parametrize("robot_name", ROBOT_NAMES[:1])
+ def test_position_optimizer(self, robot_name):
+ config_path = self.config_dir / POSITION_CONFIG_DICT[robot_name]
+
+ # Note: The parameters below are adjusted solely for this test
+ # The smoothness penalty is deactivated here, meaning no low pass filter and no continuous joint value
+ # This is because the test is focused solely on the efficiency of single step optimization
+ override = dict()
+ override["normal_delta"] = 0
+ config = RetargetingConfig.load_from_file(config_path, override)
+
+ retargeting = config.build()
+
+ robot = retargeting.optimizer.robot
+ optimizer = retargeting.optimizer
+
+ num_optimization = 100
+ tic = time()
+ errors = dict(pos=[], joint=[])
+ np.random.seed(1)
+ for i in range(num_optimization):
+ # Sampled random position
+ random_qpos, init_qpos, random_target_pos = self.generate_position_retargeting_data_gt(robot, optimizer)
+
+ # Optimized position
+ computed_qpos = optimizer.retarget(random_target_pos, fixed_qpos=[], last_qpos=init_qpos[:])
+ robot.set_qpos(np.array(computed_qpos))
+ computed_target_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.target_link_indices])
+
+ # Position difference
+ error = np.mean(np.linalg.norm(computed_target_pos - random_target_pos, axis=1))
+ errors["pos"].append(error)
+
+ tac = time()
+ print(f"Mean optimization position error: {np.mean(errors['pos'])}")
+ print(f"Retargeting computation takes {tac - tic}s for {num_optimization} times")
+ assert np.mean(errors["pos"]) < 1e-2
+
+ @pytest.mark.parametrize("robot_name", ROBOT_NAMES)
+ def test_vector_optimizer(self, robot_name):
+ config_path = self.config_dir / VECTOR_CONFIG_DICT[robot_name]
+
+ # Note: The parameters below are adjusted solely for this test
+ # For retargeting from human to robot, their values should remain the default in the retargeting config
+ # The smoothness penalty is deactivated here, meaning no low pass filter and no continuous joint value
+ # This is because the test is focused solely on the efficiency of single step optimization
+ override = dict()
+ override["type"] = "vector" # cast it to vector retargeting even if it is DexPilot retargeting
+ override["low_pass_alpha"] = 0
+ override["scaling_factor"] = 1.0
+ override["normal_delta"] = 0
+ config = RetargetingConfig.load_from_file(config_path, override)
+
+ retargeting = config.build()
+
+ robot = retargeting.optimizer.robot
+ optimizer = retargeting.optimizer
+
+ num_optimization = 100
+ tic = time()
+ errors = dict(pos=[], joint=[])
+ np.random.seed(1)
+ for i in range(num_optimization):
+ # Sampled random vector
+ random_qpos, init_qpos, random_target_vector = self.generate_vector_retargeting_data_gt(robot, optimizer)
+
+ # Optimized vector
+ computed_qpos = optimizer.retarget(random_target_vector, fixed_qpos=[], last_qpos=init_qpos[:])
+ robot.set_qpos(np.array(computed_qpos))
+ computed_pos = np.array([robot.get_links()[i].get_pose().p for i in optimizer.robot_link_indices])
+ computed_origin_pos = computed_pos[optimizer.origin_link_indices]
+ computed_task_pos = computed_pos[optimizer.task_link_indices]
+ computed_target_vector = computed_task_pos - computed_origin_pos
+
+ # Vector difference
+ error = np.mean(np.linalg.norm(computed_target_vector - random_target_vector, axis=1))
+ errors["pos"].append(error)
+
+ tac = time()
+ print(f"Mean optimization vector error: {np.mean(errors['pos'])}")
+ print(f"Retargeting computation takes {tac - tic}s for {num_optimization} times")
+ assert np.mean(errors["pos"]) < 1e-2
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 0000000..7ce9961
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,17 @@
+VECTOR_CONFIG_DICT = {
+ "allegro_right": "teleop/allegro_hand_right.yml",
+ "allegro_left": "teleop/allegro_hand_left.yml",
+ "shadow_right": "teleop/shadow_hand_right.yml",
+ "svh_right": "teleop/schunk_svh_hand_right.yml",
+}
+POSITION_CONFIG_DICT = {
+ "allegro_right": "offline/allegro_hand_right.yml",
+ # "allegro_left": "offline/allegro_hand_left.yml",
+ # "shadow_right": "offline/shadow_hand_right.yml",
+ # "svh_right": "offline/schunk_svh_hand_right.yml",
+}
+DEXPILOT_CONFIG_DICT = {
+ "allegro_right": "teleop/allegro_hand_right_dexpilot.yml",
+}
+
+ROBOT_NAMES = list(VECTOR_CONFIG_DICT.keys())