Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skinned mesh + SMPL example adjustments #291

Merged
merged 3 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions examples/08_smpl_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,21 @@ def __init__(self, model_path: Path) -> None:
assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
body_dict = dict(**np.load(model_path, allow_pickle=True))

self._J_regressor = body_dict["J_regressor"]
self._weights = body_dict["weights"]
self._v_template = body_dict["v_template"]
self._posedirs = body_dict["posedirs"]
self._shapedirs = body_dict["shapedirs"]
self._faces = body_dict["f"]

self.num_joints: int = self._weights.shape[-1]
self.num_betas: int = self._shapedirs.shape[-1]
self.J_regressor = body_dict["J_regressor"]
self.weights = body_dict["weights"]
self.v_template = body_dict["v_template"]
self.posedirs = body_dict["posedirs"]
self.shapedirs = body_dict["shapedirs"]
self.faces = body_dict["f"]

self.num_joints: int = self.weights.shape[-1]
self.num_betas: int = self.shapedirs.shape[-1]
self.parent_idx: np.ndarray = body_dict["kintree_table"][0]

def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)

# Local SE(3) transforms.
T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
Expand All @@ -63,13 +63,13 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu

# Linear blend skinning.
pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta)
v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
v_posed = np.einsum(
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta
)
return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint)


def main(model_path: Path) -> None:
Expand All @@ -86,6 +86,13 @@ def main(model_path: Path) -> None:
num_joints=model.num_joints,
parent_idx=model.parent_idx,
)
body_handle = server.scene.add_mesh_simple(
"/human",
model.v_template,
model.faces,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)
while True:
# Do nothing if no change.
time.sleep(0.02)
Expand All @@ -94,21 +101,20 @@ def main(model_path: Path) -> None:

gui_elements.changed = False

# Compute SMPL outputs.
# If anything has changed, re-compute SMPL outputs.
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=tf.SO3.exp(
# (num_joints, 3)
np.array([x.value for x in gui_elements.gui_joints])
).as_matrix(),
)
server.scene.add_mesh_simple(
"/human",
smpl_outputs.vertices,
smpl_outputs.faces,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)

# Update the mesh properties based on the SMPL model output + GUI
# elements.
body_handle.vertices = smpl_outputs.vertices
body_handle.wireframe = gui_elements.gui_wireframe.value
body_handle.color = gui_elements.gui_rgb.value

# Match transform control gizmos to joint positions.
for i, control in enumerate(gui_elements.transform_controls):
Expand Down Expand Up @@ -146,7 +152,7 @@ def set_changed(_) -> None:
with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
gui_rgb = server.gui.add_rgb("Color", initial_value=(90, 200, 255))
gui_wireframe = server.gui.add_checkbox("Wireframe", initial_value=False)
gui_show_controls = server.gui.add_checkbox("Handles", initial_value=False)
gui_show_controls = server.gui.add_checkbox("Handles", initial_value=True)

gui_rgb.on_update(set_changed)
gui_wireframe.on_update(set_changed)
Expand Down
86 changes: 50 additions & 36 deletions examples/25_smpl_visualizer_skinned.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,27 @@ def __init__(self, model_path: Path) -> None:
assert model_path.suffix.lower() == ".npz", "Model should be an .npz file!"
body_dict = dict(**np.load(model_path, allow_pickle=True))

self._J_regressor = body_dict["J_regressor"]
self._weights = body_dict["weights"]
self._v_template = body_dict["v_template"]
self._posedirs = body_dict["posedirs"]
self._shapedirs = body_dict["shapedirs"]
self._faces = body_dict["f"]

self.num_joints: int = self._weights.shape[-1]
self.num_betas: int = self._shapedirs.shape[-1]
self.J_regressor = body_dict["J_regressor"]
self.weights = body_dict["weights"]
self.v_template = body_dict["v_template"]
self.posedirs = body_dict["posedirs"]
self.shapedirs = body_dict["shapedirs"]
self.faces = body_dict["f"]

self.num_joints: int = self.weights.shape[-1]
self.num_betas: int = self.shapedirs.shape[-1]
self.parent_idx: np.ndarray = body_dict["kintree_table"][0]

def get_tpose(self, betas: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)
return v_tpose, j_tpose

def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutputs:
# Get shaped vertices + joint positions, when all local poses are identity.
v_tpose = self._v_template + np.einsum("vxb,b->vx", self._shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self._J_regressor, v_tpose)
v_tpose = self.v_template + np.einsum("vxb,b->vx", self.shapedirs, betas)
j_tpose = np.einsum("jv,vx->jx", self.J_regressor, v_tpose)

