Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify body_models #9

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@ venv.bak/
# mypy
.mypy_cache/

smplx/models
82 changes: 40 additions & 42 deletions smplx/body_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
# for Intelligent Systems and the Max Planck Institute for Biological
# Cybernetics. All rights reserved.
#
# Contact: ps-license@tuebingen.mpg.de

Expand Down Expand Up @@ -268,7 +269,10 @@ def __init__(self, model_path, data_struct=None,
self.register_buffer('v_template',
to_tensor(to_np(data_struct.v_template),
dtype=dtype))

# add bias
if self.v_template.shape[0] == 10475:
bias = torch.Tensor(np.array([[0, 0.1728, 0.0218]]))
self.v_template = self.v_template + bias
# The shape components
shapedirs = data_struct.shapedirs
# The shape components
Expand Down Expand Up @@ -352,24 +356,19 @@ def forward(self, betas=None, body_pose=None, global_orient=None,
'''
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else
self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas

bn = body_pose.shape[0]
global_orient = (global_orient if global_orient is not None else
self.global_orient[:bn])
betas = betas if betas is not None else self.betas[:bn]
if betas.shape[0] < bn:
betas = betas.expand(bn, -1)
apply_trans = transl is not None or hasattr(self, 'transl')
if transl is None and hasattr(self, 'transl'):
transl = self.transl
transl = self.transl[:bn]

full_pose = torch.cat([global_orient, body_pose], dim=1)

batch_size = max(betas.shape[0], global_orient.shape[0],
body_pose.shape[0])

if betas.shape[0] != batch_size:
num_repeats = int(batch_size / betas.shape[0])
betas = betas.expand(num_repeats, -1)

vertices, joints = lbs(betas, full_pose, self.v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
Expand Down Expand Up @@ -577,32 +576,32 @@ def forward(self, betas=None, global_orient=None, body_pose=None,
'''
# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else
self.global_orient)
body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas
bn = body_pose.shape[0]
global_orient = (global_orient if global_orient is not None else
self.global_orient[:bn])

betas = betas if betas is not None else self.betas[:bn]
left_hand_pose = (left_hand_pose if left_hand_pose is not None else
self.left_hand_pose)
self.left_hand_pose[:bn])
right_hand_pose = (right_hand_pose if right_hand_pose is not None else
self.right_hand_pose)
self.right_hand_pose[:bn])

apply_trans = transl is not None or hasattr(self, 'transl')
if transl is None:
if hasattr(self, 'transl'):
transl = self.transl
transl = self.transl[:bn]

if self.use_pca:
left_hand_pose = torch.einsum(
'bi,ij->bj', [left_hand_pose, self.left_hand_components])
right_hand_pose = torch.einsum(
'bi,ij->bj', [right_hand_pose, self.right_hand_components])

full_pose = torch.cat([global_orient, body_pose,
left_hand_pose,
right_hand_pose], dim=1)
full_pose = torch.cat([global_orient[:bn], body_pose[:bn],
left_hand_pose[:bn],
right_hand_pose[:bn]], dim=1)
full_pose += self.pose_mean

