Skip to content

Commit d88cf1a

Browse files
authored
Verify decoder outputs (#728)
1 parent dd44f57 commit d88cf1a

File tree

2 files changed

+147
-74
lines changed

2 files changed

+147
-74
lines changed

benchmarks/decoders/benchmark_decoders.py

Lines changed: 34 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,18 @@
88
import importlib.resources
99
import os
1010
import platform
11-
import typing
12-
from dataclasses import dataclass, field
1311
from pathlib import Path
1412

1513
import torch
1614

1715
from benchmark_decoders_library import (
18-
AbstractDecoder,
19-
DecordAccurate,
20-
DecordAccurateBatch,
21-
OpenCVDecoder,
16+
decoder_registry,
2217
plot_data,
2318
run_benchmarks,
24-
TorchAudioDecoder,
25-
TorchCodecCore,
26-
TorchCodecCoreBatch,
27-
TorchCodecCoreCompiled,
28-
TorchCodecCoreNonBatch,
29-
TorchCodecPublic,
30-
TorchCodecPublicNonBatch,
31-
TorchVision,
19+
verify_outputs,
3220
)
3321

3422

35-
@dataclass
36-
class DecoderKind:
37-
display_name: str
38-
kind: typing.Type[AbstractDecoder]
39-
default_options: dict[str, str] = field(default_factory=dict)
40-
41-
42-
decoder_registry = {
43-
"decord": DecoderKind("DecordAccurate", DecordAccurate),
44-
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch),
45-
"torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore),
46-
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch),
47-
"torchcodec_core_nonbatch": DecoderKind(
48-
"TorchCodecCoreNonBatch", TorchCodecCoreNonBatch
49-
),
50-
"torchcodec_core_compiled": DecoderKind(
51-
"TorchCodecCoreCompiled", TorchCodecCoreCompiled
52-
),
53-
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic),
54-
"torchcodec_public_nonbatch": DecoderKind(
55-
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch
56-
),
57-
"torchvision": DecoderKind(
58-
# We don't compare against TorchVision's "pyav" backend because it doesn't support
59-
# accurate seeks.
60-
"TorchVision[backend=video_reader]",
61-
TorchVision,
62-
{"backend": "video_reader"},
63-
),
64-
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder),
65-
"opencv": DecoderKind(
66-
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"}
67-
),
68-
}
69-
70-
7123
def in_fbcode() -> bool:
7224
return "FB_PAR_RUNTIME_FILES" in os.environ
7325

@@ -148,6 +100,12 @@ def main() -> None:
148100
type=str,
149101
default="benchmarks.png",
150102
)
103+
parser.add_argument(
104+
"--verify-outputs",
105+
help="Verify that the outputs of the decoders are the same",
106+
default=False,
107+
action=argparse.BooleanOptionalAction,
108+
)
151109

152110
args = parser.parse_args()
153111
specified_decoders = set(args.decoders.split(","))
@@ -177,29 +135,32 @@ def main() -> None:
177135
if entry.is_file() and entry.name.endswith(".mp4"):
178136
video_paths.append(entry.path)
179137

180-
results = run_benchmarks(
181-
decoders_to_run,
182-
video_paths,
183-
num_uniform_samples,
184-
num_sequential_frames_from_start=[1, 10, 100],
185-
min_runtime_seconds=args.min_run_seconds,
186-
benchmark_video_creation=args.bm_video_creation,
187-
)
188-
data = {
189-
"experiments": results,
190-
"system_metadata": {
191-
"cpu_count": os.cpu_count(),
192-
"system": platform.system(),
193-
"machine": platform.machine(),
194-
"python_version": str(platform.python_version()),
195-
"cuda": (
196-
torch.cuda.get_device_properties(0).name
197-
if torch.cuda.is_available()
198-
else "not available"
199-
),
200-
},
201-
}
202-
plot_data(data, args.plot_path)
138+
if args.verify_outputs:
139+
verify_outputs(decoders_to_run, video_paths, num_uniform_samples)
140+
else:
141+
results = run_benchmarks(
142+
decoders_to_run,
143+
video_paths,
144+
num_uniform_samples,
145+
num_sequential_frames_from_start=[1, 10, 100],
146+
min_runtime_seconds=args.min_run_seconds,
147+
benchmark_video_creation=args.bm_video_creation,
148+
)
149+
data = {
150+
"experiments": results,
151+
"system_metadata": {
152+
"cpu_count": os.cpu_count(),
153+
"system": platform.system(),
154+
"machine": platform.machine(),
155+
"python_version": str(platform.python_version()),
156+
"cuda": (
157+
torch.cuda.get_device_properties(0).name
158+
if torch.cuda.is_available()
159+
else "not available"
160+
),
161+
},
162+
}
163+
plot_data(data, args.plot_path)
203164

204165

205166
if __name__ == "__main__":

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import abc
22
import json
33
import subprocess
4+
import typing
45
import urllib.request
56
from concurrent.futures import ThreadPoolExecutor, wait
6-
from dataclasses import dataclass
7+
from dataclasses import dataclass, field
78
from itertools import product
89
from pathlib import Path
910