# Local SE(3) transforms.
T_parent_joint = np.zeros((self.num_joints, 4, 4)) + np.eye(4)
Expand All @@ -69,13 +75,13 @@ def get_outputs(self, betas: np.ndarray, joint_rotmats: np.ndarray) -> SmplOutpu

# Linear blend skinning.
pose_delta = (joint_rotmats[1:, ...] - np.eye(3)).flatten()
v_blend = v_tpose + np.einsum("byn,n->by", self._posedirs, pose_delta)
v_blend = v_tpose + np.einsum("byn,n->by", self.posedirs, pose_delta)
v_delta = np.ones((v_blend.shape[0], self.num_joints, 4))
v_delta[:, :, :3] = v_blend[:, None, :] - j_tpose[None, :, :]
v_posed = np.einsum(
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self._weights, v_delta
"jxy,vj,vjy->vx", T_world_joint[:, :3, :], self.weights, v_delta
)
return SmplOutputs(v_posed, self._faces, T_world_joint, T_parent_joint)
return SmplOutputs(v_posed, self.faces, T_world_joint, T_parent_joint)


def main(model_path: Path) -> None:
Expand All @@ -92,23 +98,14 @@ def main(model_path: Path) -> None:
num_joints=model.num_joints,
parent_idx=model.parent_idx,
)
smpl_outputs = model.get_outputs(
betas=np.array([x.value for x in gui_elements.gui_betas]),
joint_rotmats=np.zeros((model.num_joints, 3, 3)) + np.eye(3),
)

bone_wxyzs = np.array(
[tf.SO3.from_matrix(R).wxyz for R in smpl_outputs.T_world_joint[:, :3, :3]]
)
bone_positions = smpl_outputs.T_world_joint[:, :3, 3]

