Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SHOW Visualization] Which part of code to refer #10

Open
jameskuma opened this issue Jun 4, 2024 · 6 comments
Open

[SHOW Visualization] Which part of code to refer #10

jameskuma opened this issue Jun 4, 2024 · 6 comments

Comments

@jameskuma
Copy link

Dear author,

Thank you for this awesome work!

I run the inference part of this repo using SHOW dataset, and I only get a bunch of .npz.

However, how to visualize them with visualization tool in TalkSHOW. I mean which part of code should I used to visualize the results?

Best regards

@jameskuma
Copy link
Author

I try to use TalkSHOW code to visualize data but I get the bad result.

image

Do you know the reason? My code is as follow (from TalkSHOW/scripts/demo.py):

lower_pose = torch.tensor(
    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0747, -0.0158, -0.0152, -1.1826512813568115, 0.23866955935955048,
     0.15146760642528534, -1.2604516744613647, -0.3160211145877838,
     -0.1603458970785141, 1.1654603481292725, 0.0, 0.0, 1.2521806955337524, 0.041598282754421234, -0.06312154978513718,
     0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
lower_pose_stand = torch.tensor([
    8.9759e-04, 7.1074e-04, -5.9163e-06, 8.9759e-04, 7.1074e-04, -5.9163e-06,
    3.0747, -0.0158, -0.0152,
    -3.6665e-01, -8.8455e-03, 1.6113e-01, -3.6665e-01, -8.8455e-03, 1.6113e-01,
    -3.9716e-01, -4.0229e-02, -1.2637e-01,
    7.9163e-01, 6.8519e-02, -1.5091e-01, 7.9163e-01, 6.8519e-02, -1.5091e-01,
    7.8632e-01, -4.3810e-02, 1.4375e-02,
    -1.0675e-01, 1.2635e-01, 1.6711e-02, -1.0675e-01, 1.2635e-01, 1.6711e-02, ])

def part2full(input, stand=False):
    if stand:
        lp = torch.zeros_like(lower_pose)
        lp[6:9] = torch.tensor([3.0747, -0.0158, -0.0152])
        lp = lp.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)
    else:
        lp = lower_pose.unsqueeze(dim=0).repeat(input.shape[0], 1).to(input.device)

    input = torch.cat([input[:, :3],
                       lp[:, :15],
                       input[:, 3:6],
                       lp[:, 15:21],
                       input[:, 6:9],
                       lp[:, 21:27],
                       input[:, 9:12],
                       lp[:, 27:],
                       input[:, 12:]]
                      , dim=1)
    return input