@@ -23,6 +24,7 @@
2324
get_next_frame,
2425
seek_to_pts,
2526
)
27+
from torchcodec._frame import FrameBatch
2628
from torchcodec.decoders import VideoDecoder, VideoStreamMetadata
2729

2830
torch._dynamo.config.cache_size_limit = 100
@@ -824,6 +826,42 @@ def convert_result_to_df_item(
824826
return df_item
825827

826828

829+
@dataclass
830+
class DecoderKind:
831+
display_name: str
832+
kind: typing.Type[AbstractDecoder]
833+
default_options: dict[str, str] = field(default_factory=dict)
834+
835+
836+
decoder_registry = {
837+
"decord": DecoderKind("DecordAccurate", DecordAccurate),
838+
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch),
839+
"torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore),
840+
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch),
841+
"torchcodec_core_nonbatch": DecoderKind(
842+
"TorchCodecCoreNonBatch", TorchCodecCoreNonBatch
843+
),
844+
"torchcodec_core_compiled": DecoderKind(
845+
"TorchCodecCoreCompiled", TorchCodecCoreCompiled
846+
),
847+
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic),
848+
"torchcodec_public_nonbatch": DecoderKind(
849+
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch
850+
),
851+
"torchvision": DecoderKind(
852+
# We don't compare against TorchVision's "pyav" backend because it doesn't support
853+
# accurate seeks.
854+
"TorchVision[backend=video_reader]",
855+
TorchVision,
856+
{"backend": "video_reader"},
857+
),
858+
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder),
859+
"opencv": DecoderKind(
860+
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"}
861+
),
862+
}
863+
864+
827865
def run_benchmarks(
828866
decoder_dict: dict[str, AbstractDecoder],
829867
video_files_paths: list[Path],
@@ -986,3 +1024,77 @@ def run_benchmarks(
9861024
compare = benchmark.Compare(results)
9871025
compare.print()
9881026
return df_data
1027+
1028+
1029+
def verify_outputs(decoders_to_run, video_paths, num_samples):
1030+
# Reuse TorchCodecPublic decoder stream_index option, if provided.
1031+
options = decoder_registry["torchcodec_public"].default_options
1032+
if torchcodec_decoder := next(
1033+
(
1034+
decoder
1035+
for name, decoder in decoders_to_run.items()
1036+
if "TorchCodecPublic" in name
1037+
),
1038+
None,
1039+
):
1040+
options["stream_index"] = (
1041+
str(torchcodec_decoder._stream_index)
1042+
if torchcodec_decoder._stream_index is not None
1043+
else ""
1044+
)
1045+
# Create default TorchCodecPublic decoder to use as a baseline
1046+
torchcodec_public_decoder = TorchCodecPublic(**options)
1047+
1048+
# Get frames using each decoder
1049+
for video_file_path in video_paths:
1050+
metadata = get_metadata(video_file_path)
1051+
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
1052+
print(f"{metadata_label=}")
1053+
1054+
# Generate uniformly spaced PTS
1055+
duration = metadata.duration_seconds
1056+
pts_list = [i * duration / num_samples for i in range(num_samples)]
1057+
1058+
# Get the frames from TorchCodecPublic as the baseline
1059+
torchcodec_public_results = decode_and_adjust_frames(
1060+
torchcodec_public_decoder,
1061+
video_file_path,
1062+
num_samples=num_samples,
1063+
pts_list=pts_list,
1064+
)
1065+
1066+
for decoder_name, decoder in decoders_to_run.items():
1067+
print(f"video={video_file_path}, decoder={decoder_name}")
1068+
1069+
frames = decode_and_adjust_frames(
1070+
decoder,
1071+
video_file_path,
1072+
num_samples=num_samples,
1073+
pts_list=pts_list,
1074+
)
1075+
for f1, f2 in zip(torchcodec_public_results, frames):
1076+
torch.testing.assert_close(f1, f2)
1077+
print(f"Results of baseline TorchCodecPublic and {decoder_name} match!")
1078+
1079+
1080+
def decode_and_adjust_frames(
1081+
decoder, video_file_path, *, num_samples: int, pts_list: list[float]
1082+
):
1083+
frames = []
1084+
# Decode non-sequential frames using decode_frames function
1085+
non_seq_frames = decoder.decode_frames(video_file_path, pts_list)
1086+
# TorchCodec's batch APIs return a FrameBatch, so we need to extract the frames
1087+
if isinstance(non_seq_frames, FrameBatch):
1088+
non_seq_frames = non_seq_frames.data
1089+
frames.extend(non_seq_frames)
1090+
1091+
# Decode sequential frames using decode_first_n_frames function
1092+
seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples)
1093+
if isinstance(seq_frames, FrameBatch):
1094+
seq_frames = seq_frames.data
1095+
frames.extend(seq_frames)
1096+
1097+
# Check and convert frames to C,H,W for consistency with other decoders.
1098+
if frames[0].shape[-1] == 3:
1099+
frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames]
1100+
return frames

0 commit comments

Comments
 (0)