- 
                Notifications
    You must be signed in to change notification settings 
- Fork 67
BETA CUDA interface: support for approximate mode and time-based APIs #917
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
78ab058
              b45decc
              316f218
              d0192ec
              515deb5
              13fad10
              eb8de72
              4f7a4fb
              dcf3124
              0ad7370
              aad142e
              2592888
              b5fe9bc
              5605c90
              7494259
              560b376
              88196c5
              2a78b84
              5d194e5
              d1e51b3
              f9c7297
              b7bbfb2
              390fd7c
              f55dcc0
              7e4dd10
              f614846
              aa6e253
              186eaa4
              1cb4890
              c5b32a4
              70873bf
              799f1dd
              8cc80e5
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -35,16 +35,20 @@ static bool g_cuda_beta = registerDeviceInterface( | |||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| static int CUDAAPI | ||||||||||||||||||||||||||||||||||||||||||||
| pfnSequenceCallback(void* pUserData, CUVIDEOFORMAT* videoFormat) { | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface* decoder = | ||||||||||||||||||||||||||||||||||||||||||||
| static_cast<BetaCudaDeviceInterface*>(pUserData); | ||||||||||||||||||||||||||||||||||||||||||||
| auto decoder = static_cast<BetaCudaDeviceInterface*>(pUserData); | ||||||||||||||||||||||||||||||||||||||||||||
| return decoder->streamPropertyChange(videoFormat); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| static int CUDAAPI | ||||||||||||||||||||||||||||||||||||||||||||
| pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* pPicParams) { | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface* decoder = | ||||||||||||||||||||||||||||||||||||||||||||
| static_cast<BetaCudaDeviceInterface*>(pUserData); | ||||||||||||||||||||||||||||||||||||||||||||
| return decoder->frameReadyForDecoding(pPicParams); | ||||||||||||||||||||||||||||||||||||||||||||
| pfnDecodePictureCallback(void* pUserData, CUVIDPICPARAMS* picParams) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto decoder = static_cast<BetaCudaDeviceInterface*>(pUserData); | ||||||||||||||||||||||||||||||||||||||||||||
| return decoder->frameReadyForDecoding(picParams); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| static int CUDAAPI | ||||||||||||||||||||||||||||||||||||||||||||
| pfnDisplayPictureCallback(void* pUserData, CUVIDPARSERDISPINFO* dispInfo) { | ||||||||||||||||||||||||||||||||||||||||||||
| auto decoder = static_cast<BetaCudaDeviceInterface*>(pUserData); | ||||||||||||||||||||||||||||||||||||||||||||
| return decoder->frameReadyInDisplayOrder(dispInfo); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) { | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -142,7 +146,7 @@ BetaCudaDeviceInterface::BetaCudaDeviceInterface(const torch::Device& device) | |||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { | ||||||||||||||||||||||||||||||||||||||||||||
| // TODONVDEC P0: we probably need to free the frames that have been decoded by | ||||||||||||||||||||||||||||||||||||||||||||
| // NVDEC but not yet "mapped" - i.e. those that are still in frameBuffer_? | ||||||||||||||||||||||||||||||||||||||||||||
| // NVDEC but not yet "mapped" - i.e. those that are still in readyFrames_? | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| if (decoder_) { | ||||||||||||||||||||||||||||||||||||||||||||
| NVDECCache::getCache(device_.index()) | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -218,7 +222,7 @@ void BetaCudaDeviceInterface::initialize(const AVStream* avStream) { | |||||||||||||||||||||||||||||||||||||||||||
| parserParams.pUserData = this; | ||||||||||||||||||||||||||||||||||||||||||||
| parserParams.pfnSequenceCallback = pfnSequenceCallback; | ||||||||||||||||||||||||||||||||||||||||||||
| parserParams.pfnDecodePicture = pfnDecodePictureCallback; | ||||||||||||||||||||||||||||||||||||||||||||
| parserParams.pfnDisplayPicture = nullptr; | ||||||||||||||||||||||||||||||||||||||||||||
| parserParams.pfnDisplayPicture = pfnDisplayPictureCallback; | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| CUresult result = cuvidCreateVideoParser(&videoParser_, &parserParams); | ||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK( | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -274,10 +278,6 @@ int BetaCudaDeviceInterface::sendPacket(ReferenceAVPacket& packet) { | |||||||||||||||||||||||||||||||||||||||||||
| cuvidPacket.flags = CUVID_PKT_TIMESTAMP; | ||||||||||||||||||||||||||||||||||||||||||||
| cuvidPacket.timestamp = packet->pts; | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| // Like DALI: store packet PTS in queue to later assign to frames as they | ||||||||||||||||||||||||||||||||||||||||||||
| // come out | ||||||||||||||||||||||||||||||||||||||||||||
| packetsPtsQueue.push(packet->pts); | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||
| // End of stream packet | ||||||||||||||||||||||||||||||||||||||||||||
| cuvidPacket.flags = CUVID_PKT_ENDOFSTREAM; | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -329,70 +329,38 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) { | |||||||||||||||||||||||||||||||||||||||||||
| // ready to be decoded, i.e. the parser received all the necessary packets for a | ||||||||||||||||||||||||||||||||||||||||||||
| // given frame. It means we can send that frame to be decoded by the hardware | ||||||||||||||||||||||||||||||||||||||||||||
| // NVDEC decoder by calling cuvidDecodePicture which is non-blocking. | ||||||||||||||||||||||||||||||||||||||||||||
| int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* pPicParams) { | ||||||||||||||||||||||||||||||||||||||||||||
| int BetaCudaDeviceInterface::frameReadyForDecoding(CUVIDPICPARAMS* picParams) { | ||||||||||||||||||||||||||||||||||||||||||||
| if (isFlushing_) { | ||||||||||||||||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK(pPicParams != nullptr, "Invalid picture parameters"); | ||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK(picParams != nullptr, "Invalid picture parameters"); | ||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK(decoder_, "Decoder not initialized before picture decode"); | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| // Send frame to be decoded by NVDEC - non-blocking call. | ||||||||||||||||||||||||||||||||||||||||||||
| CUresult result = cuvidDecodePicture(*decoder_.get(), pPicParams); | ||||||||||||||||||||||||||||||||||||||||||||
| if (result != CUDA_SUCCESS) { | ||||||||||||||||||||||||||||||||||||||||||||
| return 0; // Yes, you're reading that right, 0 mean error. | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
| CUresult result = cuvidDecodePicture(*decoder_.get(), picParams); | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| // The frame was sent to be decoded on the NVDEC hardware. Now we store some | ||||||||||||||||||||||||||||||||||||||||||||
| // relevant info into our frame buffer so that we can retrieve the decoded | ||||||||||||||||||||||||||||||||||||||||||||
| // frame later when receiveFrame() is called. | ||||||||||||||||||||||||||||||||||||||||||||
| // Importantly we need to 'guess' the PTS of that frame. The heuristic we use | ||||||||||||||||||||||||||||||||||||||||||||
| // (like in DALI) is that the frames are ready to be decoded in the same order | ||||||||||||||||||||||||||||||||||||||||||||
| // as the packets were sent to the parser. So we assign the PTS of the frame | ||||||||||||||||||||||||||||||||||||||||||||
| // by popping the PTS of the oldest packet in our packetsPtsQueue (note: | ||||||||||||||||||||||||||||||||||||||||||||
| // oldest doesn't necessarily mean lowest PTS!). | ||||||||||||||||||||||||||||||||||||||||||||
| // Yes, you're reading that right, 0 means error, 1 means success | ||||||||||||||||||||||||||||||||||||||||||||
| return (result == CUDA_SUCCESS); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| TORCH_CHECK( | ||||||||||||||||||||||||||||||||||||||||||||
| // TODONVDEC P0 the queue may be empty, handle that. | ||||||||||||||||||||||||||||||||||||||||||||
| !packetsPtsQueue.empty(), | ||||||||||||||||||||||||||||||||||||||||||||
| "PTS queue is empty when decoding a frame"); | ||||||||||||||||||||||||||||||||||||||||||||
| int64_t guessedPts = packetsPtsQueue.front(); | ||||||||||||||||||||||||||||||||||||||||||||
| packetsPtsQueue.pop(); | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| // Field values taken from DALI | ||||||||||||||||||||||||||||||||||||||||||||
| CUVIDPARSERDISPINFO dispInfo = {}; | ||||||||||||||||||||||||||||||||||||||||||||
| dispInfo.picture_index = pPicParams->CurrPicIdx; | ||||||||||||||||||||||||||||||||||||||||||||
| dispInfo.progressive_frame = !pPicParams->field_pic_flag; | ||||||||||||||||||||||||||||||||||||||||||||
| dispInfo.top_field_first = pPicParams->bottom_field_flag ^ 1; | ||||||||||||||||||||||||||||||||||||||||||||
| dispInfo.repeat_first_field = 0; | ||||||||||||||||||||||||||||||||||||||||||||
| dispInfo.timestamp = guessedPts; | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| FrameBuffer::Slot* slot = frameBuffer_.findEmptySlot(); | ||||||||||||||||||||||||||||||||||||||||||||
| slot->dispInfo = dispInfo; | ||||||||||||||||||||||||||||||||||||||||||||
| slot->guessedPts = guessedPts; | ||||||||||||||||||||||||||||||||||||||||||||
| slot->occupied = true; | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| return 1; | ||||||||||||||||||||||||||||||||||||||||||||
| int BetaCudaDeviceInterface::frameReadyInDisplayOrder( | ||||||||||||||||||||||||||||||||||||||||||||
| CUVIDPARSERDISPINFO* dispInfo) { | ||||||||||||||||||||||||||||||||||||||||||||
| readyFrames_.push(*dispInfo); | ||||||||||||||||||||||||||||||||||||||||||||
| return 1; // success | ||||||||||||||||||||||||||||||||||||||||||||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To clarify for my understanding, when the  Are the function signatures for this and other callbacks defined somwhere in documentation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, your understanding is correct! The  torchcodec/src/torchcodec/_core/nvcuvid_include/nvcuvid.h Lines 501 to 509 in 6377dfc 
 
 torchcodec/src/torchcodec/_core/BetaCudaDeviceInterface.cpp Lines 512 to 513 in 6377dfc 
 
 Not in the docs, but in the headers: torchcodec/src/torchcodec/_core/nvcuvid_include/nvcuvid.h Lines 529 to 533 in 6377dfc 
 Strictly speaking, this is the callaback we're defining: torchcodec/src/torchcodec/_core/BetaCudaDeviceInterface.cpp Lines 49 to 53 in 6377dfc 
 It's a pure C function that calls the corresponding method on the Interface object. We have to do this gymnastic because the pure C callbacks have no notion of the Interface object. | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| // Moral equivalent of avcodec_receive_frame(). Here, we look for a decoded | ||||||||||||||||||||||||||||||||||||||||||||
| // frame with the exact desired PTS in our frame buffer. This logic is only | ||||||||||||||||||||||||||||||||||||||||||||
| // valid in exact seek_mode, for now. | ||||||||||||||||||||||||||||||||||||||||||||
| int BetaCudaDeviceInterface::receiveFrame( | ||||||||||||||||||||||||||||||||||||||||||||
| UniqueAVFrame& avFrame, | ||||||||||||||||||||||||||||||||||||||||||||
| int64_t desiredPts) { | ||||||||||||||||||||||||||||||||||||||||||||
| FrameBuffer::Slot* slot = frameBuffer_.findFrameWithExactPts(desiredPts); | ||||||||||||||||||||||||||||||||||||||||||||
| if (slot == nullptr) { | ||||||||||||||||||||||||||||||||||||||||||||
| // Moral equivalent of avcodec_receive_frame(). | ||||||||||||||||||||||||||||||||||||||||||||
| int BetaCudaDeviceInterface::receiveFrame(UniqueAVFrame& avFrame) { | ||||||||||||||||||||||||||||||||||||||||||||
| if (readyFrames_.empty()) { | ||||||||||||||||||||||||||||||||||||||||||||
| // No frame found, instruct caller to try again later after sending more | ||||||||||||||||||||||||||||||||||||||||||||
| // packets. | ||||||||||||||||||||||||||||||||||||||||||||
| return AVERROR(EAGAIN); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| slot->occupied = false; | ||||||||||||||||||||||||||||||||||||||||||||
| slot->guessedPts = -1; | ||||||||||||||||||||||||||||||||||||||||||||
| CUVIDPARSERDISPINFO dispInfo = readyFrames_.front(); | ||||||||||||||||||||||||||||||||||||||||||||
| readyFrames_.pop(); | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| CUVIDPROCPARAMS procParams = {}; | ||||||||||||||||||||||||||||||||||||||||||||
| CUVIDPARSERDISPINFO dispInfo = slot->dispInfo; | ||||||||||||||||||||||||||||||||||||||||||||
| procParams.progressive_frame = dispInfo.progressive_frame; | ||||||||||||||||||||||||||||||||||||||||||||
| procParams.top_field_first = dispInfo.top_field_first; | ||||||||||||||||||||||||||||||||||||||||||||
| procParams.unpaired_field = dispInfo.repeat_first_field < 0; | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -452,7 +420,7 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame( | |||||||||||||||||||||||||||||||||||||||||||
| avFrame->width = width; | ||||||||||||||||||||||||||||||||||||||||||||
| avFrame->height = height; | ||||||||||||||||||||||||||||||||||||||||||||
| avFrame->format = AV_PIX_FMT_CUDA; | ||||||||||||||||||||||||||||||||||||||||||||
| avFrame->pts = dispInfo.timestamp; // == guessedPts | ||||||||||||||||||||||||||||||||||||||||||||
| avFrame->pts = dispInfo.timestamp; | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| // TODONVDEC P0: Zero division error!!! | ||||||||||||||||||||||||||||||||||||||||||||
| // TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -518,13 +486,8 @@ void BetaCudaDeviceInterface::flush() { | |||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| isFlushing_ = false; | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| for (auto& slot : frameBuffer_) { | ||||||||||||||||||||||||||||||||||||||||||||
| slot.occupied = false; | ||||||||||||||||||||||||||||||||||||||||||||
| slot.guessedPts = -1; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| std::queue<int64_t> empty; | ||||||||||||||||||||||||||||||||||||||||||||
| packetsPtsQueue.swap(empty); | ||||||||||||||||||||||||||||||||||||||||||||
| std::queue<CUVIDPARSERDISPINFO> emptyQueue; | ||||||||||||||||||||||||||||||||||||||||||||
| std::swap(readyFrames_, emptyQueue); | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| eofSent_ = false; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | @@ -544,26 +507,4 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( | |||||||||||||||||||||||||||||||||||||||||||
| avFrame, frameOutput, preAllocatedOutputTensor); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface::FrameBuffer::Slot* | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface::FrameBuffer::findEmptySlot() { | ||||||||||||||||||||||||||||||||||||||||||||
| for (auto& slot : frameBuffer_) { | ||||||||||||||||||||||||||||||||||||||||||||
| if (!slot.occupied) { | ||||||||||||||||||||||||||||||||||||||||||||
| return &slot; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
| frameBuffer_.emplace_back(); | ||||||||||||||||||||||||||||||||||||||||||||
| return &frameBuffer_.back(); | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface::FrameBuffer::Slot* | ||||||||||||||||||||||||||||||||||||||||||||
| BetaCudaDeviceInterface::FrameBuffer::findFrameWithExactPts( | ||||||||||||||||||||||||||||||||||||||||||||
| int64_t desiredPts) { | ||||||||||||||||||||||||||||||||||||||||||||
| for (auto& slot : frameBuffer_) { | ||||||||||||||||||||||||||||||||||||||||||||
| if (slot.occupied && slot.guessedPts == desiredPts) { | ||||||||||||||||||||||||||||||||||||||||||||
| return &slot; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
| return nullptr; | ||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||
|  | ||||||||||||||||||||||||||||||||||||||||||||
| } // namespace facebook::torchcodec | ||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the key difference, correct? That is, by registering this callback, we get the new behavior and can delete all of the relevant code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes that's correct