forked from avinashpaliwal/Super-SloMo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
video_to_slomo.py
executable file
·217 lines (174 loc) · 7.68 KB
/
video_to_slomo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#!/usr/bin/env python3
import argparse
import os
import os.path
import ctypes
from shutil import rmtree, move
from PIL import Image
import torch
import torchvision.transforms as transforms
import model
import dataloader
import platform
from tqdm import tqdm
# For parsing commandline arguments
parser = argparse.ArgumentParser()
parser.add_argument("--ffmpeg_dir", type=str, default="", help='path to ffmpeg.exe')
parser.add_argument("--video", type=str, required=True, help='path of video to be converted')
parser.add_argument("--checkpoint", type=str, required=True, help='path of checkpoint for pretrained model')
parser.add_argument("--fps", type=float, default=30, help='specify fps of output video. Default: 30.')
parser.add_argument("--sf", type=int, required=True, help='specify the slomo factor N. This will increase the frames by Nx. Example sf=2 ==> 2x frames')
parser.add_argument("--batch_size", type=int, default=1, help='Specify batch size for faster conversion. This will depend on your cpu/gpu memory. Default: 1')
parser.add_argument("--output", type=str, default="output.mkv", help='Specify output file name. Default: output.mp4')
args = parser.parse_args()
def check():
"""
Checks the validity of commandline arguments.
Parameters
----------
None
Returns
-------
error : string
Error message if error occurs otherwise blank string.
"""
error = ""
if (args.sf < 2):
error = "Error: --sf/slomo factor has to be atleast 2"
if (args.batch_size < 1):
error = "Error: --batch_size has to be atleast 1"
if (args.fps < 1):
error = "Error: --fps has to be atleast 1"
if ".mkv" not in args.output:
error = "output needs to have mkv container"
return error
def extract_frames(video, outDir):
"""
Converts the `video` to images.
Parameters
----------
video : string
full path to the video file.
outDir : string
path to directory to output the extracted images.
Returns
-------
error : string
Error message if error occurs otherwise blank string.
"""
error = ""
print('{} -i {} -vsync 0 {}/%06d.png'.format(os.path.join(args.ffmpeg_dir, "ffmpeg"), video, outDir))
retn = os.system('{} -i "{}" -vsync 0 {}/%06d.png'.format(os.path.join(args.ffmpeg_dir, "ffmpeg"), video, outDir))
if retn:
error = "Error converting file:{}. Exiting.".format(video)
return error
def create_video(dir):
error = ""
print('{} -r {} -i {}/%d.png -vcodec ffvhuff {}'.format(os.path.join(args.ffmpeg_dir, "ffmpeg"), args.fps, dir, args.output))
retn = os.system('{} -r {} -i {}/%d.png -vcodec ffvhuff "{}"'.format(os.path.join(args.ffmpeg_dir, "ffmpeg"), args.fps, dir, args.output))
if retn:
error = "Error creating output video. Exiting."
return error
def main():
# Check if arguments are okay
error = check()
if error:
print(error)
exit(1)
# Create extraction folder and extract frames
IS_WINDOWS = 'Windows' == platform.system()
extractionDir = "tmpSuperSloMo"
if not IS_WINDOWS:
# Assuming UNIX-like system where "." indicates hidden directories
extractionDir = "." + extractionDir
if os.path.isdir(extractionDir):
rmtree(extractionDir)
os.mkdir(extractionDir)
if IS_WINDOWS:
FILE_ATTRIBUTE_HIDDEN = 0x02
# ctypes.windll only exists on Windows
ctypes.windll.kernel32.SetFileAttributesW(extractionDir, FILE_ATTRIBUTE_HIDDEN)
extractionPath = os.path.join(extractionDir, "input")
outputPath = os.path.join(extractionDir, "output")
os.mkdir(extractionPath)
os.mkdir(outputPath)
error = extract_frames(args.video, extractionPath)
if error:
print(error)
exit(1)
# Initialize transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mean = [0.429, 0.431, 0.397]
std = [1, 1, 1]
normalize = transforms.Normalize(mean=mean,
std=std)
negmean = [x * -1 for x in mean]
revNormalize = transforms.Normalize(mean=negmean, std=std)
# Temporary fix for issue #7 https://github.com/avinashpaliwal/Super-SloMo/issues/7 -
# - Removed per channel mean subtraction for CPU.
if (device == "cpu"):
transform = transforms.Compose([transforms.ToTensor()])
TP = transforms.Compose([transforms.ToPILImage()])
else:
transform = transforms.Compose([transforms.ToTensor(), normalize])
TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
# Load data
videoFrames = dataloader.Video(root=extractionPath, transform=transform)
videoFramesloader = torch.utils.data.DataLoader(videoFrames, batch_size=args.batch_size, shuffle=False)
# Initialize model
flowComp = model.UNet(6, 4)
flowComp.to(device)
for param in flowComp.parameters():
param.requires_grad = False
ArbTimeFlowIntrp = model.UNet(20, 5)
ArbTimeFlowIntrp.to(device)
for param in ArbTimeFlowIntrp.parameters():
param.requires_grad = False
flowBackWarp = model.backWarp(videoFrames.dim[0], videoFrames.dim[1], device)
flowBackWarp = flowBackWarp.to(device)
dict1 = torch.load(args.checkpoint, map_location='cpu')
ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])
flowComp.load_state_dict(dict1['state_dictFC'])
# Interpolate frames
frameCounter = 1
with torch.no_grad():
for _, (frame0, frame1) in enumerate(tqdm(videoFramesloader), 0):
I0 = frame0.to(device)
I1 = frame1.to(device)
flowOut = flowComp(torch.cat((I0, I1), dim=1))
F_0_1 = flowOut[:,:2,:,:]
F_1_0 = flowOut[:,2:,:,:]
# Save reference frames in output folder
for batchIndex in range(args.batch_size):
(TP(frame0[batchIndex].detach())).resize(videoFrames.origDim, Image.BILINEAR).save(os.path.join(outputPath, str(frameCounter + args.sf * batchIndex) + ".png"))
frameCounter += 1
# Generate intermediate frames
for intermediateIndex in range(1, args.sf):
t = float(intermediateIndex) / args.sf
temp = -t * (1 - t)
fCoeff = [temp, t * t, (1 - t) * (1 - t), temp]
F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0
F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0
g_I0_F_t_0 = flowBackWarp(I0, F_t_0)
g_I1_F_t_1 = flowBackWarp(I1, F_t_1)
intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))
F_t_0_f = intrpOut[:, :2, :, :] + F_t_0
F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1
V_t_0 = torch.sigmoid(intrpOut[:, 4:5, :, :])
V_t_1 = 1 - V_t_0
g_I0_F_t_0_f = flowBackWarp(I0, F_t_0_f)
g_I1_F_t_1_f = flowBackWarp(I1, F_t_1_f)
wCoeff = [1 - t, t]
Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)
# Save intermediate frame
for batchIndex in range(args.batch_size):
(TP(Ft_p[batchIndex].cpu().detach())).resize(videoFrames.origDim, Image.BILINEAR).save(os.path.join(outputPath, str(frameCounter + args.sf * batchIndex) + ".png"))
frameCounter += 1
# Set counter accounting for batching of frames
frameCounter += args.sf * (args.batch_size - 1)
# Generate video from interpolated frames
create_video(outputPath)
# Remove temporary files
rmtree(extractionDir)
exit(0)
main()