def main():
    # * create smplex model
    zelin_log.info('init smlpx model...')
    dtype = torch.float64
    smplx_path = './visualise/'
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        dtype=dtype,
    )
    smplx_model = smplx.create(**model_params).to(device)
    # * load smplx param
    # this is DiffSHEG output
    pred_smplx = np.load('results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_1/Forrest_tts.npy')
    pred_smplx = torch.from_numpy(pred_smplx).float().to(device)[0][:100]
    pred_smplx = part2full(pred_smplx, stand=True)
    
    # * pred_smplx size: [n_frames, param_dim]
    import tqdm
    vertices = []
    betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
    for frame_ind in tqdm.tqdm(range(pred_smplx.shape[0]), desc='infer mesh vectices per frame'):
        sample_output: SMPLOutput = smplx_model.forward(
            betas=betas,
            jaw_pose=pred_smplx[frame_ind][0:3].unsqueeze_(dim=0),
            leye_pose=pred_smplx[frame_ind][3:6].unsqueeze_(dim=0),
            reye_pose=pred_smplx[frame_ind][6:9].unsqueeze_(dim=0),
            global_orient=pred_smplx[frame_ind][9:12].unsqueeze_(dim=0),
            body_pose=pred_smplx[frame_ind][12:75].unsqueeze_(dim=0),
            left_hand_pose=pred_smplx[frame_ind][75:120].unsqueeze_(dim=0),
            right_hand_pose=pred_smplx[frame_ind][120:165].unsqueeze_(dim=0),
            expression=pred_smplx[frame_ind][165:265].unsqueeze_(dim=0),
            return_verts=True,
        )
        vertices.append(sample_output.vertices.detach().cpu().numpy().squeeze())
    vertices = np.asarray(vertices)

    print(vertices.shape)

    # * debug Render
    exp_dir = 'exp/speech2smplx'
    os.makedirs(exp_dir, exist_ok=True)
    num_frames = vertices.shape[0]

    # dataset is inverse
    vertices = vertices.reshape(vertices.shape[0], -1, 3)
    vertices[:, :, 1] = -vertices[:, :, 1]
    vertices[:, :, 2] = -vertices[:, :, 2]

    width, height = 800, 1440
    viewport_height = 1440
    z_offset = 1.8

    video_fname = 'demo'
    os.makedirs(f'{exp_dir}/video_frames', exist_ok=True)

    writer = cv2.VideoWriter(f'{exp_dir}/{video_fname}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 30, (width, height), True)
    center = np.mean(vertices[0], axis=0)

    render_helper = pyrender.OffscreenRenderer(viewport_width=800, viewport_height=viewport_height)

    class Struct(object):
        def __init__(self, **kwargs):
            for key, val in kwargs.items():
                setattr(self, key, val)

    path = os.path.join(os.getcwd(), 'visualise/smplx/SMPLX_NEUTRAL.npz')
    model_data = np.load(path, allow_pickle=True)
    data_struct = Struct(**model_data)

    for i_frame in tqdm.tqdm(range(num_frames), desc='render debug image'):
        vectice = vertices[i_frame]
        # todo save vectice as npz
        imgi = render_mesh_helper((vectice, data_struct.f), center, camera='o', r=render_helper, y=0.7, z_offset=z_offset)
        imgi = imgi.astype(np.uint8)
        # save image as frame
        cv2.imwrite(f'{exp_dir}/video_frames/{i_frame:04d}.png', imgi)
        # save image as video
        writer.write(imgi)
    writer.release()

if __name__ == '__main__':
    main()

@JeremyCJM
Copy link
Owner

JeremyCJM commented Jun 5, 2024

Hi James, you may want to pay attention to the code here:

def extract_pose(self, pose):
. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

@jameskuma
Copy link
Author

Owner

Hi James, you may want to pay attention to the code here:

def extract_pose(self, pose):

. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

Thank you for reply!

Yes, so what is the the order of channel of these output files? I mean I read these npy files and find that they are [n_frame, 232] where 232 is exactly same as the output of SHOW/TalkSHOW.

The order is important since I need input them for this function to get mesh:

pred_smplx = np.load('Forrest_tts.npy')

sample_output: SMPLOutput = smplx_model.forward(
    betas=betas,
    jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
    leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
    reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
    global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
    body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
    left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
    right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
    expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
    return_verts=True,
)

@Mumuwei
Copy link

Mumuwei commented Jun 6, 2024

Owner

Hi James, you may want to pay attention to the code here:

def extract_pose(self, pose):

. The order of channels for pose should be carefully aligned with the pose in visualization code of TalkSHOW.

Thank you for reply!

Yes, so what is the the order of channel of these output files? I mean I read these npy files and find that they are [n_frame, 232] where 232 is exactly same as the output of SHOW/TalkSHOW.

The order is important since I need input them for this function to get mesh:

pred_smplx = np.load('Forrest_tts.npy')

sample_output: SMPLOutput = smplx_model.forward(
    betas=betas,
    jaw_pose=pred_smplx[0][0:3].unsqueeze_(dim=0),
    leye_pose=pred_smplx[0][3:6].unsqueeze_(dim=0),
    reye_pose=pred_smplx[0][6:9].unsqueeze_(dim=0),
    global_orient=pred_smplx[0][9:12].unsqueeze_(dim=0),
    body_pose=pred_smplx[0][12:75].unsqueeze_(dim=0),
    left_hand_pose=pred_smplx[0][75:120].unsqueeze_(dim=0),
    right_hand_pose=pred_smplx[0][120:165].unsqueeze_(dim=0),
    expression=pred_smplx[0][165:265].unsqueeze_(dim=0),
    return_verts=True,
)

Hello, did you render the result correctly?

@JeremyCJM
Copy link
Owner

Hi @jameskuma, this is my code to visualize the SHOW results, which is modified from the visualization code in TalkSHOW. Remember to specify the face_path and gesture_path arguments.

import os
import sys

# os.environ['CUDA_VISIBLE_DEVICES'] = '7'
sys.path.append(os.getcwd())

from transformers import Wav2Vec2Processor
from glob import glob

import numpy as np
import json
import smplx as smpl

from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool

import time


def init_model(model_name, model_path, args, config):
    if model_name == 's2g_face':
        generator = s2g_face(
            args,
            config,
        )
    elif model_name == 's2g_body_vq':
        generator = s2g_body_vq(
            args,
            config,
        )
    elif model_name == 's2g_body_pixel':
        generator = s2g_body_pixel(
            args,
            config,
        )
    else:
        raise NotImplementedError

    model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
    if model_name == 'smplx_S2G':
        generator.generator.load_state_dict(model_ckpt['generator']['generator'])

    elif 'generator' in list(model_ckpt.keys()):
        generator.load_state_dict(model_ckpt['generator'])
    else:
        model_ckpt = {'generator': model_ckpt}
        generator.load_state_dict(model_ckpt)

    return generator


def init_dataloader(data_root, speakers, args, config):
    if data_root.endswith('.csv'):
        raise NotImplementedError
    else:
        data_class = torch_data
    if 'smplx' in config.Model.model_name or 's2g' in config.Model.model_name:
        data_base = torch_data(
            data_root=data_root,
            speakers=speakers,
            split='test',
            limbscaling=False,
            normalization=config.Data.pose.normalization,
            norm_method=config.Data.pose.norm_method,
            split_trans_zero=False,
            num_pre_frames=config.Data.pose.pre_pose_length,
            num_generate_length=config.Data.pose.generate_length,
            num_frames=30,
            aud_feat_win_size=config.Data.aud.aud_feat_win_size,
            aud_feat_dim=config.Data.aud.aud_feat_dim,
            feat_method=config.Data.aud.feat_method,
            smplx=True,
            audio_sr=22000,
            convert_to_6d=config.Data.pose.convert_to_6d,
            expression=config.Data.pose.expression,
            config=config
        )
    else:
        data_base = torch_data(
            data_root=data_root,
            speakers=speakers,
            split='val',
            limbscaling=False,
            normalization=config.Data.pose.normalization,
            norm_method=config.Data.pose.norm_method,
            split_trans_zero=False,
            num_pre_frames=config.Data.pose.pre_pose_length,
            aud_feat_win_size=config.Data.aud.aud_feat_win_size,
            aud_feat_dim=config.Data.aud.aud_feat_dim,
            feat_method=config.Data.aud.feat_method
        )
    if config.Data.pose.normalization:
        norm_stats_fn = os.path.join(os.path.dirname(args.model_path), "norm_stats.npy")
        norm_stats = np.load(norm_stats_fn, allow_pickle=True)
        data_base.data_mean = norm_stats[0]
        data_base.data_std = norm_stats[1]
    else:
        norm_stats = None

    data_base.get_dataset()
    infer_set = data_base.all_dataset
    infer_loader = data.DataLoader(data_base.all_dataset, batch_size=1, shuffle=False)

    return infer_set, infer_loader, norm_stats


def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
    vertices_list = []
    poses_list = []
    expression = torch.zeros([1, 50])

    for i in result_list:
        vertices = []
        poses = []
        for j in range(i.shape[0]):
            output = smplx_model(betas=betas,
                                 expression=i[j][165:265].unsqueeze_(dim=0) if exp else expression,
                                 jaw_pose=i[j][0:3].unsqueeze_(dim=0),
                                 leye_pose=i[j][3:6].unsqueeze_(dim=0),
                                 reye_pose=i[j][6:9].unsqueeze_(dim=0),
                                 global_orient=i[j][9:12].unsqueeze_(dim=0),
                                 body_pose=i[j][12:75].unsqueeze_(dim=0),
                                 left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
                                 right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
                                 return_verts=True)
            vertices.append(output.vertices.detach().cpu().numpy().squeeze())
            # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
            pose = output.body_pose
            poses.append(pose.detach().cpu())
        vertices = np.asarray(vertices)
        vertices_list.append(vertices)
        poses = torch.cat(poses, dim=0)
        poses_list.append(poses)
    if require_pose:
        return vertices_list, poses_list
    else:
        return vertices_list, None


global_orient = torch.tensor([3.0747, -0.0158, -0.0152])


def infer(data_root, g_body, g_face, g_body2, exp_name, infer_loader, infer_set, device, norm_stats, smplx,
          smplx_model, rendertool, args=None, config=None, face_path=None, gesture_path=None):
    am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
    am_sr = 16000
    num_sample = 1
    face = False
    if face:
        body_static = torch.zeros([1, 162], device='cuda')
        body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)
    stand = False
    j = 0
    gt_0 = None

    face_list = os.listdir(face_path)
    face_list.sort()

    gesture_list = os.listdir(gesture_path)
    gesture_list.sort()

    for idx, bat in enumerate(infer_loader):
        poses_ = bat['poses'].to(torch.float32).to(device)
        if poses_.shape[-1] == 300:
            # import pdb; pdb.set_trace()
            j = j + 1
            if j > 1000:
                continue
            id = bat['speaker'].to('cuda') - 20
            if config.Data.pose.expression:
                expression = bat['expression'].to(device).to(torch.float32)
                poses = torch.cat([poses_, expression], dim=1)
            else:
                poses = poses_
            cur_wav_file = bat['aud_file'][0]
            npy_file_name = 'visualise/video/' + config.Log.name + '/' + \
                        cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1] + '.npy'

            if os.path.exists(npy_file_name):
                continue
            
            betas = bat['betas'][0].to(torch.float64).to('cuda')
            # betas = torch.zeros([1, 300], dtype=torch.float64).to('cuda')
            gt = poses.to('cuda').squeeze().transpose(1, 0)
            if config.Data.pose.normalization: # false
                gt = denormalize(gt, norm_stats[0], norm_stats[1]).squeeze(dim=0)
            if config.Data.pose.convert_to_6d: # false
                if config.Data.pose.expression:
                    gt_exp = gt[:, -100:]
                    gt = gt[:, :-100]

                gt = gt.reshape(gt.shape[0], -1, 6)

                gt = matrix_to_axis_angle(rotation_6d_to_matrix(gt)).reshape(gt.shape[0], -1)
                gt = torch.cat([gt, gt_exp], -1)
            if face: # false
                gt = torch.cat([gt[:, :3], body_static.repeat(gt.shape[0], 1), gt[:, -100:]], dim=-1)

            result_list = [gt]

            # cur_wav_file = '.\\training_data\\1_song_(Vocals).wav'

            ############################ Prediction ############################
            pred_face = np.load(os.path.join(face_path, face_list[idx]))

            pred_face = torch.tensor(pred_face).squeeze().to('cuda')
            pred_jaw = pred_face[:, :3]
            pred_face = pred_face[:, 3:]

            for i in range(num_sample):
                pred_res = np.load(os.path.join(gesture_path,gesture_list[idx]))
                pred = torch.tensor(pred_res).squeeze().to('cuda')

                if pred.shape[0] < pred_face.shape[0]:
                    repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
                    pred = torch.cat([pred, repeat_frame], dim=0)
                else:
                    pred = pred[:pred_face.shape[0], :]


                # pred = torch.cat([pred, pred_face], dim=-1)
                pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)

                pred = part2full(pred, stand)


                result_list.append(pred)

            vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)

            result_list = [res.to('cpu') for res in result_list]
            dict = np.concatenate(result_list[1:], axis=0)
            file_name = 'visualise/video/' + config.Log.name + '/' + \
                        cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
            np.save(file_name, dict)

            rendertool._render_sequences(cur_wav_file, vertices_list[1:], stand=stand, face=face)


