-
Notifications
You must be signed in to change notification settings - Fork 33
Pass pre-allocate tensors in batch APIs to avoid copies #266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8e06aa6
025bf27
f83ada9
72717bd
291bc87
887ae42
9418cb3
6b3da59
6a2190c
c8f2e79
5113b9c
9387537
bcb4e50
5db658e
e23acb7
96deb24
c2f2e59
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -847,7 +847,8 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( | |||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | ||||||||||||||||||||||||||||||||||
VideoDecoder::RawDecodedOutput& rawOutput) { | ||||||||||||||||||||||||||||||||||
VideoDecoder::RawDecodedOutput& rawOutput, | ||||||||||||||||||||||||||||||||||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||||||||||||||||||||||||||||||||||
// Convert the frame to tensor. | ||||||||||||||||||||||||||||||||||
DecodedOutput output; | ||||||||||||||||||||||||||||||||||
int streamIndex = rawOutput.streamIndex; | ||||||||||||||||||||||||||||||||||
|
@@ -862,8 +863,10 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | |||||||||||||||||||||||||||||||||
output.durationSeconds = ptsToSeconds( | ||||||||||||||||||||||||||||||||||
getDuration(frame), formatContext_->streams[streamIndex]->time_base); | ||||||||||||||||||||||||||||||||||
if (streamInfo.options.device.type() == torch::kCPU) { | ||||||||||||||||||||||||||||||||||
convertAVFrameToDecodedOutputOnCPU(rawOutput, output); | ||||||||||||||||||||||||||||||||||
convertAVFrameToDecodedOutputOnCPU( | ||||||||||||||||||||||||||||||||||
rawOutput, output, preAllocatedOutputTensor); | ||||||||||||||||||||||||||||||||||
} else if (streamInfo.options.device.type() == torch::kCUDA) { | ||||||||||||||||||||||||||||||||||
// TODO: handle pre-allocated output tensor | ||||||||||||||||||||||||||||||||||
convertAVFrameToDecodedOutputOnCuda( | ||||||||||||||||||||||||||||||||||
streamInfo.options.device, | ||||||||||||||||||||||||||||||||||
streamInfo.options, | ||||||||||||||||||||||||||||||||||
|
@@ -879,16 +882,24 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | ||||||||||||||||||||||||||||||||||
VideoDecoder::RawDecodedOutput& rawOutput, | ||||||||||||||||||||||||||||||||||
DecodedOutput& output) { | ||||||||||||||||||||||||||||||||||
DecodedOutput& output, | ||||||||||||||||||||||||||||||||||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||||||||||||||||||||||||||||||||||
int streamIndex = rawOutput.streamIndex; | ||||||||||||||||||||||||||||||||||
AVFrame* frame = rawOutput.frame.get(); | ||||||||||||||||||||||||||||||||||
auto& streamInfo = streams_[streamIndex]; | ||||||||||||||||||||||||||||||||||
if (output.streamType == AVMEDIA_TYPE_VIDEO) { | ||||||||||||||||||||||||||||||||||
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { | ||||||||||||||||||||||||||||||||||
int width = streamInfo.options.width.value_or(frame->width); | ||||||||||||||||||||||||||||||||||
int height = streamInfo.options.height.value_or(frame->height); | ||||||||||||||||||||||||||||||||||
torch::Tensor tensor = torch::empty( | ||||||||||||||||||||||||||||||||||
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); | ||||||||||||||||||||||||||||||||||
torch::Tensor tensor; | ||||||||||||||||||||||||||||||||||
if (preAllocatedOutputTensor.has_value()) { | ||||||||||||||||||||||||||||||||||
// TODO: check shape of preAllocatedOutputTensor? | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should TORCH_CHECK for height, width, shape, etc. here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have this a try thinking it would be a simple assert like assert `shape[-3] == H, W, 3` But it turns out it's not as simple. Some tensors come as HWC while some other come pas HWC. This is because the pre-allocated batched tensors are allocated as such: torchcodec/src/torchcodec/decoders/_core/VideoDecoder.cpp Lines 171 to 186 in c6a0a5a
It then me realize that everything works, but it's pretty magical. We end up doing the I want to fix this as an immediate follow-up if that's OK. I gave it a try here, but it's not trivial and it might be preferable not to overcomplexify this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm in favor of @NicolasHug suggestion. The logic he points out is legacy from way back when, and it wasn't necessarily throught through in terms of long term maintenance and code health. Always doing it one way, and then permuting as needed on the way out, sounds easier and cleaner. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good to me |
||||||||||||||||||||||||||||||||||
tensor = preAllocatedOutputTensor.value(); | ||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||
int width = streamInfo.options.width.value_or(frame->width); | ||||||||||||||||||||||||||||||||||
int height = streamInfo.options.height.value_or(frame->height); | ||||||||||||||||||||||||||||||||||
tensor = torch::empty( | ||||||||||||||||||||||||||||||||||
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
rawOutput.data = tensor.data_ptr<uint8_t>(); | ||||||||||||||||||||||||||||||||||
convertFrameToBufferUsingSwsScale(rawOutput); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
@@ -981,7 +992,8 @@ void VideoDecoder::validateFrameIndex( | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( | ||||||||||||||||||||||||||||||||||
int streamIndex, | ||||||||||||||||||||||||||||||||||
int64_t frameIndex) { | ||||||||||||||||||||||||||||||||||
int64_t frameIndex, | ||||||||||||||||||||||||||||||||||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||||||||||||||||||||||||||||||||||
validateUserProvidedStreamIndex(streamIndex); | ||||||||||||||||||||||||||||||||||
validateScannedAllStreams("getFrameAtIndex"); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
@@ -990,7 +1002,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
int64_t pts = stream.allFrames[frameIndex].pts; | ||||||||||||||||||||||||||||||||||
setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); | ||||||||||||||||||||||||||||||||||
return getNextDecodedOutputNoDemux(); | ||||||||||||||||||||||||||||||||||
return getNextDecodedOutputNoDemux(preAllocatedOutputTensor); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( | ||||||||||||||||||||||||||||||||||
|
@@ -1062,8 +1074,10 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( | |||||||||||||||||||||||||||||||||
BatchDecodedOutput output(numOutputFrames, options, streamMetadata); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
for (int64_t i = start, f = 0; i < stop; i += step, ++f) { | ||||||||||||||||||||||||||||||||||
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i); | ||||||||||||||||||||||||||||||||||
output.frames[f] = singleOut.frame; | ||||||||||||||||||||||||||||||||||
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); | ||||||||||||||||||||||||||||||||||
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { | ||||||||||||||||||||||||||||||||||
output.frames[f] = singleOut.frame; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
output.ptsSeconds[f] = singleOut.ptsSeconds; | ||||||||||||||||||||||||||||||||||
output.durationSeconds[f] = singleOut.durationSeconds; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
@@ -1155,8 +1169,10 @@ VideoDecoder::getFramesDisplayedByTimestampInRange( | |||||||||||||||||||||||||||||||||
int64_t numFrames = stopFrameIndex - startFrameIndex; | ||||||||||||||||||||||||||||||||||
BatchDecodedOutput output(numFrames, options, streamMetadata); | ||||||||||||||||||||||||||||||||||
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { | ||||||||||||||||||||||||||||||||||
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i); | ||||||||||||||||||||||||||||||||||
output.frames[f] = singleOut.frame; | ||||||||||||||||||||||||||||||||||
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); | ||||||||||||||||||||||||||||||||||
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { | ||||||||||||||||||||||||||||||||||
output.frames[f] = singleOut.frame; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
output.ptsSeconds[f] = singleOut.ptsSeconds; | ||||||||||||||||||||||||||||||||||
output.durationSeconds[f] = singleOut.durationSeconds; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
@@ -1173,9 +1189,10 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { | |||||||||||||||||||||||||||||||||
return rawOutput; | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux() { | ||||||||||||||||||||||||||||||||||
VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( | ||||||||||||||||||||||||||||||||||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||||||||||||||||||||||||||||||||||
auto rawOutput = getNextRawDecodedOutputNoDemux(); | ||||||||||||||||||||||||||||||||||
return convertAVFrameToDecodedOutput(rawOutput); | ||||||||||||||||||||||||||||||||||
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); | ||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
void VideoDecoder::setCursorPtsInSeconds(double seconds) { | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR is a no-op for CUDA devices. I'm leaving-out CUDA pre-allocation because this is strongly tied to #189 and can be treated separately.