Skip to content

Commit

Permalink
New mesh wrapping algorithms with relative tests
Browse files Browse the repository at this point in the history
- New mesh wrapping algorithms (mesh decimation, object mapping, aap, select points over axis)
- Implemented tests of above except first algorithm
- Updated manifold3d dependency (used in object mapping)
- Restructured meshes module
  • Loading branch information
lorycontixd committed Nov 15, 2024
1 parent 22da2cd commit 8058b7c
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 98 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ dependencies = [
"rod >= 0.3.3",
"typing_extensions ; python_version < '3.12'",
"trimesh",
"manifold3d",
]

[project.optional-dependencies]
Expand Down
276 changes: 199 additions & 77 deletions src/jaxsim/parsers/rod/meshes.py
Original file line number Diff line number Diff line change
@@ -1,80 +1,202 @@
import trimesh
import numpy as np


def extract_points_vertex_extraction(mesh: trimesh.Trimesh) -> np.ndarray:
"""Extracts the points of a mesh using the vertices of the mesh as colliders.
Args:
mesh: The mesh to extract the points from.
Returns:
The points of the mesh.
"""
return mesh.vertices


def extract_points_random_surface_sampling(
mesh: trimesh.Trimesh, num_points: int
) -> np.ndarray:
"""Extracts the points of a mesh by sampling the surface of the mesh randomly.
Args:
mesh: The mesh to extract the points from.
num_points: The number of points to sample.
Returns:
The points of the mesh.
"""
return mesh.sample(num_points)


def extract_points_uniform_surface_sampling(
mesh: trimesh.Trimesh, num_points: int
) -> np.ndarray:
"""Extracts the points of a mesh by sampling the surface of the mesh uniformly.
Args:
mesh: The mesh to extract the points from.
num_points: The number of points to sample.
Returns:
The points of the mesh.
"""
return trimesh.sample.sample_surface_even(mesh=mesh, count=num_points)


def extract_points_aap(
mesh: trimesh.Trimesh,
aap_axis: str,
aap_value: float,
aap_direction: str,
) -> np.ndarray:
"""Extracts the points of a mesh that are on one side of an axis-aligned plane (AAP).
Args:
mesh: The mesh to extract the points from.
aap_axis: The axis of the AAP.
aap_value: The value of the AAP.
aap_direction: The direction of the AAP.
Returns:
The points of the mesh that are on one side of the AAP.
"""
if aap_direction == "higher":
aap_operator = np.greater
elif aap_direction == "lower":
aap_operator = np.less
import rod


def parse_object_mapping_object(obj) -> trimesh.Trimesh:
if isinstance(obj, trimesh.Trimesh):
return obj
elif isinstance(obj, dict):
if "type" not in obj:
raise ValueError("Object type not specified")
if obj["type"] == "box":
if "extents" not in obj:
raise ValueError("Box extents not specified")
return trimesh.creation.box(extents=obj["extents"])
elif obj["type"] == "sphere":
if "radius" not in obj:
raise ValueError("Sphere radius not specified")
return trimesh.creation.icosphere(subdivisions=4, radius=obj["radius"])
else:
raise ValueError(f"Invalid object type {obj['type']}")
elif isinstance(obj, rod.builder.primitive_builder.PrimitiveBuilder):
raise NotImplementedError("PrimitiveBuilder not implemented")
else:
raise ValueError("Invalid direction for axis-aligned plane")

if aap_axis == "x":
points = mesh.vertices[aap_operator(mesh.vertices[:, 0], aap_value)]
elif aap_axis == "y":
points = mesh.vertices[aap_operator(mesh.vertices[:, 1], aap_value)]
elif aap_axis == "z":
points = mesh.vertices[aap_operator(mesh.vertices[:, 2], aap_value)]
else:
raise ValueError("Invalid axis for axis-aligned plane")

return points
raise ValueError("Invalid object type")


class MeshMapping:
@staticmethod
def vertex_extraction(mesh: trimesh.Trimesh) -> np.ndarray:
"""Extracts the points of a mesh using the vertices of the mesh as colliders.
Args:
mesh: The mesh to extract the points from.
Returns:
The points of the mesh.
"""
return mesh.vertices