def main():
    parser = parse_args()
    args = parser.parse_args()
    device = torch.device(args.gpu)
    torch.cuda.set_device(device)

    config = load_JsonConfig(args.config_file)

    face_model_name = args.face_model_name
    face_model_path = args.face_model_path
    body_model_name = args.body_model_name
    body_model_path = args.body_model_path
    smplx_path = './visualise/'

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init model...')
    generator = init_model(body_model_name, body_model_path, args, config)
    generator2 = None
    generator_face = init_model(face_model_name, face_model_path, args, config)
    print('init dataloader...')
    infer_set, infer_loader, norm_stats = init_dataloader(config.Data.data_root, args.speakers, args, config)

    print('init smlpx model...')
    dtype = torch.float64
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        # gender='ne',
                        dtype=dtype, )
    smplx_model = smpl.create(**model_params).to('cuda')
    

    if args.rename != None:
        config.Log.name = args.rename
    print('init rendertool...')
    rendertool = RenderTool('visualise/video/' + config.Log.name)
    
    infer(config.Data.data_root, generator, generator_face, generator2, args.exp_name, infer_loader, infer_set, device,
          norm_stats, True, smplx_model, rendertool, args, config, face_path=args.face_path, gesture_path=args.gesture_path)


if __name__ == '__main__':
    main()

