diff --git a/nodes.py b/nodes.py index 40a2ce8..a8700a1 100644 --- a/nodes.py +++ b/nodes.py @@ -39,7 +39,7 @@ StitchingRetargetingNetwork, ) from .liveportrait.utils.camera import get_rotation_matrix -from .liveportrait.utils.crop import _transform_img_kornia +from .liveportrait.utils.crop import _transform_img_kornia, _transform_pts class InferenceConfig: @@ -616,20 +616,22 @@ def INPUT_TYPES(s): ], ), "rotate": ("BOOLEAN", {"default": True}), + "output_masks": ("BOOLEAN", {"default": False}), }, } - RETURN_TYPES = ("IMAGE", "CROPINFO",) - RETURN_NAMES = ("cropped_image", "crop_info",) + RETURN_TYPES = ("IMAGE", "CROPINFO", "MASK") + RETURN_NAMES = ("cropped_image", "crop_info", "mask") FUNCTION = "process" CATEGORY = "LivePortrait" - def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ratio, face_index, face_index_order, rotate): + def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ratio, face_index, face_index_order, rotate, output_masks): source_image_np = (source_image.contiguous() * 255).byte().numpy() # Initialize lists crop_info_list = [] cropped_images_list = [] + masks_list = [] source_info = [] source_rot_list = [] f_s_list = [] @@ -640,7 +642,8 @@ def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ra for i in tqdm(range(len(source_image_np)), desc='Detecting, cropping, and processing..', total=len(source_image_np)): # Cropping operation crop_info, cropped_image_256 = cropper.crop_single_image(source_image_np[i], dsize, scale, vy_ratio, vx_ratio, face_index, face_index_order, rotate) - + if output_masks: + mask = np.zeros(source_image_np[i].shape[:2],dtype=float) # Processing source images if crop_info: crop_info_list.append(crop_info) @@ -661,6 +664,11 @@ def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ra f_s = pipeline.live_portrait_wrapper.extract_feature_3d(I_s) f_s_list.append(f_s) + if output_masks: + pts = np.array([[0,0],[0,dsize],[dsize,dsize],[dsize,0]],dtype=float) + pts = (_transform_pts(pts,crop_info["M_c2o"])+0.5).astype(np.int32) + cv2.fillPoly(mask, pts=[pts], color=1.0) + del I_s else: @@ -671,7 +679,10 @@ def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ra x_s_list.append(None) source_info.append(None) source_rot_list.append(None) - + + if output_masks: + masks_list.append(torch.from_numpy(mask).unsqueeze(0)) + # Update progress bar pbar.update(1) @@ -679,6 +690,8 @@ def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ra torch.stack([torch.from_numpy(np_array) for np_array in cropped_images_list]) / 255 ) + if output_masks: + masks_list = torch.cat(masks_list,axis=0).float() crop_info_dict = { 'crop_info_list': crop_info_list, @@ -688,7 +701,7 @@ def process(self, pipeline, cropper, source_image, dsize, scale, vx_ratio, vy_ra 'source_info': source_info } - return (cropped_tensors_out, crop_info_dict) + return (cropped_tensors_out, crop_info_dict, masks_list) class LivePortraitRetargeting: @classmethod