diff --git a/torchaudio/csrc/ffmpeg/decoder.cpp b/torchaudio/csrc/ffmpeg/decoder.cpp index 4ac2dff8cc3..b9ff2356864 100644 --- a/torchaudio/csrc/ffmpeg/decoder.cpp +++ b/torchaudio/csrc/ffmpeg/decoder.cpp @@ -6,7 +6,11 @@ namespace ffmpeg { //////////////////////////////////////////////////////////////////////////////// // Decoder //////////////////////////////////////////////////////////////////////////////// -Decoder::Decoder(AVCodecParameters* pParam) : pCodecContext(pParam) {} +Decoder::Decoder( + AVCodecParameters* pParam, + const std::string& decoder_name, + const std::map& decoder_option) + : pCodecContext(pParam, decoder_name, decoder_option) {} int Decoder::process_packet(AVPacket* pPacket) { return avcodec_send_packet(pCodecContext, pPacket); diff --git a/torchaudio/csrc/ffmpeg/decoder.h b/torchaudio/csrc/ffmpeg/decoder.h index ba3c11dbc26..ed823c788b5 100644 --- a/torchaudio/csrc/ffmpeg/decoder.h +++ b/torchaudio/csrc/ffmpeg/decoder.h @@ -10,7 +10,10 @@ class Decoder { public: // Default constructable - Decoder(AVCodecParameters* pParam); + Decoder( + AVCodecParameters* pParam, + const std::string& decoder_name, + const std::map& decoder_option); // Custom destructor to clean up the resources ~Decoder() = default; // Non-copyable diff --git a/torchaudio/csrc/ffmpeg/ffmpeg.cpp b/torchaudio/csrc/ffmpeg/ffmpeg.cpp index 7acb92ee0db..98833e8d740 100644 --- a/torchaudio/csrc/ffmpeg/ffmpeg.cpp +++ b/torchaudio/csrc/ffmpeg/ffmpeg.cpp @@ -151,11 +151,22 @@ void AVCodecContextDeleter::operator()(AVCodecContext* p) { }; namespace { -AVCodecContext* get_codec_context(AVCodecParameters* pParams) { - const AVCodec* pCodec = avcodec_find_decoder(pParams->codec_id); +AVCodecContext* get_codec_context( + enum AVCodecID codec_id, + const std::string& decoder_name) { + const AVCodec* pCodec = decoder_name.empty() + ? avcodec_find_decoder(codec_id) + : avcodec_find_decoder_by_name(decoder_name.c_str()); if (!pCodec) { - throw std::runtime_error("Unknown codec."); + std::stringstream ss; + if (decoder_name.empty()) { + ss << "Unsupported codec: \"" << avcodec_get_name(codec_id) << "\", (" + << codec_id << ")."; + } else { + ss << "Unsupported codec: \"" << decoder_name << "\"."; + } + throw std::runtime_error(ss.str()); } AVCodecContext* pCodecContext = avcodec_alloc_context3(pCodec); @@ -167,16 +178,29 @@ AVCodecContext* get_codec_context(AVCodecParameters* pParams) { void init_codec_context( AVCodecContext* pCodecContext, - AVCodecParameters* pParams) { - const AVCodec* pCodec = avcodec_find_decoder(pParams->codec_id); + AVCodecParameters* pParams, + const std::string& decoder_name, + std::map decoder_option) { + const AVCodec* pCodec = decoder_name.empty() + ? avcodec_find_decoder(pParams->codec_id) + : avcodec_find_decoder_by_name(decoder_name.c_str()); + + // No need to check if pCodec is null as it's been already checked in + // get_codec_context if (avcodec_parameters_to_context(pCodecContext, pParams) < 0) { throw std::runtime_error("Failed to set CodecContext parameter."); } - if (avcodec_open2(pCodecContext, pCodec, NULL) < 0) { + AVDictionary* opts = get_option_dict(decoder_option); + if (avcodec_open2(pCodecContext, pCodec, &opts) < 0) { throw std::runtime_error("Failed to initialize CodecContext."); } + auto unused_keys = clean_up_dict(opts); + if (unused_keys.size()) { + throw std::runtime_error( + "Unexpected decoder options: " + join(unused_keys)); + } if (pParams->codec_type == AVMEDIA_TYPE_AUDIO && !pParams->channel_layout) pParams->channel_layout = @@ -184,10 +208,13 @@ void init_codec_context( } } // namespace -AVCodecContextPtr::AVCodecContextPtr(AVCodecParameters* pParam) +AVCodecContextPtr::AVCodecContextPtr( + AVCodecParameters* pParam, + const std::string& decoder_name, + const std::map& decoder_option) : Wrapper( - get_codec_context(pParam)) { - init_codec_context(ptr.get(), pParam); + get_codec_context(pParam->codec_id, decoder_name)) { + init_codec_context(ptr.get(), pParam, decoder_name, decoder_option); } //////////////////////////////////////////////////////////////////////////////// // AVFilterGraph diff --git a/torchaudio/csrc/ffmpeg/ffmpeg.h b/torchaudio/csrc/ffmpeg/ffmpeg.h index 68645d64e77..ed6c581b0a0 100644 --- a/torchaudio/csrc/ffmpeg/ffmpeg.h +++ b/torchaudio/csrc/ffmpeg/ffmpeg.h @@ -118,7 +118,10 @@ struct AVCodecContextDeleter { }; struct AVCodecContextPtr : public Wrapper { - AVCodecContextPtr(AVCodecParameters* pParam); + AVCodecContextPtr( + AVCodecParameters* pParam, + const std::string& decoder, + const std::map& decoder_option); }; //////////////////////////////////////////////////////////////////////////////// diff --git a/torchaudio/csrc/ffmpeg/prototype.cpp b/torchaudio/csrc/ffmpeg/prototype.cpp index 69a0d269910..50db7d18e02 100644 --- a/torchaudio/csrc/ffmpeg/prototype.cpp +++ b/torchaudio/csrc/ffmpeg/prototype.cpp @@ -7,8 +7,10 @@ namespace ffmpeg { namespace { +using OptionDict = c10::Dict; + std::map convert_dict( - const c10::optional>& option) { + const c10::optional& option) { std::map opts; if (option) { for (auto& it : option.value()) { @@ -23,7 +25,7 @@ struct StreamerHolder : torch::CustomClassHolder { StreamerHolder( const std::string& src, c10::optional device, - c10::optional> option) + c10::optional option) : s(src, device.value_or(""), convert_dict(option)) {} }; @@ -32,7 +34,7 @@ using S = c10::intrusive_ptr; S init( const std::string& src, c10::optional device, - c10::optional> option) { + c10::optional option) { return c10::make_intrusive(src, device, option); } @@ -216,7 +218,7 @@ void add_basic_audio_stream( const c10::optional& sample_rate, const c10::optional& dtype) { std::string filter_desc = get_afilter_desc(sample_rate, dtype); - s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc); + s->s.add_audio_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {}); } void add_basic_video_stream( @@ -229,7 +231,7 @@ void add_basic_video_stream( const c10::optional& height, const c10::optional& format) { std::string filter_desc = get_vfilter_desc(frame_rate, width, height, format); - s->s.add_video_stream(i, frames_per_chunk, num_chunks, filter_desc); + s->s.add_video_stream(i, frames_per_chunk, num_chunks, filter_desc, "", {}); } void add_audio_stream( @@ -237,9 +239,16 @@ void add_audio_stream( int64_t i, int64_t frames_per_chunk, int64_t num_chunks, - const c10::optional& filter_desc) { + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_options) { s->s.add_audio_stream( - i, frames_per_chunk, num_chunks, filter_desc.value_or("")); + i, + frames_per_chunk, + num_chunks, + filter_desc.value_or(""), + decoder.value_or(""), + convert_dict(decoder_options)); } void add_video_stream( @@ -247,9 +256,16 @@ void add_video_stream( int64_t i, int64_t frames_per_chunk, int64_t num_chunks, - const c10::optional& filter_desc) { + const c10::optional& filter_desc, + const c10::optional& decoder, + const c10::optional& decoder_options) { s->s.add_video_stream( - i, frames_per_chunk, num_chunks, filter_desc.value_or("")); + i, + frames_per_chunk, + num_chunks, + filter_desc.value_or(""), + decoder.value_or(""), + convert_dict(decoder_options)); } void remove_stream(S s, int64_t i) { @@ -293,7 +309,7 @@ std::tuple, int64_t> load(const std::string& src) { int i = s.find_best_audio_stream(); auto sinfo = s.get_src_stream_info(i); int64_t sample_rate = static_cast(sinfo.sample_rate); - s.add_audio_stream(i, -1, -1, ""); + s.add_audio_stream(i, -1, -1, "", "", {}); process_all_packets(s); auto tensors = s.pop_chunks(); return std::make_tuple<>(tensors[0], sample_rate); diff --git a/torchaudio/csrc/ffmpeg/stream_processor.cpp b/torchaudio/csrc/ffmpeg/stream_processor.cpp index 33f95e71be7..0e4b91af412 100644 --- a/torchaudio/csrc/ffmpeg/stream_processor.cpp +++ b/torchaudio/csrc/ffmpeg/stream_processor.cpp @@ -6,8 +6,11 @@ namespace ffmpeg { using KeyType = StreamProcessor::KeyType; -StreamProcessor::StreamProcessor(AVCodecParameters* codecpar) - : decoder(codecpar) {} +StreamProcessor::StreamProcessor( + AVCodecParameters* codecpar, + const std::string& decoder_name, + const std::map& decoder_option) + : decoder(codecpar, decoder_name, decoder_option) {} //////////////////////////////////////////////////////////////////////////////// // Configurations diff --git a/torchaudio/csrc/ffmpeg/stream_processor.h b/torchaudio/csrc/ffmpeg/stream_processor.h index 846c5b25c6f..50c7f17633a 100644 --- a/torchaudio/csrc/ffmpeg/stream_processor.h +++ b/torchaudio/csrc/ffmpeg/stream_processor.h @@ -25,7 +25,10 @@ class StreamProcessor { std::map sinks; public: - StreamProcessor(AVCodecParameters* codecpar); + StreamProcessor( + AVCodecParameters* codecpar, + const std::string& decoder_name, + const std::map& decoder_option); ~StreamProcessor() = default; // Non-copyable StreamProcessor(const StreamProcessor&) = delete; diff --git a/torchaudio/csrc/ffmpeg/streamer.cpp b/torchaudio/csrc/ffmpeg/streamer.cpp index bc1c0064ec3..d7bdb1d6770 100644 --- a/torchaudio/csrc/ffmpeg/streamer.cpp +++ b/torchaudio/csrc/ffmpeg/streamer.cpp @@ -156,26 +156,34 @@ void Streamer::add_audio_stream( int i, int frames_per_chunk, int num_chunks, - std::string filter_desc) { + std::string filter_desc, + const std::string& decoder, + const std::map& decoder_option) { add_stream( i, AVMEDIA_TYPE_AUDIO, frames_per_chunk, num_chunks, - std::move(filter_desc)); + std::move(filter_desc), + decoder, + decoder_option); } void Streamer::add_video_stream( int i, int frames_per_chunk, int num_chunks, - std::string filter_desc) { + std::string filter_desc, + std::string decoder, + std::map decoder_option) { add_stream( i, AVMEDIA_TYPE_VIDEO, frames_per_chunk, num_chunks, - std::move(filter_desc)); + std::move(filter_desc), + decoder, + decoder_option); } void Streamer::add_stream( @@ -183,12 +191,15 @@ void Streamer::add_stream( AVMediaType media_type, int frames_per_chunk, int num_chunks, - std::string filter_desc) { + std::string filter_desc, + std::string decoder, + std::map decoder_option) { validate_src_stream_type(i, media_type); AVStream* stream = pFormatContext->streams[i]; stream->discard = AVDISCARD_DEFAULT; if (!processors[i]) - processors[i] = std::make_unique(stream->codecpar); + processors[i] = std::make_unique( + stream->codecpar, decoder, decoder_option); int key = processors[i]->add_stream( stream->time_base, stream->codecpar, diff --git a/torchaudio/csrc/ffmpeg/streamer.h b/torchaudio/csrc/ffmpeg/streamer.h index 7b94514560a..1d2a0ee084e 100644 --- a/torchaudio/csrc/ffmpeg/streamer.h +++ b/torchaudio/csrc/ffmpeg/streamer.h @@ -66,12 +66,16 @@ class Streamer { int i, int frames_per_chunk, int num_chunks, - std::string filter_desc); + std::string filter_desc, + std::string decoder, + std::map decoder_option); void add_video_stream( int i, int frames_per_chunk, int num_chunks, - std::string filter_desc); + std::string filter_desc, + std::string decoder, + std::map decoder_option); void remove_stream(int i); private: @@ -80,7 +84,9 @@ class Streamer { AVMediaType media_type, int frames_per_chunk, int num_chunks, - std::string filter_desc); + std::string filter_desc, + const std::string& decoder, + const std::map& decoder_option); public: ////////////////////////////////////////////////////////////////////////////// diff --git a/torchaudio/prototype/io/streamer.py b/torchaudio/prototype/io/streamer.py index 34e863a6675..c65743070da 100644 --- a/torchaudio/prototype/io/streamer.py +++ b/torchaudio/prototype/io/streamer.py @@ -354,6 +354,8 @@ def add_audio_stream( buffer_chunk_size: int = 3, stream_index: Optional[int] = None, filter_desc: Optional[str] = None, + decoder: Optional[str] = None, + decoder_options: Optional[Dict[str, str]] = None, ): """Add output audio stream @@ -374,10 +376,22 @@ def add_audio_stream( The list of available filters can be found at https://ffmpeg.org/ffmpeg-filters.html Note that complex filters are not supported. + + decoder (str or None, optional): The name of the decoder to be used. + When provided, use the specified decoder instead of the default one. + + decoder_options (dict or None, optional): Options passed to decoder. + Mapping from str to str. """ i = self.default_audio_stream if stream_index is None else stream_index torch.ops.torchaudio.ffmpeg_streamer_add_audio_stream( - self._s, i, frames_per_chunk, buffer_chunk_size, filter_desc + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + filter_desc, + decoder, + decoder_options, ) def add_video_stream( @@ -386,6 +400,8 @@ def add_video_stream( buffer_chunk_size: int = 3, stream_index: Optional[int] = None, filter_desc: Optional[str] = None, + decoder: Optional[str] = None, + decoder_options: Optional[Dict[str, str]] = None, ): """Add output video stream @@ -406,10 +422,22 @@ def add_video_stream( The list of available filters can be found at https://ffmpeg.org/ffmpeg-filters.html Note that complex filters are not supported. + + decoder (str or None, optional): The name of the decoder to be used. + When provided, use the specified decoder instead of the default one. + + decoder_options (dict or None, optional): Options passed to decoder. + Mapping from str to str. """ i = self.default_video_stream if stream_index is None else stream_index torch.ops.torchaudio.ffmpeg_streamer_add_video_stream( - self._s, i, frames_per_chunk, buffer_chunk_size, filter_desc + self._s, + i, + frames_per_chunk, + buffer_chunk_size, + filter_desc, + decoder, + decoder_options, ) def remove_stream(self, i: int):