vertices, joints = lbs(self.betas, full_pose, self.v_template,
vertices, joints = lbs(betas, full_pose, self.v_template,
self.shapedirs, self.posedirs,
self.J_regressor, self.parents,
self.lbs_weights, pose2rot=pose2rot,
Expand Down Expand Up @@ -875,35 +874,35 @@ def forward(self, betas=None, global_orient=None, body_pose=None,

# If no shape and pose parameters are passed along, then use the
# ones from the module
global_orient = (global_orient if global_orient is not None else
self.global_orient)

body_pose = body_pose if body_pose is not None else self.body_pose
betas = betas if betas is not None else self.betas

bn = body_pose.shape[0]
global_orient = (global_orient if global_orient is not None else
self.global_orient[:bn])
betas = betas if betas is not None else self.betas[:bn]
left_hand_pose = (left_hand_pose if left_hand_pose is not None else
self.left_hand_pose)
right_hand_pose = (right_hand_pose if right_hand_pose is not None else
self.right_hand_pose)
jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
leye_pose = leye_pose if leye_pose is not None else self.leye_pose
reye_pose = reye_pose if reye_pose is not None else self.reye_pose
expression = expression if expression is not None else self.expression

jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose[:bn]
leye_pose = leye_pose if leye_pose is not None else self.leye_pose[:bn]
reye_pose = reye_pose if reye_pose is not None else self.reye_pose[:bn]
expression = expression if expression is not None else self.expression[:bn]
apply_trans = transl is not None or hasattr(self, 'transl')
if transl is None:
if hasattr(self, 'transl'):
transl = self.transl
transl = self.transl[:bn]

if self.use_pca:
left_hand_pose = torch.einsum(
'bi,ij->bj', [left_hand_pose, self.left_hand_components])
right_hand_pose = torch.einsum(
'bi,ij->bj', [right_hand_pose, self.right_hand_components])

full_pose = torch.cat([global_orient, body_pose,
jaw_pose, leye_pose, reye_pose,
left_hand_pose,
right_hand_pose], dim=1)
full_pose = torch.cat([global_orient[:bn], body_pose[:bn],
jaw_pose[:bn], leye_pose[:bn], reye_pose[:bn],
left_hand_pose[:bn],
right_hand_pose[:bn]], dim=1)

# Add the mean pose of the model. Does not affect the body, only the
# hands when flat_hand_mean == False
Expand All @@ -926,7 +925,7 @@ def forward(self, betas=None, global_orient=None, body_pose=None,
lmk_faces_idx = self.lmk_faces_idx.unsqueeze(
dim=0).expand(batch_size, -1).contiguous()
lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
self.batch_size, 1, 1)
batch_size, 1, 1)
if self.use_face_contour:
dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
vertices, full_pose, self.dynamic_lmk_faces_idx,
Expand All @@ -938,7 +937,6 @@ def forward(self, betas=None, global_orient=None, body_pose=None,
lmk_bary_coords = torch.cat(
[lmk_bary_coords.expand(batch_size, -1, -1),
dyn_lmk_bary_coords], 1)

landmarks = vertices2landmarks(vertices, self.faces_tensor,
lmk_faces_idx,
lmk_bary_coords)
Expand Down
2 changes: 1 addition & 1 deletion tools/clean_ch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def clean_fn(fn, output_folder='output'):
with open(fn, 'rb') as body_file:
body_data = pickle.load(body_file)
body_data = pickle.load(body_file, encoding='latin1')

output_dict = {}
for key, data in body_data.iteritems():
Expand Down
8 changes: 4 additions & 4 deletions tools/merge_smplh_mano.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ def merge_models(smplh_fn, mano_left_fn, mano_right_fn,
output_folder='output'):

with open(smplh_fn, 'rb') as body_file:
body_data = pickle.load(body_file)
body_data = pickle.load(body_file, encoding='latin1')

with open(mano_left_fn, 'rb') as lhand_file:
lhand_data = pickle.load(lhand_file)
lhand_data = pickle.load(lhand_file, encoding='latin1')

with open(mano_right_fn, 'rb') as rhand_file:
rhand_data = pickle.load(rhand_file)
rhand_data = pickle.load(rhand_file, encoding='latin1')

out_fn = osp.split(smplh_fn)[1]

Expand All @@ -50,7 +50,7 @@ def merge_models(smplh_fn, mano_left_fn, mano_right_fn,
output_data['hands_meanl'] = lhand_data['hands_mean']
output_data['hands_meanr'] = rhand_data['hands_mean']

for key, data in output_data.iteritems():
for key, data in output_data.items():
if 'chumpy' in str(type(data)):
output_data[key] = np.array(data)
else:
Expand Down