Skip to content

Commit

Permalink
Update the parser logic to use the rod library
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Oct 5, 2022
1 parent b077555 commit 48e18ed
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 98 deletions.
109 changes: 65 additions & 44 deletions src/jaxsim/parsers/sdf/parser.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import dataclasses
import pathlib
from pathlib import Path
from typing import Dict, List, NamedTuple, Union
from typing import Dict, List, NamedTuple, Optional, Union

import jax.numpy as jnp
import numpy as np
import pysdf
import rod

from jaxsim import logging
from jaxsim.math.quaternion import Quaternion
Expand All @@ -24,39 +25,57 @@ class SDFData(NamedTuple):
joint_descriptions: List[descriptions.JointDescription]
collision_shapes: List[descriptions.CollisionShape]

sdf_tree: pysdf.Model = None
sdf_model: Optional[rod.Model] = None
model_pose: kinematic_graph.RootPose = kinematic_graph.RootPose()


def extract_data_from_sdf(
sdf: Union[Path, str],
sdf: Union[pathlib.Path, str], model_name: Optional[str] = None
) -> SDFData:

if isinstance(sdf, str) and len(sdf) < 500 and Path(sdf).is_file():
sdf = Path(sdf)
# Parse the SDF resource
sdf_element = rod.Sdf.load(sdf=sdf)

# Get the SDF string
sdf_string = sdf if isinstance(sdf, str) else sdf.read_text()
if len(sdf_element.models()) == 0:
raise RuntimeError("Failed to find any model in SDF resource")

# Parse the tree
sdf_tree = pysdf.SDF.from_xml(sdf_string=sdf_string, remove_blank_text=True)
# Assume the SDF resource has only one model, or the desired model name is given
sdf_models = {m.name: m for m in sdf_element.models()}
sdf_model = (
sdf_element.models()[0] if len(sdf_models) == 1 else sdf_models[model_name]
)
logging.debug(msg=f"Found model '{sdf_model.name}' in SDF resource")

# Detect fixed-base models by checking the existence of joints having world as parent
sdf_joints_with_world_parent = [
j for j in sdf_model.joints() if j.parent == "world"
]
fixed_base = len(sdf_joints_with_world_parent) > 0

# Detect whether the model is fixed base by checking joints with world parent exist.
# This link is a special link used to specify that the model's base should be fixed.
fixed_base = len([j for j in sdf_tree.model.joints if j.parent == "world"]) > 0
logging.debug(
msg="Model '{}' is {}".format(
sdf_model.name, "fixed-base" if fixed_base else "floating-base"
)
)

# Base link of the model. We take the first link in the SDF description.
base_link_name = sdf_tree.model.links[0].name
# We extract the link connected to 'world', and consider it as base link.
# Instead, for floating-base models, we consider the first link as base link.
base_link_name = (
sdf_joints_with_world_parent[0].name
if fixed_base
else sdf_model.links()[0].name
)
logging.debug(msg=f"Considering '{base_link_name}' as base link")

# Pose of the model
if sdf_tree.model.pose is None:
if sdf_model.pose is None:
model_pose = kinematic_graph.RootPose()

else:
w_H_m = utils.from_sdf_pose(pose=sdf_tree.model.pose)
W_H_M = sdf_model.pose.transform()
model_pose = kinematic_graph.RootPose(
root_position=w_H_m[0:3, 3],
root_quaternion=Quaternion.from_dcm(dcm=w_H_m[0:3, 0:3]),
root_position=W_H_M[0:3, 3],
root_quaternion=Quaternion.from_dcm(dcm=W_H_M[0:3, 0:3]),
)

# ===========
Expand All @@ -69,9 +88,9 @@ def extract_data_from_sdf(
name=l.name,
mass=jnp.float32(l.inertial.mass),
inertia=utils.from_sdf_inertial(inertial=l.inertial),
pose=utils.from_sdf_pose(pose=l.pose) if l.pose is not None else np.eye(4),
pose=l.pose.transform() if l.pose is not None else np.eye(4),
)
for l in sdf_tree.model.links
for l in sdf_model.links()
if l.inertial.mass > 0
]