@TashvikDhamija
Copy link

Hello, I get similar results to @jameskuma

I tried to understand if there is a mismatch in parameters in DiffSHEG output and SHOW SMPLX model input but everything seems okay. Has anyone been able to find the right way to render SHOW results?

@JeremyCJM I tried running your code but I cannot figure out what the face_path and gesture_path are since the DiffSHEG model only gives one npy output. Also, not quite sure why it creates a dataset and loader for the whole talkSHOW dataset whilst infering one output. Can you help me use your code for a single inference from the .npy output DiffSHEG gives?

Any help in visualising would be appreciated!

Forrest_tts_diffsheg_show.mp4

Here is my code:

import os
import sys
os.environ["PYOPENGL_PLATFORM"] = "egl"
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
sys.path.append(os.getcwd())

from transformers import Wav2Vec2Processor
from glob import glob

import numpy as np
import json
import smplx as smpl

from nets import *
from trainer.options import parse_args
from data_utils import torch_data
from trainer.config import load_JsonConfig

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from data_utils.rotation_conversion import rotation_6d_to_matrix, matrix_to_axis_angle
from data_utils.lower_body import part2full, pred2poses, poses2pred, poses2poses
from visualise.rendering import RenderTool

global device
device = 'cpu'

def init_model(model_name, model_path, args, config):
    if model_name == 's2g_face':
        generator = s2g_face(
            args,
            config,
        )
    elif model_name == 's2g_body_vq':
        generator = s2g_body_vq(
            args,
            config,
        )
    elif model_name == 's2g_body_pixel':
        generator = s2g_body_pixel(
            args,
            config,
        )
    elif model_name == 's2g_LS3DCG':
        generator = LS3DCG(
            args,
            config,
        )
    else:
        raise NotImplementedError

    model_ckpt = torch.load(model_path, map_location=torch.device('cpu'))
    if model_name == 'smplx_S2G':
        generator.generator.load_state_dict(model_ckpt['generator']['generator'])

    elif 'generator' in list(model_ckpt.keys()):
        generator.load_state_dict(model_ckpt['generator'])
    else:
        model_ckpt = {'generator': model_ckpt}
        generator.load_state_dict(model_ckpt)

    return generator