skinned_handle = server.scene.add_mesh_skinned(
v_tpose, j_tpose = model.get_tpose(np.zeros((model.num_betas,)))
mesh_handle = server.scene.add_mesh_skinned(
"/human",
smpl_outputs.vertices,
smpl_outputs.faces,
bone_wxyzs=bone_wxyzs,
bone_positions=bone_positions,
skin_weights=model._weights,
v_tpose,
model.faces,
bone_wxyzs=tf.SO3.identity(batch_axes=(model.num_joints,)).wxyz,
bone_positions=j_tpose,
skin_weights=model.weights,
wireframe=gui_elements.gui_wireframe.value,
color=gui_elements.gui_rgb.value,
)
Expand All @@ -119,10 +116,19 @@ def main(model_path: Path) -> None:
if not gui_elements.changed:
continue

# Shapes changed: update vertices / joint positions.
if gui_elements.betas_changed:
v_tpose, j_tpose = model.get_tpose(
np.array([gui_beta.value for gui_beta in gui_elements.gui_betas])
)
mesh_handle.vertices = v_tpose
mesh_handle.bone_positions = j_tpose

gui_elements.changed = False
gui_elements.betas_changed = False

# Render as wireframe?
skinned_handle.wireframe = gui_elements.gui_wireframe.value
mesh_handle.wireframe = gui_elements.gui_wireframe.value

# Compute SMPL outputs.
smpl_outputs = model.get_outputs(
Expand All @@ -139,10 +145,10 @@ def main(model_path: Path) -> None:
# Match transform control gizmos to joint positions.
for i, control in enumerate(gui_elements.transform_controls):
control.position = smpl_outputs.T_parent_joint[i, :3, 3]
skinned_handle.bones[i].wxyz = tf.SO3.from_matrix(
mesh_handle.bones[i].wxyz = tf.SO3.from_matrix(
smpl_outputs.T_world_joint[i, :3, :3]
).wxyz
skinned_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]
mesh_handle.bones[i].position = smpl_outputs.T_world_joint[i, :3, 3]


@dataclass
Expand All @@ -156,7 +162,10 @@ class GuiElements:
transform_controls: List[viser.TransformControlsHandle]

changed: bool
"""This flag will be flipped to True whenever the mesh needs to be re-generated."""
"""This flag will be flipped to True whenever any input is changed."""

betas_changed: bool
"""This flag will be flipped to True whenever the shape changes."""


def make_gui_elements(
Expand All @@ -170,7 +179,11 @@ def make_gui_elements(
tab_group = server.gui.add_tab_group()

def set_changed(_) -> None:
out.changed = True # out is define later!
out.changed = True # out is defined later!

def set_betas_changed(_) -> None:
out.betas_changed = True
out.changed = True

# GUI elements: mesh settings + visibility.
with tab_group.add_tab("View", viser.Icon.VIEWFINDER):
Expand Down Expand Up @@ -220,7 +233,7 @@ def _(_):
f"beta{i}", min=-5.0, max=5.0, step=0.01, initial_value=0.0
)
gui_betas.append(beta)
beta.on_update(set_changed)
beta.on_update(set_betas_changed)

# GUI elements: joint angles.
with tab_group.add_tab("Joints", viser.Icon.ANGLE):
Expand Down Expand Up @@ -295,6 +308,7 @@ def _(_) -> None:
gui_joints,
transform_controls=transform_controls,
changed=True,
betas_changed=False,
)
return out

Expand Down
11 changes: 7 additions & 4 deletions src/viser/_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,17 +551,20 @@ class SkinnedMeshProps(MeshProps):

Vertices are internally canonicalized to float32, faces to uint32."""

bone_wxyzs: Tuple[Tuple[float, float, float, float], ...]
"""Tuple of quaternions representing bone orientations. Synchronized automatically when assigned."""
bone_positions: Tuple[Tuple[float, float, float], ...]
"""Tuple of positions representing bone positions. Synchronized automatically when assigned."""
bone_wxyzs: npt.NDArray[np.float32]
"""Array of quaternions representing bone orientations (B, 4). Synchronized automatically when assigned."""
bone_positions: npt.NDArray[np.float32]
"""Array of positions representing bone positions (B, 3). Synchronized automatically when assigned."""
skin_indices: npt.NDArray[np.uint16]
"""Array of skin indices. Should have shape (V, 4). Synchronized automatically when assigned."""
skin_weights: npt.NDArray[np.float32]
"""Array of skin weights. Should have shape (V, 4). Synchronized automatically when assigned."""

def __post_init__(self):
# Check shapes.
assert self.bone_wxyzs.shape[-1] == 4
assert self.bone_positions.shape[-1] == 3
assert self.bone_wxyzs.shape[0] == self.bone_positions.shape[0]
assert self.vertices.shape[-1] == 3
assert self.faces.shape[-1] == 3
assert self.skin_weights is not None
Expand Down
15 changes: 2 additions & 13 deletions src/viser/_scene_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,19 +1065,8 @@ def add_mesh_skinned(
flat_shading=flat_shading,
side=side,
material=material,
bone_wxyzs=tuple(
(
float(wxyz[0]),
float(wxyz[1]),
float(wxyz[2]),
float(wxyz[3]),
)
for wxyz in bone_wxyzs.astype(np.float32)
),
bone_positions=tuple(
(float(xyz[0]), float(xyz[1]), float(xyz[2]))
for xyz in bone_positions.astype(np.float32)
),
bone_wxyzs=bone_wxyzs.astype(np.float32),
bone_positions=bone_positions.astype(np.float32),
skin_indices=top4_skin_indices.astype(np.uint16),
skin_weights=top4_skin_weights.astype(np.float32),
),
Expand Down
9 changes: 7 additions & 2 deletions src/viser/_scene_handles.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,13 @@ def __setattr__(self, name: str, value: Any) -> None:
elif hint == onpt.NDArray[np.uint8] and "color" in name:
value = colors_to_uint8(value)

if getattr(handle._impl.props, name) == value:
# Do nothing. Assumes equality is defined for the prop value.
current_value = getattr(handle._impl.props, name)

# Do nothing if the value hasn't changed.
if isinstance(current_value, np.ndarray):
if current_value.data == value.data:
return
elif current_value == value:
return

setattr(handle._impl.props, name, value)
Expand Down
30 changes: 26 additions & 4 deletions src/viser/client/src/MessageHandler.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,34 @@ function useMessageHandler() {
initialized: false,
poses: [],
};

const bone_wxyzs = new Float32Array(
message.props.bone_wxyzs.buffer.slice(
message.props.bone_wxyzs.byteOffset,
message.props.bone_wxyzs.byteOffset +
message.props.bone_wxyzs.byteLength,
),
);
const bone_positions = new Float32Array(
message.props.bone_positions.buffer.slice(
message.props.bone_positions.byteOffset,
message.props.bone_positions.byteOffset +
message.props.bone_positions.byteLength,
),
);
for (let i = 0; i < message.props.bone_wxyzs!.length; i++) {
const wxyz = message.props.bone_wxyzs[i];
const position = message.props.bone_positions[i];
viewer.skinnedMeshState.current[message.name].poses.push({
wxyz: wxyz,
position: position,
wxyz: [
bone_wxyzs[4 * i],
bone_wxyzs[4 * i + 1],
bone_wxyzs[4 * i + 2],
bone_wxyzs[4 * i + 3],
],
position: [
bone_positions[3 * i],
bone_positions[3 * i + 1],
bone_positions[3 * i + 2],
],
});
}
}
Expand Down
Loading
Loading