@@ -146,8 +146,84 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
146146 return frames
147147
148148
149+ class OpenCVDecoder (AbstractDecoder ):
150+ def __init__ (self , backend ):
151+ import cv2
152+
153+ self .cv2 = cv2
154+
155+ self ._available_backends = {"FFMPEG" : cv2 .CAP_FFMPEG }
156+ self ._backend = self ._available_backends .get (backend )
157+
158+ self ._print_each_iteration_time = False
159+
160+ def decode_frames (self , video_file , pts_list ):
161+ cap = self .cv2 .VideoCapture (video_file , self ._backend )
162+ if not cap .isOpened ():
163+ raise ValueError ("Could not open video stream" )
164+
165+ fps = cap .get (self .cv2 .CAP_PROP_FPS )
166+ approx_frame_indices = [int (pts * fps ) for pts in pts_list ]
167+
168+ current_frame = 0
169+ frames = []
170+ while True :
171+ ok = cap .grab ()
172+ if not ok :
173+ raise ValueError ("Could not grab video frame" )
174+ if current_frame in approx_frame_indices : # only decompress needed
175+ ret , frame = cap .retrieve ()
176+ if ret :
177+ frame = self .convert_frame_to_rgb_tensor (frame )
178+ frames .append (frame )
179+
180+ if len (frames ) == len (approx_frame_indices ):
181+ break
182+ current_frame += 1
183+ cap .release ()
184+ assert len (frames ) == len (approx_frame_indices )
185+ return frames
186+
187+ def decode_first_n_frames (self , video_file , n ):
188+ cap = self .cv2 .VideoCapture (video_file , self ._backend )
189+ if not cap .isOpened ():
190+ raise ValueError ("Could not open video stream" )
191+
192+ frames = []
193+ for i in range (n ):
194+ ok = cap .grab ()
195+ if not ok :
196+ raise ValueError ("Could not grab video frame" )
197+ ret , frame = cap .retrieve ()
198+ if ret :
199+ frame = self .convert_frame_to_rgb_tensor (frame )
200+ frames .append (frame )
201+ cap .release ()
202+ assert len (frames ) == n
203+ return frames
204+
205+ def decode_and_resize (self , * args , ** kwargs ):
206+ raise ValueError (
207+ "OpenCV doesn't apply antialias while pytorch does by default, this is potentially an unfair comparison"
208+ )
209+
210+ def convert_frame_to_rgb_tensor (self , frame ):
211+ # OpenCV uses BGR, change to RGB
212+ frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
213+ # Update to C, H, W
214+ frame = np .transpose (frame , (2 , 0 , 1 ))
215+ # Convert to tensor
216+ frame = torch .from_numpy (frame )
217+ return frame
218+
219+
149220class TorchCodecCore (AbstractDecoder ):
150- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
221+ def __init__ (
222+ self ,
223+ num_threads : str | None = None ,
224+ color_conversion_library = None ,
225+ device = "cpu" ,
226+ ):
151227 self ._num_threads = int (num_threads ) if num_threads else None
152228 self ._color_conversion_library = color_conversion_library
153229 self ._device = device
@@ -185,7 +261,12 @@ def decode_first_n_frames(self, video_file, n):
185261
186262
187263class TorchCodecCoreNonBatch (AbstractDecoder ):
188- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
264+ def __init__ (
265+ self ,
266+ num_threads : str | None = None ,
267+ color_conversion_library = None ,
268+ device = "cpu" ,
269+ ):
189270 self ._num_threads = num_threads
190271 self ._color_conversion_library = color_conversion_library
191272 self ._device = device
@@ -254,7 +335,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
254335
255336
256337class TorchCodecCoreBatch (AbstractDecoder ):
257- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
338+ def __init__ (
339+ self ,
340+ num_threads : str | None = None ,
341+ color_conversion_library = None ,
342+ device = "cpu" ,
343+ ):
258344 self ._print_each_iteration_time = False
259345 self ._num_threads = int (num_threads ) if num_threads else None
260346 self ._color_conversion_library = color_conversion_library
@@ -293,10 +379,17 @@ def decode_first_n_frames(self, video_file, n):
293379
294380
295381class TorchCodecPublic (AbstractDecoder ):
296- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "exact" ):
382+ def __init__ (
383+ self ,
384+ num_ffmpeg_threads : str | None = None ,
385+ device = "cpu" ,
386+ seek_mode = "exact" ,
387+ stream_index : str | None = None ,
388+ ):
297389 self ._num_ffmpeg_threads = num_ffmpeg_threads
298390 self ._device = device
299391 self ._seek_mode = seek_mode
392+ self ._stream_index = int (stream_index ) if stream_index else None
300393
301394 from torchvision .transforms import v2 as transforms_v2
302395
@@ -311,6 +404,7 @@ def decode_frames(self, video_file, pts_list):
311404 num_ffmpeg_threads = num_ffmpeg_threads ,
312405 device = self ._device ,
313406 seek_mode = self ._seek_mode ,
407+ stream_index = self ._stream_index ,
314408 )
315409 return decoder .get_frames_played_at (pts_list )
316410
@@ -323,6 +417,7 @@ def decode_first_n_frames(self, video_file, n):
323417 num_ffmpeg_threads = num_ffmpeg_threads ,
324418 device = self ._device ,
325419 seek_mode = self ._seek_mode ,
420+ stream_index = self ._stream_index ,
326421 )
327422 frames = []
328423 count = 0
@@ -342,14 +437,20 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
342437 num_ffmpeg_threads = num_ffmpeg_threads ,
343438 device = self ._device ,
344439 seek_mode = self ._seek_mode ,
440+ stream_index = self ._stream_index ,
345441 )
346442 frames = decoder .get_frames_played_at (pts_list )
347443 frames = self .transforms_v2 .functional .resize (frames .data , (height , width ))
348444 return frames
349445
350446
351447class TorchCodecPublicNonBatch (AbstractDecoder ):
352- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "approximate" ):
448+ def __init__ (
449+ self ,
450+ num_ffmpeg_threads : str | None = None ,
451+ device = "cpu" ,
452+ seek_mode = "approximate" ,
453+ ):
353454 self ._num_ffmpeg_threads = num_ffmpeg_threads
354455 self ._device = device
355456 self ._seek_mode = seek_mode
@@ -452,19 +553,22 @@ def decode_first_n_frames(self, video_file, n):
452553
453554
454555class TorchAudioDecoder (AbstractDecoder ):
455- def __init__ (self ):
556+ def __init__ (self , stream_index : str | None = None ):
456557 import torchaudio # noqa: F401
457558
458559 self .torchaudio = torchaudio
459560
460561 from torchvision .transforms import v2 as transforms_v2
461562
462563 self .transforms_v2 = transforms_v2
564+ self ._stream_index = int (stream_index ) if stream_index else None
463565
464566 def decode_frames (self , video_file , pts_list ):
465567 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
466568 stream_reader .add_basic_video_stream (
467- frames_per_chunk = 1 , decoder_option = {"threads" : "0" }
569+ frames_per_chunk = 1 ,
570+ decoder_option = {"threads" : "0" },
571+ stream_index = self ._stream_index ,
468572 )
469573 frames = []
470574 for pts in pts_list :
@@ -477,7 +581,9 @@ def decode_frames(self, video_file, pts_list):
477581 def decode_first_n_frames (self , video_file , n ):
478582 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
479583 stream_reader .add_basic_video_stream (
480- frames_per_chunk = 1 , decoder_option = {"threads" : "0" }
584+ frames_per_chunk = 1 ,
585+ decoder_option = {"threads" : "0" },
586+ stream_index = self ._stream_index ,
481587 )
482588 frames = []
483589 frame_cnt = 0
@@ -492,7 +598,9 @@ def decode_first_n_frames(self, video_file, n):
492598 def decode_and_resize (self , video_file , pts_list , height , width , device ):
493599 stream_reader = self .torchaudio .io .StreamReader (src = video_file )
494600 stream_reader .add_basic_video_stream (
495- frames_per_chunk = 1 , decoder_option = {"threads" : "1" }
601+ frames_per_chunk = 1 ,
602+ decoder_option = {"threads" : "1" },
603+ stream_index = self ._stream_index ,
496604 )
497605 frames = []
498606 for pts in pts_list :
@@ -745,7 +853,8 @@ def run_benchmarks(
745853 # are using different random pts values across videos.
746854 random_pts_list = (torch .rand (num_samples ) * duration ).tolist ()
747855
748- for decoder_name , decoder in decoder_dict .items ():
856+ # The decoder items are sorted to perform and display the benchmarks in a consistent order.
857+ for decoder_name , decoder in sorted (decoder_dict .items (), key = lambda x : x [0 ]):
749858 print (f"video={ video_file_path } , decoder={ decoder_name } " )
750859
751860 if dataloader_parameters :
0 commit comments