def get_vertices(smplx_model, betas, result_list, exp, require_pose=False):
    vertices_list = []
    poses_list = []
    expression = torch.zeros([1, 50])

    for i in result_list:
        vertices = []
        poses = []
        for j in range(i.shape[0]):
            output = smplx_model(betas=betas,
                                 expression=i[j][165:265].unsqueeze_(dim=0),
                                 jaw_pose=i[j][0:3].unsqueeze_(dim=0),
                                 leye_pose=i[j][3:6].unsqueeze_(dim=0),
                                 reye_pose=i[j][6:9].unsqueeze_(dim=0),
                                 global_orient=i[j][9:12].unsqueeze_(dim=0),
                                 body_pose=i[j][12:75].unsqueeze_(dim=0),
                                 left_hand_pose=i[j][75:120].unsqueeze_(dim=0),
                                 right_hand_pose=i[j][120:165].unsqueeze_(dim=0),
                                 return_verts=True)
            vertices.append(output.vertices.detach().cpu().numpy().squeeze())
            # pose = torch.cat([output.body_pose, output.left_hand_pose, output.right_hand_pose], dim=1)
            pose = output.body_pose
            poses.append(pose.detach().cpu())
        vertices = np.asarray(vertices)
        vertices_list.append(vertices)
        poses = torch.cat(poses, dim=0)
        poses_list.append(poses)
    if require_pose:
        return vertices_list, poses_list
    else:
        return vertices_list, None


