Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for half precision in face renderer #704

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions app_sadtalker.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warp
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
half = gr.Checkbox(label="Use half precision")
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')

with gr.Tabs(elem_id="sadtalker_genearted"):
Expand All @@ -78,6 +79,7 @@ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warp
preprocess_type,
is_still_mode,
enhancer,
half,
batch_size,
size_of_image,
pose_style
Expand All @@ -92,6 +94,7 @@ def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warp
preprocess_type,
is_still_mode,
enhancer,
half,
batch_size,
size_of_image,
pose_style
Expand Down
1 change: 1 addition & 0 deletions docs/best_practice.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Advanced confiuration options for `inference.py`:
| ref Mode (pose) | `--ref_pose` | None | A video path, where we borrow the pose from the head reference video.
| 3D Mode | `--face3dvis` | False | Need additional installation. More details to generate the 3d face can be founded [here](docs/face3d.md).
| free-view Mode | `--input_yaw`,<br> `--input_pitch`,<br> `--input_roll` | None | Genearting novel view or free-view 4D talking head from a single image. More details can be founded [here](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image).
| half precision | `--half` | False | Using half precision to speed up the inference on the face renderer.


### About `--preprocess`
Expand Down
3 changes: 2 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def main(args):

audio_to_coeff = Audio2Coeff(sadtalker_paths, device)

animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device, args.half)

#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
Expand Down Expand Up @@ -118,6 +118,7 @@ def main(args):
parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
parser.add_argument("--half", action="store_true", help="use half precision or not" )


# net structure and parameters
Expand Down
5 changes: 3 additions & 2 deletions src/facerender/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

class AnimateFromCoeff():

def __init__(self, sadtalker_path, device):
def __init__(self, sadtalker_path, device, half=False):

with open(sadtalker_path['facerender_yaml']) as f:
config = yaml.safe_load(f)
Expand Down Expand Up @@ -82,6 +82,7 @@ def __init__(self, sadtalker_path, device):
self.mapping.eval()

self.device = device
self.half = half

def load_cpk_facevid2vid_safetensor(self, checkpoint_path, generator=None,
kp_detector=None, he_estimator=None,
Expand Down Expand Up @@ -182,7 +183,7 @@ def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, backgr

predictions_video = make_animation(source_image, source_semantics, target_semantics,
self.generator, self.kp_extractor, self.he_estimator, self.mapping,
yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True, use_half=self.half)

predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
predictions_video = predictions_video[:frame_num]
Expand Down
5 changes: 5 additions & 0 deletions src/facerender/modules/make_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def make_animation(source_image, source_semantics, target_semantics,
use_exp=True, use_half=False):
with torch.no_grad():
predictions = []
generator = generator.half() if use_half else generator

kp_canonical = kp_detector(source_image)
he_source = mapping(source_semantics)
Expand All @@ -125,6 +126,10 @@ def make_animation(source_image, source_semantics, target_semantics,
kp_driving = keypoint_transformation(kp_canonical, he_driving)

kp_norm = kp_driving
if use_half:
source_image = source_image.half()
kp_source = {k: v.half() for k, v in kp_source.items()}
kp_norm = {k: v.half() for k, v in kp_norm.items()}
out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
'''
source_image_new = out['prediction'].squeeze(1)
Expand Down
4 changes: 2 additions & 2 deletions src/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy


def test(self, source_image, driven_audio, preprocess='crop',
still_mode=False, use_enhancer=False, batch_size=1, size=256,
still_mode=False, use_enhancer=False, use_half=False, batch_size=1, size=256,
pose_style = 0, exp_scale=1.0,
use_ref_video = False,
ref_video = None,
Expand All @@ -48,7 +48,7 @@ def test(self, source_image, driven_audio, preprocess='crop',

self.audio_to_coeff = Audio2Coeff(self.sadtalker_paths, self.device)
self.preprocess_model = CropAndExtract(self.sadtalker_paths, self.device)
self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device)
self.animate_from_coeff = AnimateFromCoeff(self.sadtalker_paths, self.device, use_half)

time_tag = str(uuid.uuid4())
save_dir = os.path.join(result_dir, time_tag)
Expand Down