@@ -426,35 +426,36 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
426426 dst = allocateEmptyHWCTensor (frameDims, device_);
427427 }
428428
429- torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex (device_);
430-
431- // Create a CUDA event and attach it to the AVFrame's CUDA stream. That's the
432- // NVDEC stream, i.e. the CUDA stream that the frame was decoded on.
433- // We will be waiting for this event to complete before calling the NPP
434- // functions, to ensure NVDEC has finished decoding the frame before running
435- // the NPP color-conversion.
436- // Note that our code is generic and assumes that the NVDEC's stream can be
437- // arbitrary, but unfortunately we know it's hardcoded to be the default
438- // stream by FFmpeg:
429+ // We need to make sure NVDEC has finished decoding a frame before
430+ // color-converting it with NPP.
431+ // So we make the NPP stream wait for NVDEC to finish.
432+ // If we're in the default CUDA interface, we figure out the NVDEC stream from
433+ // the avFrame's hardware context. But in reality, we know that this stream is
434+ // hardcoded to be the default stream by FFmpeg:
439435 // https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
440- at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream (deviceIndex);
436+ // If we're in the BETA CUDA interface, we know the NVDEC stream was set with
437+ // getCurrentCUDAStream(), so it's the same as the nppStream.
438+ at::cuda::CUDAStream nppStream =
439+ at::cuda::getCurrentCUDAStream (device_.index ());
440+ // We can't create a CUDAStream without assigning it a value so we initialize
441+ // it to the nppStream, which is valid for the BETA interface.
442+ at::cuda::CUDAStream nvdecStream = nppStream;
441443 if (hwFramesCtx) {
442- // TODONVDEC P2 this block won't be hit from the beta interface because
443- // there is no hwFramesCtx, but we should still make sure there's no CUDA
444- // stream sync issue in the beta interface.
444+ // Default interface path
445445 TORCH_CHECK (
446446 hwFramesCtx->device_ctx != nullptr ,
447447 " The AVFrame's hw_frames_ctx does not have a device_ctx. " );
448448 auto cudaDeviceCtx =
449449 static_cast <AVCUDADeviceContext*>(hwFramesCtx->device_ctx ->hwctx );
450450 TORCH_CHECK (cudaDeviceCtx != nullptr , " The hardware context is null" );
451- at::cuda::CUDAEvent nvdecDoneEvent;
452- at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
453- c10::cuda::getStreamFromExternal (cudaDeviceCtx->stream , deviceIndex);
454- nvdecDoneEvent.record (nvdecStream);
455- // Don't start NPP work before NVDEC is done decoding the frame!
456- nvdecDoneEvent.block (nppStream);
451+ nvdecStream = // That's always the default stream. Sad.
452+ c10::cuda::getStreamFromExternal (
453+ cudaDeviceCtx->stream , device_.index ());
457454 }
455+ // Don't start NPP work before NVDEC is done decoding the frame!
456+ at::cuda::CUDAEvent nvdecDoneEvent;
457+ nvdecDoneEvent.record (nvdecStream);
458+ nvdecDoneEvent.block (nppStream);
458459
459460 // Create the NPP context if we haven't yet.
460461 nppCtx_->hStream = nppStream.stream ();
0 commit comments