global_orient = torch.tensor([3.0747, -0.0158, -0.0152])


def infer(g_body, g_face, smplx_model, rendertool, config, args):
    betas = torch.zeros([1, 300], dtype=torch.float64).to(device)
    am = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-300m-phoneme")
    am_sr = 16000
    num_sample = args.num_sample
    cur_wav_file = args.audio_file
    id = args.id
    face = args.only_face
    stand = args.stand
    if face:
        body_static = torch.zeros([1, 162], device=device)
        body_static[:, 6:9] = torch.tensor([3.0747, -0.0158, -0.0152]).reshape(1, 3).repeat(body_static.shape[0], 1)

    # result_list = []

    # pred_face = g_face.infer_on_audio(cur_wav_file,
    #                                   initial_pose=None,
    #                                   norm_stats=None,
    #                                   w_pre=False,
    #                                   # id=id,
    #                                   frame=None,
    #                                   am=am,
    #                                   am_sr=am_sr
    #                                   )
    # pred_face = torch.tensor(pred_face).squeeze().to(device)
    # # pred_face = torch.zeros([gt.shape[0], 105])

    # if config.Data.pose.convert_to_6d:
    #     pred_jaw = pred_face[:, :6].reshape(pred_face.shape[0], -1, 6)
    #     pred_jaw = matrix_to_axis_angle(rotation_6d_to_matrix(pred_jaw)).reshape(pred_face.shape[0], -1)
    #     pred_face = pred_face[:, 6:]
    # else:
    #     pred_jaw = pred_face[:, :3]
    #     pred_face = pred_face[:, 3:]

    # id = torch.tensor([id], device=device)

    # for i in range(num_sample):
    #     pred_res = g_body.infer_on_audio(cur_wav_file,
    #                                      initial_pose=None,
    #                                      norm_stats=None,
    #                                      txgfile=None,
    #                                      id=id,
    #                                      var=None,
    #                                      fps=30,
    #                                      w_pre=False
    #                                      )
    #     pred = torch.tensor(pred_res).squeeze().to(device)

    #     if pred.shape[0] < pred_face.shape[0]:
    #         repeat_frame = pred[-1].unsqueeze(dim=0).repeat(pred_face.shape[0] - pred.shape[0], 1)
    #         pred = torch.cat([pred, repeat_frame], dim=0)
    #     else:
    #         pred = pred[:pred_face.shape[0], :]

    #     body_or_face = False
    #     if pred.shape[1] < 275:
    #         body_or_face = True
    #     if config.Data.pose.convert_to_6d:
    #         pred = pred.reshape(pred.shape[0], -1, 6)
    #         pred = matrix_to_axis_angle(rotation_6d_to_matrix(pred))
    #         pred = pred.reshape(pred.shape[0], -1)

    #     if config.Model.model_name == 's2g_LS3DCG':
    #         pred = torch.cat([pred[:, :3], pred[:, 103:], pred[:, 3:103]], dim=-1)
    #     else:
    #         pred = torch.cat([pred_jaw, pred, pred_face], dim=-1)

    #     # pred[:, 9:12] = global_orient
    #     pred = part2full(pred, stand)
    #     if face:
    #         pred = torch.cat([pred[:, :3], body_static.repeat(pred.shape[0], 1), pred[:, -100:]], dim=-1)
    #     # result_list[0] = poses2pred(result_list[0], stand)
    #     # if gt_0 is None:
    #     #     gt_0 = gt
    #     # pred = pred2poses(pred, gt_0)
    #     # result_list[0] = poses2poses(result_list[0], gt_0)

    #     result_list.append(pred)

    result_list = torch.from_numpy(np.load('../DiffSHEG/results/talkshow_88/test_custom_audio/talkshow_GesExpr_unify_addHubert_encodeHubert_mdlpIncludeX_condRes_LN_ClsFree/fixStart10/ckpt_e2599_ddim25_lastStepInterp/pid_4/gesture/Forrest_tts.npy'))
    result_list = part2full(result_list[0], stand=True).unsqueeze(0)
    print(result_list.shape)
    vertices_list, _ = get_vertices(smplx_model, betas, result_list, config.Data.pose.expression)

    result_list = [res.to('cpu') for res in result_list]
    dict = np.concatenate(result_list[:], axis=0)
    file_name = 'visualise/video/' + config.Log.name + '/' + \
                cur_wav_file.split('\\')[-1].split('.')[-2].split('/')[-1]
    np.save(file_name, dict)
    rendertool._render_sequences(cur_wav_file, vertices_list, stand=stand, face=face, whole_body=args.whole_body)


