-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify output tensor allocation across all callers and devices #321
Changes from all commits
7c37e9a
f273b52
639952a
3ae34c7
8019219
dc40ad6
6bd363d
0b6590d
1c01ec7
9efa08f
3ab551e
0568482
649401e
7cf3a3e
8c11e43
5c6e23c
fcbe4a3
7322f9e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,18 +187,34 @@ VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( | |
} | ||
} | ||
|
||
VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( | ||
int64_t numFrames, | ||
const VideoStreamDecoderOptions& options, | ||
const StreamMetadata& metadata) | ||
: frames(torch::empty( | ||
{numFrames, | ||
options.height.value_or(*metadata.height), | ||
options.width.value_or(*metadata.width), | ||
3}, | ||
at::TensorOptions(options.device).dtype(torch::kUInt8))), | ||
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), | ||
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {} | ||
torch::Tensor VideoDecoder::allocateEmptyHWCTensorForStream( | ||
int streamIndex, | ||
std::optional<int> numFrames) { | ||
auto metadata = containerMetadata_.streams[streamIndex]; | ||
auto options = streams_[streamIndex].options; | ||
auto height = options.height.value_or(*metadata.height); | ||
auto width = options.width.value_or(*metadata.width); | ||
|
||
auto tensorOptions = torch::TensorOptions() | ||
.dtype(torch::kUInt8) | ||
.layout(torch::kStrided) | ||
.device(options.device.type()); | ||
if (numFrames.has_value()) { | ||
return torch::empty({numFrames.value(), height, width, 3}, tensorOptions); | ||
} else { | ||
return torch::empty({height, width, 3}, tensorOptions); | ||
} | ||
} | ||
|
||
VideoDecoder::BatchDecodedOutput VideoDecoder::allocateBatchDecodedOutput( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous So I had to make the "constructor" a member itself. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that but by making it a member we are letting it access a lot more information than it strictly needs. It only needs height, width, device, and not all the other stuff in VideoDecoder. What's worse is it can accidentally change the state of the VideoDecoder because it's not I think it may be better to pass in the height and width after looking it up in the caller. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A general principle I like to apply is that classes should know how to create correct objects, and ideally, constructed objects should be valid. That means classes should have constructors, and the construtor is responsible for making the object valid. So I think we should keep the We may need a
We can take advantage of tensors have reasonable move-like semantics. |
||
int streamIndex, | ||
int64_t numFrames) { | ||
BatchDecodedOutput output; | ||
output.frames = allocateEmptyHWCTensorForStream(streamIndex, numFrames); | ||
output.ptsSeconds = torch::empty({numFrames}, {torch::kFloat64}); | ||
output.durationSeconds = torch::empty({numFrames}, {torch::kFloat64}); | ||
return output; | ||
} | ||
|
||
VideoDecoder::VideoDecoder() {} | ||
|
||
|
@@ -841,7 +857,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getDecodedOutputWithFilter( | |
|
||
VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | ||
VideoDecoder::RawDecodedOutput& rawOutput, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
torch::Tensor preAllocatedOutputTensor) { | ||
// Convert the frame to tensor. | ||
DecodedOutput output; | ||
int streamIndex = rawOutput.streamIndex; | ||
|
@@ -875,7 +891,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | |
} | ||
|
||
// Note [preAllocatedOutputTensor with swscale and filtergraph]: | ||
// Callers may pass a pre-allocated tensor, where the output frame tensor will | ||
// Callers must pass a pre-allocated tensor, where the output frame tensor will | ||
// be stored. This parameter is honored in any case, but it only leads to a | ||
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the | ||
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet | ||
|
@@ -886,50 +902,25 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | |
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | ||
VideoDecoder::RawDecodedOutput& rawOutput, | ||
DecodedOutput& output, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
torch::Tensor preAllocatedOutputTensor) { | ||
int streamIndex = rawOutput.streamIndex; | ||
AVFrame* frame = rawOutput.frame.get(); | ||
auto& streamInfo = streams_[streamIndex]; | ||
torch::Tensor tensor; | ||
if (output.streamType == AVMEDIA_TYPE_VIDEO) { | ||
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { | ||
int width = streamInfo.options.width.value_or(frame->width); | ||
int height = streamInfo.options.height.value_or(frame->height); | ||
if (preAllocatedOutputTensor.has_value()) { | ||
tensor = preAllocatedOutputTensor.value(); | ||
auto shape = tensor.sizes(); | ||
TORCH_CHECK( | ||
(shape.size() == 3) && (shape[0] == height) && | ||
(shape[1] == width) && (shape[2] == 3), | ||
"Expected tensor of shape ", | ||
height, | ||
"x", | ||
width, | ||
"x3, got ", | ||
shape); | ||
} else { | ||
tensor = torch::empty( | ||
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); | ||
} | ||
rawOutput.data = tensor.data_ptr<uint8_t>(); | ||
rawOutput.data = preAllocatedOutputTensor.data_ptr<uint8_t>(); | ||
convertFrameToBufferUsingSwsScale(rawOutput); | ||
|
||
output.frame = tensor; | ||
} else if ( | ||
streamInfo.colorConversionLibrary == | ||
ColorConversionLibrary::FILTERGRAPH) { | ||
tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); | ||
if (preAllocatedOutputTensor.has_value()) { | ||
preAllocatedOutputTensor.value().copy_(tensor); | ||
output.frame = preAllocatedOutputTensor.value(); | ||
} else { | ||
output.frame = tensor; | ||
} | ||
auto tmpTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); | ||
preAllocatedOutputTensor.copy_(tmpTensor); | ||
} else { | ||
throw std::runtime_error( | ||
"Invalid color conversion library: " + | ||
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary))); | ||
} | ||
output.frame = preAllocatedOutputTensor; | ||
|
||
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) { | ||
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement | ||
|
@@ -971,8 +962,11 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( | |
return seconds >= frameStartTime && seconds < frameEndTime; | ||
}); | ||
// Convert the frame to tensor. | ||
auto output = convertAVFrameToDecodedOutput(rawOutput); | ||
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); | ||
auto streamIndex = rawOutput.streamIndex; | ||
auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the light of this: #312, it would be more correct to use rawOutput.frame's dimensions, and not the stream's dimensions |
||
auto output = | ||
convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); | ||
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); | ||
return output; | ||
} | ||
|
||
|
@@ -1009,15 +1003,17 @@ void VideoDecoder::validateFrameIndex( | |
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( | ||
int streamIndex, | ||
int64_t frameIndex) { | ||
auto output = getFrameAtIndexInternal(streamIndex, frameIndex); | ||
auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex); | ||
auto output = getFrameAtIndexInternal( | ||
streamIndex, frameIndex, preAllocatedOutputTensor); | ||
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); | ||
return output; | ||
} | ||
|
||
VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( | ||
int streamIndex, | ||
int64_t frameIndex, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
torch::Tensor preAllocatedOutputTensor) { | ||
validateUserProvidedStreamIndex(streamIndex); | ||
validateScannedAllStreams("getFrameAtIndex"); | ||
|
||
|
@@ -1057,7 +1053,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( | |
const auto& streamMetadata = containerMetadata_.streams[streamIndex]; | ||
const auto& stream = streams_[streamIndex]; | ||
const auto& options = stream.options; | ||
BatchDecodedOutput output(frameIndices.size(), options, streamMetadata); | ||
BatchDecodedOutput output = | ||
allocateBatchDecodedOutput(streamIndex, frameIndices.size()); | ||
|
||
auto previousIndexInVideo = -1; | ||
for (auto f = 0; f < frameIndices.size(); ++f) { | ||
|
@@ -1149,8 +1146,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( | |
step > 0, "Step must be greater than 0; is " + std::to_string(step)); | ||
|
||
int64_t numOutputFrames = std::ceil((stop - start) / double(step)); | ||
const auto& options = stream.options; | ||
BatchDecodedOutput output(numOutputFrames, options, streamMetadata); | ||
BatchDecodedOutput output = | ||
allocateBatchDecodedOutput(streamIndex, numOutputFrames); | ||
|
||
for (int64_t i = start, f = 0; i < stop; i += step, ++f) { | ||
DecodedOutput singleOut = | ||
|
@@ -1189,9 +1186,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange( | |
"; must be less than or equal to " + std::to_string(maxSeconds) + | ||
")."); | ||
|
||
const auto& stream = streams_[streamIndex]; | ||
const auto& options = stream.options; | ||
|
||
// Special case needed to implement a half-open range. At first glance, this | ||
// may seem unnecessary, as our search for stopFrame can return the end, and | ||
// we don't include stopFramIndex in our output. However, consider the | ||
|
@@ -1210,7 +1204,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( | |
// values of the intervals will map to the same frame indices below. Hence, we | ||
// need this special case below. | ||
if (startSeconds == stopSeconds) { | ||
BatchDecodedOutput output(0, options, streamMetadata); | ||
BatchDecodedOutput output = allocateBatchDecodedOutput(streamIndex, 0); | ||
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); | ||
return output; | ||
} | ||
|
@@ -1226,6 +1220,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( | |
// 2. In order to establish if the start of an interval maps to a particular | ||
// frame, we need to figure out if it is ordered after the frame's pts, but | ||
// before the next frames's pts. | ||
const auto& stream = streams_[streamIndex]; | ||
auto startFrame = std::lower_bound( | ||
stream.allFrames.begin(), | ||
stream.allFrames.end(), | ||
|
@@ -1245,7 +1240,8 @@ VideoDecoder::getFramesPlayedByTimestampInRange( | |
int64_t startFrameIndex = startFrame - stream.allFrames.begin(); | ||
int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); | ||
int64_t numFrames = stopFrameIndex - startFrameIndex; | ||
BatchDecodedOutput output(numFrames, options, streamMetadata); | ||
BatchDecodedOutput output = | ||
allocateBatchDecodedOutput(streamIndex, numFrames); | ||
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { | ||
DecodedOutput singleOut = | ||
getFrameAtIndexInternal(streamIndex, i, output.frames[f]); | ||
|
@@ -1267,13 +1263,17 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { | |
} | ||
|
||
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { | ||
auto output = getNextFrameOutputNoDemuxInternal(); | ||
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); | ||
auto rawOutput = getNextRawDecodedOutputNoDemux(); | ||
auto streamIndex = rawOutput.streamIndex; | ||
auto preAllocatedOutputTensor = allocateEmptyHWCTensorForStream(streamIndex); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto |
||
auto output = | ||
convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); | ||
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); | ||
return output; | ||
} | ||
|
||
VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
torch::Tensor preAllocatedOutputTensor) { | ||
auto rawOutput = getNextRawDecodedOutputNoDemux(); | ||
return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -157,8 +157,6 @@ class VideoDecoder { | |
int streamIndex, | ||
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions()); | ||
|
||
torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); | ||
|
||
// ---- SINGLE FRAME SEEK AND DECODING API ---- | ||
// Places the cursor at the first frame on or after the position in seconds. | ||
// Calling getNextFrameOutputNoDemuxInternal() will return the first frame at | ||
|
@@ -232,17 +230,16 @@ class VideoDecoder { | |
DecodedOutput getFrameAtIndexInternal( | ||
int streamIndex, | ||
int64_t frameIndex, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
torch::Tensor preAllocatedOutputTensor); | ||
|
||
struct BatchDecodedOutput { | ||
torch::Tensor frames; | ||
torch::Tensor ptsSeconds; | ||
torch::Tensor durationSeconds; | ||
|
||
explicit BatchDecodedOutput( | ||
int64_t numFrames, | ||
const VideoStreamDecoderOptions& options, | ||
const StreamMetadata& metadata); | ||
}; | ||
BatchDecodedOutput allocateBatchDecodedOutput( | ||
int streamIndex, | ||
int64_t numFrames); | ||
// Returns frames at the given indices for a given stream as a single stacked | ||
// Tensor. | ||
BatchDecodedOutput getFramesAtIndices( | ||
|
@@ -301,6 +298,14 @@ class VideoDecoder { | |
|
||
double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); | ||
|
||
// -------------------------------------------------------------------------- | ||
// Tensor (frames) manipulation APIs | ||
// -------------------------------------------------------------------------- | ||
torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't need to be in the header at all. This should just be a utility function in an anonymous namespace in the cpp file. You can pass it the height and width. Same goes for function below. |
||
torch::Tensor allocateEmptyHWCTensorForStream( | ||
int streamIndex, | ||
std::optional<int> numFrames = std::nullopt); | ||
|
||
private: | ||
struct FrameInfo { | ||
int64_t pts = 0; | ||
|
@@ -385,14 +390,14 @@ class VideoDecoder { | |
void convertFrameToBufferUsingSwsScale(RawDecodedOutput& rawOutput); | ||
DecodedOutput convertAVFrameToDecodedOutput( | ||
RawDecodedOutput& rawOutput, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
torch::Tensor preAllocatedOutputTensor); | ||
void convertAVFrameToDecodedOutputOnCPU( | ||
RawDecodedOutput& rawOutput, | ||
DecodedOutput& output, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
torch::Tensor preAllocatedOutputTensor); | ||
|
||
DecodedOutput getNextFrameOutputNoDemuxInternal( | ||
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); | ||
torch::Tensor preAllocatedOutputTensor); | ||
|
||
DecoderOptions options_; | ||
ContainerMetadata containerMetadata_; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check this:
To allocate a tensor, we need its height and width. To get those, we need a streamIndex[1]. To get the streamIndex, we need either:
RawOutput
.[1] well, not strictly true, there are some cases where it's not needed, but the whole point of this PR is to unify that logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the most generic case this isn't true because a single stream can have frames with different heights and widths:
#312
But ignoring that fact, to get the frame dimensions before decoding you do need the stream_index