-
Notifications
You must be signed in to change notification settings - Fork 36
/
video_editing.py
124 lines (102 loc) · 4.87 KB
/
video_editing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
#os.environ['CUDA_VISIBLE_DEVICES'] = "0"
from models.psp import pSp
import torch
import dlib
import cv2
import PIL
import argparse
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms, utils
from argparse import Namespace
from datasets import augmentations
from scripts.align_all_parallel import align_face
from latent_optimization import latent_optimization
from utils.inference_utils import save_image, load_image, visualize, get_video_crop_parameter, tensor2cv2, tensor2label, labelcolormap
class TestOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(description="StyleGANEX Video Editing")
self.parser.add_argument("--data_path", type=str, default='./data/390.mp4', help="path of the target image/video")
self.parser.add_argument("--ckpt", type=str, default='pretrained_models/styleganex_toonify_cartoon.pt', help="path of the saved model")
self.parser.add_argument("--output_path", type=str, default='./output/', help="path of the output results")
self.parser.add_argument("--scale_factor", type=float, default=1.0, help="scale of the editing degree")
self.parser.add_argument("--cpu", action="store_true", help="if true, only use cpu")
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
print('Load options')
for name, value in sorted(args.items()):
print('%s: %s' % (str(name), str(value)))
return self.opt
if __name__ == "__main__":
parser = TestOptions()
args = parser.parse()
print('*'*98)
device = "cpu" if args.cpu else "cuda"
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
])
ckpt = torch.load(args.ckpt, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = args.ckpt
opts['device'] = device
opts = Namespace(**opts)
pspex = pSp(opts).to(device).eval()
pspex.latent_avg = pspex.latent_avg.to(device)
editing_w = None
if 'editing_w' in ckpt.keys():
editing_w = ckpt['editing_w'].clone().to(device)[0:1] * args.scale_factor
modelname = 'pretrained_models/shape_predictor_68_face_landmarks.dat'
if not os.path.exists(modelname):
import wget, bz2
wget.download('http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', modelname+'.bz2')
zipfile = bz2.BZ2File(modelname+'.bz2')
data = zipfile.read()
open(modelname, 'wb').write(data)
landmarkpredictor = dlib.shape_predictor(modelname)
print('Load models successfully!')
video_path = args.data_path
video_cap = cv2.VideoCapture(video_path)
success, frame = video_cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
paras = get_video_crop_parameter(frame, landmarkpredictor)
assert paras is not None, 'StyleGANEX uses dlib.get_frontal_face_detector but sometimes it fails to detect a face. \
You can try several times or use other videos until a face is detected, \
then switch back to the original video.'
h,w,top,bottom,left,right,scale = paras
H, W = int(bottom-top), int(right-left)
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
x1 = transform(frame).unsqueeze(0).to(device)
with torch.no_grad():
x2 = align_face(frame, landmarkpredictor)
x2 = transform(x2).unsqueeze(dim=0).to(device)
save_name = '%s/%s_%s'%(args.output_path, os.path.basename(video_path).split('.')[0], os.path.basename(args.ckpt).split('.')[0])
num = int(video_cap.get(7))
if num == 1: # input is image
save_name = save_name + '.jpg'
else: # input is video
save_name = save_name + '.mp4'
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(save_name, fourcc, video_cap.get(5), (4*W, 4*H))
with torch.no_grad():
for i in tqdm(range(num)):
if i > 0:
success, frame = video_cap.read()
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (w, h))[top:bottom, left:right]
x1 = transform(frame).unsqueeze(0).to(device)
y_hat = pspex(x1=x1, x2=x2, use_skip=pspex.opts.use_skip, zero_noise=True,
resize=False, editing_w=editing_w)
y_hat = torch.clamp(y_hat, -1, 1)
if num > 1:
videoWriter.write(tensor2cv2(y_hat[0].cpu()))
if num == 1:
save_image(y_hat[0].cpu(), save_name)
print('Image editing successfully!')
else:
videoWriter.release()
print('Video editing successfully!')