Skip to content

Commit b97e7bf

Browse files
committed
Annotate types for ints passed as str/None, extract opencv conversion to function
1 parent 2e23cfa commit b97e7bf

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,7 @@ def decode_frames(self, video_file, pts_list):
174174
if current_frame in approx_frame_indices: # only decompress needed
175175
ret, frame = cap.retrieve()
176176
if ret:
177-
# OpenCV uses BGR, change to RGB
178-
frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
179-
# Update to C, H, W
180-
frame = np.transpose(frame, (2, 0, 1))
181-
frame = torch.from_numpy(frame)
177+
frame = self.convert_frame_to_rgb_tensor(frame)
182178
frames.append(frame)
183179

184180
if len(frames) == len(approx_frame_indices):
@@ -200,11 +196,7 @@ def decode_first_n_frames(self, video_file, n):
200196
raise ValueError("Could not grab video frame")
201197
ret, frame = cap.retrieve()
202198
if ret:
203-
# OpenCV uses BGR, change to RGB
204-
frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
205-
# Update to C, H, W
206-
frame = np.transpose(frame, (2, 0, 1))
207-
frame = torch.from_numpy(frame)
199+
frame = self.convert_frame_to_rgb_tensor(frame)
208200
frames.append(frame)
209201
cap.release()
210202
assert len(frames) == n
@@ -219,9 +211,23 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
219211
]
220212
return frames
221213

214+
def convert_frame_to_rgb_tensor(self, frame):
215+
# OpenCV uses BGR, change to RGB
216+
frame = self.cv2.cvtColor(frame, self.cv2.COLOR_BGR2RGB)
217+
# Update to C, H, W
218+
frame = np.transpose(frame, (2, 0, 1))
219+
# Convert to tensor
220+
frame = torch.from_numpy(frame)
221+
return frame
222+
222223

223224
class TorchCodecCore(AbstractDecoder):
224-
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
225+
def __init__(
226+
self,
227+
num_threads: str | None = None,
228+
color_conversion_library=None,
229+
device="cpu",
230+
):
225231
self._num_threads = int(num_threads) if num_threads else None
226232
self._color_conversion_library = color_conversion_library
227233
self._device = device
@@ -259,7 +265,12 @@ def decode_first_n_frames(self, video_file, n):
259265

260266

261267
class TorchCodecCoreNonBatch(AbstractDecoder):
262-
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
268+
def __init__(
269+
self,
270+
num_threads: str | None = None,
271+
color_conversion_library=None,
272+
device="cpu",
273+
):
263274
self._num_threads = num_threads
264275
self._color_conversion_library = color_conversion_library
265276
self._device = device
@@ -328,7 +339,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
328339

329340

330341
class TorchCodecCoreBatch(AbstractDecoder):
331-
def __init__(self, num_threads=None, color_conversion_library=None, device="cpu"):
342+
def __init__(
343+
self,
344+
num_threads: str | None = None,
345+
color_conversion_library=None,
346+
device="cpu",
347+
):
332348
self._print_each_iteration_time = False
333349
self._num_threads = int(num_threads) if num_threads else None
334350
self._color_conversion_library = color_conversion_library
@@ -369,10 +385,10 @@ def decode_first_n_frames(self, video_file, n):
369385
class TorchCodecPublic(AbstractDecoder):
370386
def __init__(
371387
self,
372-
num_ffmpeg_threads=None,
388+
num_ffmpeg_threads: str | None = None,
373389
device="cpu",
374390
seek_mode="exact",
375-
stream_index=None,
391+
stream_index: str | None = None,
376392
):
377393
self._num_ffmpeg_threads = num_ffmpeg_threads
378394
self._device = device
@@ -433,7 +449,12 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):
433449

434450

435451
class TorchCodecPublicNonBatch(AbstractDecoder):
436-
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"):
452+
def __init__(
453+
self,
454+
num_ffmpeg_threads: str | None = None,
455+
device="cpu",
456+
seek_mode="approximate",
457+
):
437458
self._num_ffmpeg_threads = num_ffmpeg_threads
438459
self._device = device
439460
self._seek_mode = seek_mode
@@ -536,7 +557,7 @@ def decode_first_n_frames(self, video_file, n):
536557

537558

538559
class TorchAudioDecoder(AbstractDecoder):
539-
def __init__(self, stream_index=None):
560+
def __init__(self, stream_index: str | None = None):
540561
import torchaudio # noqa: F401
541562

542563
self.torchaudio = torchaudio

0 commit comments

Comments
 (0)