@@ -138,6 +138,24 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
138138 return UniqueCUvideodecoder (decoder, CUvideoDecoderDeleter{});
139139}
140140
141+ cudaVideoCodec validateCodecSupport (AVCodecID codecId) {
142+ switch (codecId) {
143+ case AV_CODEC_ID_H264:
144+ return cudaVideoCodec_H264;
145+ case AV_CODEC_ID_HEVC:
146+ return cudaVideoCodec_HEVC;
147+ // TODONVDEC P0: support more codecs
148+ // case AV_CODEC_ID_AV1: return cudaVideoCodec_AV1;
149+ // case AV_CODEC_ID_MPEG4: return cudaVideoCodec_MPEG4;
150+ // case AV_CODEC_ID_VP8: return cudaVideoCodec_VP8;
151+ // case AV_CODEC_ID_VP9: return cudaVideoCodec_VP9;
152+ // case AV_CODEC_ID_MJPEG: return cudaVideoCodec_JPEG;
153+ default : {
154+ TORCH_CHECK (false , " Unsupported codec type: " , avcodec_get_name (codecId));
155+ }
156+ }
157+ }
158+
141159} // namespace
142160
143161BetaCudaDeviceInterface::BetaCudaDeviceInterface (const torch::Device& device)
@@ -163,29 +181,62 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() {
163181 }
164182}
165183
166- void BetaCudaDeviceInterface::initializeInterface (AVStream* avStream) {
167- torch::Tensor dummyTensorForCudaInitialization = torch::empty (
168- {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
184+ void BetaCudaDeviceInterface::initializeBSF (
185+ const AVCodecParameters* codecPar,
186+ const UniqueDecodingAVFormatContext& avFormatCtx) {
187+ // Setup bit stream filters (BSF):
188+ // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
189+ // This is only needed for some formats, like H264 or HEVC.
169190
170- TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
171- timeBase_ = avStream->time_base ;
191+ TORCH_CHECK (codecPar != nullptr , " codecPar cannot be null" );
192+ TORCH_CHECK (avFormatCtx != nullptr , " AVFormatContext cannot be null" );
193+ TORCH_CHECK (
194+ avFormatCtx->iformat != nullptr ,
195+ " AVFormatContext->iformat cannot be null" );
196+ std::string filterName;
197+
198+ // Matching logic is taken from DALI
199+ switch (codecPar->codec_id ) {
200+ case AV_CODEC_ID_H264: {
201+ const std::string formatName = avFormatCtx->iformat ->long_name
202+ ? avFormatCtx->iformat ->long_name
203+ : " " ;
204+
205+ if (formatName == " QuickTime / MOV" ||
206+ formatName == " FLV (Flash Video)" ||
207+ formatName == " Matroska / WebM" || formatName == " raw H.264 video" ) {
208+ filterName = " h264_mp4toannexb" ;
209+ }
210+ break ;
211+ }
172212
173- const AVCodecParameters* codecpar = avStream->codecpar ;
174- TORCH_CHECK (codecpar != nullptr , " CodecParameters cannot be null" );
213+ case AV_CODEC_ID_HEVC: {
214+ const std::string formatName = avFormatCtx->iformat ->long_name
215+ ? avFormatCtx->iformat ->long_name
216+ : " " ;
175217
176- TORCH_CHECK (
177- // TODONVDEC P0 support more
178- avStream->codecpar ->codec_id == AV_CODEC_ID_H264,
179- " Can only do H264 for now" );
218+ if (formatName == " QuickTime / MOV" ||
219+ formatName == " FLV (Flash Video)" ||
220+ formatName == " Matroska / WebM" || formatName == " raw HEVC video" ) {
221+ filterName = " hevc_mp4toannexb" ;
222+ }
223+ break ;
224+ }
180225
181- // Setup bit stream filters (BSF):
182- // https://ffmpeg.org/doxygen/7.0/group__lavc__bsf.html
183- // This is only needed for some formats, like H264 or HEVC. TODONVDEC P1: For
184- // now we apply BSF unconditionally, but it should be optional and dependent
185- // on codec and container.
186- const AVBitStreamFilter* avBSF = av_bsf_get_by_name (" h264_mp4toannexb" );
226+ default :
227+ // No bitstream filter needed for other codecs
228+ // TODONVDEC P1 MPEG4 will need one!
229+ break ;
230+ }
231+
232+ if (filterName.empty ()) {
233+ // Only initialize BSF if we actually need one
234+ return ;
235+ }
236+
237+ const AVBitStreamFilter* avBSF = av_bsf_get_by_name (filterName.c_str ());
187238 TORCH_CHECK (
188- avBSF != nullptr , " Failed to find h264_mp4toannexb bitstream filter" );
239+ avBSF != nullptr , " Failed to find bitstream filter: " , filterName );
189240
190241 AVBSFContext* avBSFContext = nullptr ;
191242 int retVal = av_bsf_alloc (avBSF, &avBSFContext);
@@ -196,7 +247,7 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
196247
197248 bitstreamFilter_.reset (avBSFContext);
198249
199- retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecpar );
250+ retVal = avcodec_parameters_copy (bitstreamFilter_->par_in , codecPar );
200251 TORCH_CHECK (
201252 retVal >= AVSUCCESS,
202253 " Failed to copy codec parameters: " ,
@@ -207,10 +258,25 @@ void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) {
207258 retVal == AVSUCCESS,
208259 " Failed to initialize bitstream filter: " ,
209260 getFFMPEGErrorStringFromErrorCode (retVal));
261+ }
262+
263+ void BetaCudaDeviceInterface::initializeInterface (
264+ const AVStream* avStream,
265+ const UniqueDecodingAVFormatContext& avFormatCtx) {
266+ torch::Tensor dummyTensorForCudaInitialization = torch::empty (
267+ {1 }, torch::TensorOptions ().dtype (torch::kUInt8 ).device (device_));
268+
269+ TORCH_CHECK (avStream != nullptr , " AVStream cannot be null" );
270+ timeBase_ = avStream->time_base ;
271+
272+ const AVCodecParameters* codecPar = avStream->codecpar ;
273+ TORCH_CHECK (codecPar != nullptr , " CodecParameters cannot be null" );
274+
275+ initializeBSF (codecPar, avFormatCtx);
210276
211277 // Create parser. Default values that aren't obvious are taken from DALI.
212278 CUVIDPARSERPARAMS parserParams = {};
213- parserParams.CodecType = cudaVideoCodec_H264 ;
279+ parserParams.CodecType = validateCodecSupport (codecPar-> codec_id ) ;
214280 parserParams.ulMaxNumDecodeSurfaces = 8 ;
215281 parserParams.ulMaxDisplayDelay = 0 ;
216282 // Callback setup, all are triggered by the parser within a call
0 commit comments