@@ -174,11 +174,7 @@ def decode_frames(self, video_file, pts_list):
174174 if current_frame in approx_frame_indices : # only decompress needed
175175 ret , frame = cap .retrieve ()
176176 if ret :
177- # OpenCV uses BGR, change to RGB
178- frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
179- # Update to C, H, W
180- frame = np .transpose (frame , (2 , 0 , 1 ))
181- frame = torch .from_numpy (frame )
177+ frame = self .convert_frame_to_rgb_tensor (frame )
182178 frames .append (frame )
183179
184180 if len (frames ) == len (approx_frame_indices ):
@@ -200,11 +196,7 @@ def decode_first_n_frames(self, video_file, n):
200196 raise ValueError ("Could not grab video frame" )
201197 ret , frame = cap .retrieve ()
202198 if ret :
203- # OpenCV uses BGR, change to RGB
204- frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
205- # Update to C, H, W
206- frame = np .transpose (frame , (2 , 0 , 1 ))
207- frame = torch .from_numpy (frame )
199+ frame = self .convert_frame_to_rgb_tensor (frame )
208200 frames .append (frame )
209201 cap .release ()
210202 assert len (frames ) == n
@@ -219,9 +211,23 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
219211 ]
220212 return frames
221213
214+ def convert_frame_to_rgb_tensor (self , frame ):
215+ # OpenCV uses BGR, change to RGB
216+ frame = self .cv2 .cvtColor (frame , self .cv2 .COLOR_BGR2RGB )
217+ # Update to C, H, W
218+ frame = np .transpose (frame , (2 , 0 , 1 ))
219+ # Convert to tensor
220+ frame = torch .from_numpy (frame )
221+ return frame
222+
222223
223224class TorchCodecCore (AbstractDecoder ):
224- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
225+ def __init__ (
226+ self ,
227+ num_threads : str | None = None ,
228+ color_conversion_library = None ,
229+ device = "cpu" ,
230+ ):
225231 self ._num_threads = int (num_threads ) if num_threads else None
226232 self ._color_conversion_library = color_conversion_library
227233 self ._device = device
@@ -259,7 +265,12 @@ def decode_first_n_frames(self, video_file, n):
259265
260266
261267class TorchCodecCoreNonBatch (AbstractDecoder ):
262- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
268+ def __init__ (
269+ self ,
270+ num_threads : str | None = None ,
271+ color_conversion_library = None ,
272+ device = "cpu" ,
273+ ):
263274 self ._num_threads = num_threads
264275 self ._color_conversion_library = color_conversion_library
265276 self ._device = device
@@ -328,7 +339,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
328339
329340
330341class TorchCodecCoreBatch (AbstractDecoder ):
331- def __init__ (self , num_threads = None , color_conversion_library = None , device = "cpu" ):
342+ def __init__ (
343+ self ,
344+ num_threads : str | None = None ,
345+ color_conversion_library = None ,
346+ device = "cpu" ,
347+ ):
332348 self ._print_each_iteration_time = False
333349 self ._num_threads = int (num_threads ) if num_threads else None
334350 self ._color_conversion_library = color_conversion_library
@@ -369,10 +385,10 @@ def decode_first_n_frames(self, video_file, n):
369385class TorchCodecPublic (AbstractDecoder ):
370386 def __init__ (
371387 self ,
372- num_ffmpeg_threads = None ,
388+ num_ffmpeg_threads : str | None = None ,
373389 device = "cpu" ,
374390 seek_mode = "exact" ,
375- stream_index = None ,
391+ stream_index : str | None = None ,
376392 ):
377393 self ._num_ffmpeg_threads = num_ffmpeg_threads
378394 self ._device = device
@@ -433,7 +449,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
433449
434450
435451class TorchCodecPublicNonBatch (AbstractDecoder ):
436- def __init__ (self , num_ffmpeg_threads = None , device = "cpu" , seek_mode = "approximate" ):
452+ def __init__ (
453+ self ,
454+ num_ffmpeg_threads : str | None = None ,
455+ device = "cpu" ,
456+ seek_mode = "approximate" ,
457+ ):
437458 self ._num_ffmpeg_threads = num_ffmpeg_threads
438459 self ._device = device
439460 self ._seek_mode = seek_mode
@@ -536,7 +557,7 @@ def decode_first_n_frames(self, video_file, n):
536557
537558
538559class TorchAudioDecoder (AbstractDecoder ):
539- def __init__ (self , stream_index = None ):
560+ def __init__ (self , stream_index : str | None = None ):
540561 import torchaudio # noqa: F401
541562
542563 self .torchaudio = torchaudio
0 commit comments