def main():
    parser = parse_args()
    args = parser.parse_args()
    # device = torch.device(args.gpu)
    # torch.cuda.set_device(device)


    config = load_JsonConfig(args.config_file)

    face_model_name = args.face_model_name
    face_model_path = args.face_model_path
    body_model_name = args.body_model_name
    body_model_path = args.body_model_path
    smplx_path = './visualise/'

    os.environ['smplx_npz_path'] = config.smplx_npz_path
    os.environ['extra_joint_path'] = config.extra_joint_path
    os.environ['j14_regressor_path'] = config.j14_regressor_path

    print('init model...')
    generator = init_model(body_model_name, body_model_path, args, config)
    generator2 = None
    generator_face = init_model(face_model_name, face_model_path, args, config)

    print('init smlpx model...')
    dtype = torch.float64
    model_params = dict(model_path=smplx_path,
                        model_type='smplx',
                        create_global_orient=True,
                        create_body_pose=True,
                        create_betas=True,
                        num_betas=300,
                        create_left_hand_pose=True,
                        create_right_hand_pose=True,
                        use_pca=False,
                        flat_hand_mean=False,
                        create_expression=True,
                        num_expression_coeffs=100,
                        num_pca_comps=12,
                        create_jaw_pose=True,
                        create_leye_pose=True,
                        create_reye_pose=True,
                        create_transl=False,
                        # gender='ne',
                        dtype=dtype, )
    smplx_model = smpl.create(**model_params).to(device)
    print('init rendertool...')
    rendertool = RenderTool('visualise/video/' + config.Log.name)

    infer(generator, generator_face, smplx_model, rendertool, config, args)


if __name__ == '__main__':
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants