|
1 | 1 | import abc |
2 | 2 | import json |
3 | 3 | import subprocess |
| 4 | +import typing |
4 | 5 | import urllib.request |
5 | 6 | from concurrent.futures import ThreadPoolExecutor, wait |
6 | | -from dataclasses import dataclass |
| 7 | +from dataclasses import dataclass, field |
7 | 8 | from itertools import product |
8 | 9 | from pathlib import Path |
9 | 10 |
|
|
23 | 24 | get_next_frame, |
24 | 25 | seek_to_pts, |
25 | 26 | ) |
| 27 | +from torchcodec._frame import FrameBatch |
26 | 28 | from torchcodec.decoders import VideoDecoder, VideoStreamMetadata |
27 | 29 |
|
28 | 30 | torch._dynamo.config.cache_size_limit = 100 |
@@ -824,6 +826,42 @@ def convert_result_to_df_item( |
824 | 826 | return df_item |
825 | 827 |
|
826 | 828 |
|
| 829 | +@dataclass |
| 830 | +class DecoderKind: |
| 831 | + display_name: str |
| 832 | + kind: typing.Type[AbstractDecoder] |
| 833 | + default_options: dict[str, str] = field(default_factory=dict) |
| 834 | + |
| 835 | + |
| 836 | +decoder_registry = { |
| 837 | + "decord": DecoderKind("DecordAccurate", DecordAccurate), |
| 838 | + "decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch), |
| 839 | + "torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore), |
| 840 | + "torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch), |
| 841 | + "torchcodec_core_nonbatch": DecoderKind( |
| 842 | + "TorchCodecCoreNonBatch", TorchCodecCoreNonBatch |
| 843 | + ), |
| 844 | + "torchcodec_core_compiled": DecoderKind( |
| 845 | + "TorchCodecCoreCompiled", TorchCodecCoreCompiled |
| 846 | + ), |
| 847 | + "torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic), |
| 848 | + "torchcodec_public_nonbatch": DecoderKind( |
| 849 | + "TorchCodecPublicNonBatch", TorchCodecPublicNonBatch |
| 850 | + ), |
| 851 | + "torchvision": DecoderKind( |
| 852 | + # We don't compare against TorchVision's "pyav" backend because it doesn't support |
| 853 | + # accurate seeks. |
| 854 | + "TorchVision[backend=video_reader]", |
| 855 | + TorchVision, |
| 856 | + {"backend": "video_reader"}, |
| 857 | + ), |
| 858 | + "torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder), |
| 859 | + "opencv": DecoderKind( |
| 860 | + "OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"} |
| 861 | + ), |
| 862 | +} |
| 863 | + |
| 864 | + |
827 | 865 | def run_benchmarks( |
828 | 866 | decoder_dict: dict[str, AbstractDecoder], |
829 | 867 | video_files_paths: list[Path], |
@@ -986,3 +1024,77 @@ def run_benchmarks( |
986 | 1024 | compare = benchmark.Compare(results) |
987 | 1025 | compare.print() |
988 | 1026 | return df_data |
| 1027 | + |
| 1028 | + |
| 1029 | +def verify_outputs(decoders_to_run, video_paths, num_samples): |
| 1030 | + # Reuse TorchCodecPublic decoder stream_index option, if provided. |
| 1031 | + options = decoder_registry["torchcodec_public"].default_options |
| 1032 | + if torchcodec_decoder := next( |
| 1033 | + ( |
| 1034 | + decoder |
| 1035 | + for name, decoder in decoders_to_run.items() |
| 1036 | + if "TorchCodecPublic" in name |
| 1037 | + ), |
| 1038 | + None, |
| 1039 | + ): |
| 1040 | + options["stream_index"] = ( |
| 1041 | + str(torchcodec_decoder._stream_index) |
| 1042 | + if torchcodec_decoder._stream_index is not None |
| 1043 | + else "" |
| 1044 | + ) |
| 1045 | + # Create default TorchCodecPublic decoder to use as a baseline |
| 1046 | + torchcodec_public_decoder = TorchCodecPublic(**options) |
| 1047 | + |
| 1048 | + # Get frames using each decoder |
| 1049 | + for video_file_path in video_paths: |
| 1050 | + metadata = get_metadata(video_file_path) |
| 1051 | + metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps" |
| 1052 | + print(f"{metadata_label=}") |
| 1053 | + |
| 1054 | + # Generate uniformly spaced PTS |
| 1055 | + duration = metadata.duration_seconds |
| 1056 | + pts_list = [i * duration / num_samples for i in range(num_samples)] |
| 1057 | + |
| 1058 | + # Get the frames from TorchCodecPublic as the baseline |
| 1059 | + torchcodec_public_results = decode_and_adjust_frames( |
| 1060 | + torchcodec_public_decoder, |
| 1061 | + video_file_path, |
| 1062 | + num_samples=num_samples, |
| 1063 | + pts_list=pts_list, |
| 1064 | + ) |
| 1065 | + |
| 1066 | + for decoder_name, decoder in decoders_to_run.items(): |
| 1067 | + print(f"video={video_file_path}, decoder={decoder_name}") |
| 1068 | + |
| 1069 | + frames = decode_and_adjust_frames( |
| 1070 | + decoder, |
| 1071 | + video_file_path, |
| 1072 | + num_samples=num_samples, |
| 1073 | + pts_list=pts_list, |
| 1074 | + ) |
| 1075 | + for f1, f2 in zip(torchcodec_public_results, frames): |
| 1076 | + torch.testing.assert_close(f1, f2) |
| 1077 | + print(f"Results of baseline TorchCodecPublic and {decoder_name} match!") |
| 1078 | + |
| 1079 | + |
| 1080 | +def decode_and_adjust_frames( |
| 1081 | + decoder, video_file_path, *, num_samples: int, pts_list: list[float] |
| 1082 | +): |
| 1083 | + frames = [] |
| 1084 | + # Decode non-sequential frames using decode_frames function |
| 1085 | + non_seq_frames = decoder.decode_frames(video_file_path, pts_list) |
| 1086 | + # TorchCodec's batch APIs return a FrameBatch, so we need to extract the frames |
| 1087 | + if isinstance(non_seq_frames, FrameBatch): |
| 1088 | + non_seq_frames = non_seq_frames.data |
| 1089 | + frames.extend(non_seq_frames) |
| 1090 | + |
| 1091 | + # Decode sequential frames using decode_first_n_frames function |
| 1092 | + seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples) |
| 1093 | + if isinstance(seq_frames, FrameBatch): |
| 1094 | + seq_frames = seq_frames.data |
| 1095 | + frames.extend(seq_frames) |
| 1096 | + |
| 1097 | + # Check and convert frames to C,H,W for consistency with other decoders. |
| 1098 | + if frames[0].shape[-1] == 3: |
| 1099 | + frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames] |
| 1100 | + return frames |
0 commit comments