Skip to content

Commit

Permalink
Merge pull request #15 from YifuXu1127/master
Browse files Browse the repository at this point in the history
Left Hand
  • Loading branch information
kelvin34501 authored Dec 4, 2024
2 parents 933be97 + a31bca9 commit a38f87e
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 8 deletions.
15 changes: 11 additions & 4 deletions manotorch/axislayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@

class AxisAdaptiveLayer(torch.nn.Module):

def __init__(self):
def __init__(self, side: str = "right"):
super(AxisAdaptiveLayer, self).__init__()
self.joints_mapping = [5, 6, 7, 9, 10, 11, 17, 18, 19, 13, 14, 15, 1, 2, 3]
self.parent_joints_mappings = [0, 5, 6, 0, 9, 10, 0, 17, 18, 0, 13, 14, 0, 1, 2]
up_axis_base = np.vstack((np.array([[0, 1, 0]]).repeat(13, axis=0), np.array([[1, 1, 1]]).repeat(3, axis=0)))
self.side = side
if side == "right":
up_axis_base = np.vstack((np.array([[0, 1, 0]]).repeat(13, axis=0), np.array([[1, 1, 1]]).repeat(3, axis=0)))
elif side == "left":
up_axis_base = np.vstack((np.array([[0, 1, 0]]).repeat(13, axis=0), np.array([[-1, 1, 1]]).repeat(3, axis=0)))
self.register_buffer("up_axis_base", torch.from_numpy(up_axis_base).float().unsqueeze(0))

def forward(self, hand_joints, transf):
Expand All @@ -34,7 +38,10 @@ def forward(self, hand_joints, transf):
# b_axis = hand_joints[:, self.joints_mapping] - hand_joints[:, [i + 1 for i in self.joints_mapping]]
b_axis = hand_joints[:, self.parent_joints_mappings] - hand_joints[:, self.joints_mapping]
b_axis = (transf[:, 1:, :3, :3].transpose(2, 3) @ b_axis.unsqueeze(-1)).squeeze(-1)
b_axis_init = torch.tensor([1, 0, 0]).float().unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1).to(b_axis.device)
if self.side == "right":
b_axis_init = torch.tensor([1, 0, 0]).float().unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1).to(b_axis.device)
elif self.side == "left":
b_axis_init = torch.tensor([-1, 0, 0]).float().unsqueeze(0).unsqueeze(0).repeat(bs, 1, 1).to(b_axis.device)
b_axis = torch.cat((b_axis_init, b_axis), dim=1) # (B, 16, 3)

l_axis = torch.cross(b_axis, self.up_axis_base.expand(bs, 16, 3))
Expand All @@ -60,7 +67,7 @@ def __init__(self, side: str = "right", mano_assets_root: str = "assets/mano"):
tmpl_joints = tmpl_mano.joints
tmpl_transf_abs = tmpl_mano.transforms_abs # tmpl_T_g_p

tmpl_b_axis, tmpl_u_axis, tmpl_l_axis = AxisAdaptiveLayer()(tmpl_joints, tmpl_transf_abs) # (1, 16, 3)
tmpl_b_axis, tmpl_u_axis, tmpl_l_axis = AxisAdaptiveLayer(side=side)(tmpl_joints, tmpl_transf_abs) # (1, 16, 3)
tmpl_R_p_a = torch.cat((tmpl_b_axis.unsqueeze(-1), tmpl_u_axis.unsqueeze(-1), tmpl_l_axis.unsqueeze(-1)), dim=3)
zero_tsl = torch.zeros(1, 16, 3, 1)
zero_pad = torch.tensor([[[[0, 0, 0, 1]]]]).repeat(*zero_tsl.shape[0:2], 1, 1)
Expand Down
3 changes: 2 additions & 1 deletion scripts/simple_anatomy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def main():
center_idx=9,
mano_assets_root="assets/mano",
use_pca=False,
side="right",
flat_hand_mean=True)
hand_faces = mano_layer.th_faces # (NF, 3)

axisFK = AxisLayerFK(mano_assets_root="assets/mano")
axisFK = AxisLayerFK(side=mano_layer.side,mano_assets_root="assets/mano")
anatomyLoss = AnatomyConstraintLossEE()
anatomyLoss.setup()

Expand Down
4 changes: 2 additions & 2 deletions scripts/simple_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main(args):
mano_assets_root="assets/mano",
flat_hand_mean=False,
)
axis_layer = AxisLayerFK(mano_assets_root="assets/mano")
axis_layer = AxisLayerFK(side=mano_layer.side, mano_assets_root="assets/mano")
anchor_layer = AnchorLayer(anchor_root="assets/anchor")

BS = 1
Expand All @@ -42,7 +42,7 @@ def main(args):
bul_axes_loc = torch.eye(3).reshape(1, 1, 3, 3).repeat(BS, 16, 1, 1).to(verts.device)
bul_axes_glb = torch.matmul(T_g_a[:, :, :3, :3], bul_axes_loc) # (B, 16, 3, 3)

b_axes_dir = bul_axes_glb[:, :, :, 0].numpy() # bend direction (B, 16, 3)
b_axes_dir = bul_axes_glb[:, :, :, 0].numpy() # back direction (B, 16, 3)
u_axes_dir = bul_axes_glb[:, :, :, 1].numpy() # up direction (B, 16, 3)
l_axes_dir = bul_axes_glb[:, :, :, 2].numpy() # left direction (B, 16, 3)

Expand Down
3 changes: 2 additions & 1 deletion scripts/simple_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ def main():
center_idx=9,
mano_assets_root="assets/mano",
use_pca=False,
side="right",
flat_hand_mean=True)
hand_faces = mano_layer.th_faces # (NF, 3)

axisFK = AxisLayerFK(mano_assets_root="assets/mano")
axisFK = AxisLayerFK(side=mano_layer.side,mano_assets_root="assets/mano")
composed_ee = torch.zeros((1, 16, 3))

# transform order of right hand
Expand Down

0 comments on commit a38f87e

Please sign in to comment.