Expand All @@ -86,6 +105,7 @@ def extract_data_from_sdf(
# to the world and combine their pose
if fixed_base:

# Create a massless word link
world_link = descriptions.LinkDescription(
name="world", mass=0, inertia=np.zeros(shape=(6, 6))
)
Expand All @@ -100,20 +120,18 @@ def extract_data_from_sdf(
parent=world_link,
child=links_dict[j.child],
jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
axis=utils.from_sdf_string_list(string_list=j.axis.xyz.text)
axis=np.array(j.axis.xyz.xyz)
if j.axis is not None
and j.axis.xyz is not None
and j.axis.xyz.text is not None
and j.axis.xyz.xyz is not None
else None,
pose=utils.from_sdf_pose(pose=j.pose)
if j.pose is not None
else np.eye(4),
pose=j.pose.transform() if j.pose is not None else np.eye(4),
)
for j in sdf_tree.model.joints
for j in sdf_model.joints()
if j.type == "fixed"
and j.parent == "world"
and j.child in links_dict.keys()
and j.pose.relative_to == "__model__"
and j.pose.relative_to in {"__model__", None}
]

logging.debug(
Expand Down Expand Up @@ -146,14 +164,14 @@ def extract_data_from_sdf(
# ============

# Check that all joint poses are expressed w.r.t. their parent link
for j in sdf_tree.model.joints:
for j in sdf_model.joints():

if j.pose is None:
continue

if j.parent == "world":

if j.pose.relative_to == "__model__":
if j.pose.relative_to in {"__model__", None}:
continue

raise ValueError("Pose of fixed joint connecting to 'world' link not valid")
Expand All @@ -169,12 +187,12 @@ def extract_data_from_sdf(
parent=links_dict[j.parent],
child=links_dict[j.child],
jtype=utils.axis_to_jtype(axis=j.axis, type=j.type),
axis=utils.from_sdf_string_list(j.axis.xyz.text)
axis=np.array(j.axis.xyz.xyz)
if j.axis is not None
and j.axis.xyz is not None
and j.axis.xyz.text is not None
and j.axis.xyz.xyz is not None
else None,
pose=utils.from_sdf_pose(pose=j.pose) if j.pose is not None else np.eye(4),
pose=j.pose.transform() if j.pose is not None else np.eye(4),
initial_position=0.0,
position_limit=(
float(j.axis.limit.lower)
Expand All @@ -192,20 +210,23 @@ def extract_data_from_sdf(
friction_viscous=j.axis.dynamics.damping
if j.axis is not None
and j.axis.dynamics is not None
and j.axis.dynamics.friction is not None
and j.axis.dynamics.damping is not None
else 0.0,
# position_limit_damper=1_000.0,
# position_limit_spring=1.0,
position_limit_damper=j.axis.limit.dissipation
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.dissipation is not None
else 0.0,
# else 0.0,
else 1_000.0,
position_limit_spring=j.axis.limit.stiffness
if j.axis is not None
and j.axis.limit is not None
and j.axis.limit.stiffness is not None
else 0.0,
)
for j in sdf_tree.model.joints
for j in sdf_model.joints()
if j.type in {"revolute", "prismatic", "fixed"}
and j.parent != "world"
and j.child in links_dict.keys()
Expand All @@ -215,7 +236,7 @@ def extract_data_from_sdf(
joint_dict = {j.child.name: j.name for j in joints}

# Check that all the link poses are expressed wrt their parent joint
for l in sdf_tree.model.links:
for l in sdf_model.links():

if l.name not in links_dict:
continue
Expand All @@ -241,10 +262,10 @@ def extract_data_from_sdf(
collisions: List[descriptions.CollisionShape] = []

# Parse the collisions
for link in sdf_tree.model.links:
for collision in link.colliders:
for link in sdf_model.links():
for collision in link.collisions():

if collision.geometry.box.to_xml() != "<box/>":
if collision.geometry.box is not None:

box_collision = utils.create_box_collision(
collision=collision,
Expand All @@ -253,7 +274,7 @@ def extract_data_from_sdf(

collisions.append(box_collision)

if collision.geometry.sphere.to_xml() != "<sphere/>":
if collision.geometry.sphere is not None:

sphere_collision = utils.create_sphere_collision(
collision=collision,
Expand All @@ -263,14 +284,14 @@ def extract_data_from_sdf(
collisions.append(sphere_collision)

return SDFData(
model_name=sdf_tree.model.name,
model_name=sdf_model.name,
link_descriptions=links,
joint_descriptions=joints,
collision_shapes=collisions,
fixed_base=fixed_base,
base_link_name=base_link_name,
model_pose=model_pose,
sdf_tree=sdf_tree.model,
sdf_model=sdf_model,
)


Expand Down Expand Up @@ -298,6 +319,6 @@ def build_model_from_sdf(sdf: Union[Path, str]) -> descriptions.ModelDescription
)

# Store the parsed SDF tree as extra info
model = dataclasses.replace(model, extra_info=dict(sdf_tree=sdf_data.sdf_tree))
model = dataclasses.replace(model, extra_info=dict(sdf_model=sdf_data.sdf_model))

return model
Loading

0 comments on commit 48e18ed

Please sign in to comment.