@staticmethod
def random_surface_sampling(mesh: trimesh.Trimesh, num_points: int) -> np.ndarray:
"""Extracts the points of a mesh by sampling the surface of the mesh randomly.
Args:
mesh: The mesh to extract the points from.
num_points: The number of points to sample.
Returns:
The points of the mesh.
"""
return mesh.sample(num_points)

@staticmethod
def uniform_surface_sampling(mesh: trimesh.Trimesh, num_points: int) -> np.ndarray:
"""Extracts the points of a mesh by sampling the surface of the mesh uniformly.
Args:
mesh: The mesh to extract the points from.
num_points: The number of points to sample.
Returns:
The points of the mesh.
"""
return trimesh.sample.sample_surface_even(mesh=mesh, count=num_points)

@staticmethod
def aap(
mesh: trimesh.Trimesh,
axis: str,
direction: str,
aap_value: float,
) -> np.ndarray:
"""Axis Aligned Plane
Extracts the points of a mesh that are on one side of an axis-aligned plane (AAP).
This means that the algorithm considers all points above/below a certain value along one axis.
Args:
mesh: The mesh to extract the points from.
axis: The axis of the AAP.
direction: The direction of the AAP.
aap_value: The value of the AAP.
Returns:
The points of the mesh that are on one side of the AAP.
TODO: Implement inclined plane
"""
if direction == "higher":
aap_operator = np.greater
elif direction == "lower":
aap_operator = np.less
else:
raise ValueError("Invalid direction for axis-aligned plane")

if axis == "x":
points = mesh.vertices[aap_operator(mesh.vertices[:, 0], aap_value)]
elif axis == "y":
points = mesh.vertices[aap_operator(mesh.vertices[:, 1], aap_value)]
elif axis == "z":
points = mesh.vertices[aap_operator(mesh.vertices[:, 2], aap_value)]
else:
raise ValueError("Invalid axis for axis-aligned plane")

return points

@staticmethod
def select_points_over_axis(
mesh: trimesh.Trimesh,
axis: str,
direction: str,
n: int,
):
"""Select Points Over Axis.
Select N points over an axis, either starting from the lower or higher end.
Args:
mesh: The mesh to extract the points from.
axis: The axis along which to remove points.
direction: The direction of the AAP.
n: The number of points to remove.
Returns:
The points of the mesh.
"""
valid_dirs = ["higher", "lower"]
if direction not in valid_dirs:
raise ValueError(f"Invalid direction. Valid directions are {valid_dirs}")
arr = mesh.vertices

index = 0 if axis == "x" else 1 if axis == "y" else 2
# Sort the array in ascending order
sorted_arr = arr[arr[:, index].argsort()]

if direction == "lower":
# Select first N points
points = sorted_arr[:n]
elif direction == "higher":
# Select last N points
points = sorted_arr[-n:]
else:
raise ValueError(
f"Invalid direction {direction} for SelectPointsOverAxis method"
)

return points

@staticmethod
def object_mapping(
mesh: trimesh.Trimesh,
object: trimesh.Trimesh,
method: str = "subtraction",
**kwargs,
):
"""Object Mapping.
Removes points from a mesh that are inside another object, using subtraction or intersection.
The method can be either "subtraction" or "intersection".
The object can be a mesh or a primitive.
Args:
mesh: The mesh to extract the points from.
object: The object to use for mapping.
method: The method to use for mapping.
**kwargs: Additional arguments for the method.
Returns:
The points of the mesh.
"""
if method == "subtraction":
x: trimesh.Trimesh = mesh.difference(object, **kwargs)
x.show()
points = x.vertices
elif method == "intersection":
points = mesh.intersection(object, **kwargs).vertices
else:
raise ValueError("Invalid method for object mapping")

return points

@staticmethod
def mesh_decimation(
mesh: trimesh.Trimesh,
method: str = "",
nsamples: int = -1,
):
"""Object decimation algorithm to reduce the number of vertices in a mesh, then extract points.
Args:
mesh: The mesh to extract the points from.
method: The method to use for decimation.
nsamples: The number of desired samples.
Returns:
The points of the mesh.
"""

if method == "quadric":
mesh = mesh.simplify_quadric_decimation(nsamples // 3)
else:
raise ValueError("Invalid method for mesh decimation")

return mesh.vertices
51 changes: 40 additions & 11 deletions src/jaxsim/parsers/rod/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ class MeshMappingMethods(enum.IntEnum):
RandomSurfaceSampling = enum.auto() # Sample the surface of the mesh randomly
UniformSurfaceSampling = enum.auto() # Sample the surface of the mesh uniformly
MeshDecimation = enum.auto() # Decimate the mesh to a certain number of vertices
MeshMapping = enum.auto() # Subtraction of bounding box from mesh
ObjectMapping = enum.auto() # Subtraction of bounding box from mesh
AxisAlignedPlane = enum.auto() # Remove all points above a certain axis value
SelectPointsOverAxis = (
enum.auto()
) # Remove N highest or lowest points over a certain axis value


def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix:
Expand Down Expand Up @@ -226,9 +229,10 @@ def create_mesh_collision(
link_description: descriptions.LinkDescription,
method: MeshMappingMethods = MeshMappingMethods.VertexExtraction,
nsamples: int = 1000,
aap_axis: str = "z",
axis: str = "z",
direction: str = "lower",
aap_value: float = 0.0,
aap_direction: str = "lower",
object_mapping_object: trimesh.Trimesh | dict = None,
) -> descriptions.MeshCollision:
file = pathlib.Path(resolve_local_uri(uri=collision.geometry.mesh.uri))
_file_type = file.suffix.replace(".", "")
Expand All @@ -246,25 +250,50 @@ def create_mesh_collision(
# Extract the points from the mesh to use as colliders according to the provided method
match method:
case MeshMappingMethods.VertexExtraction:
points = meshes.extract_points_vertex_extraction(mesh=mesh)
points = meshes.MeshMapping.vertex_extraction(mesh=mesh)
case MeshMappingMethods.RandomSurfaceSampling:
points = meshes.extract_points_random_surface_sampling(
if nsamples > len(mesh.vertices):
logging.warning(
f"Number of samples {nsamples} is larger than the number of vertices {len(mesh.vertices)} in the mesh. Falling back to number of vertices"
)
nsamples = len(mesh.vertices)
points = meshes.MeshMapping.random_surface_sampling(
mesh=mesh, num_points=nsamples
)
case MeshMappingMethods.UniformSurfaceSampling:
points = meshes.extract_points_uniform_surface_sampling(
if nsamples > len(mesh.vertices):
logging.warning(
f"Number of samples {nsamples} is larger than the number of vertices {len(mesh.vertices)} in the mesh. Falling back to number of vertices"
)
nsamples = len(mesh.vertices)
points = meshes.MeshMapping.uniform_surface_sampling(
mesh=mesh, num_points=nsamples
)
case MeshMappingMethods.MeshDecimation:
raise NotImplementedError("Mesh decimation is not implemented yet")
case MeshMappingMethods.MeshMapping:
raise NotImplementedError("AABMapping is not implemented yet")
case MeshMappingMethods.ObjectMapping:
if object_mapping_object is None:
raise ValueError("Object mapping object was not provided")
obj = meshes.parse_object_mapping_object(object_mapping_object)
points = meshes.MeshMapping.object_mapping(mesh=mesh, object=obj)
case MeshMappingMethods.AxisAlignedPlane:
points = meshes.extract_points_aap(
points = meshes.MeshMapping.aap(
mesh=mesh,
aap_axis=aap_axis,
axis=axis,
direction=direction,
aap_value=aap_value,
aap_direction=aap_direction,
)
case MeshMappingMethods.SelectPointsOverAxis:
if nsamples > len(mesh.vertices):
logging.warning(
f"Number of samples {nsamples} is larger than the number of vertices {len(mesh.vertices)} in the mesh. Falling back to number of vertices"
)
nsamples = len(mesh.vertices)
points = meshes.MeshMapping.select_points_over_axis(
mesh=mesh,
axis=axis,
direction=direction,
n=nsamples,
)
case _:
raise ValueError("Invalid mesh mapping method")
Expand Down
Loading

0 comments on commit 8058b7c

Please sign in to comment.