Skip to content

Speed-up time-based samplers by 20X and index-based by 1.5X #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
823c8a3
Let get_frames_at_indices op return a 3-tuple instead of single Tensor
NicolasHug Oct 22, 2024
61b4937
Add deduplication logic
NicolasHug Oct 22, 2024
f7a70ba
Added sorting logic
NicolasHug Oct 22, 2024
133c213
minor opt
NicolasHug Oct 22, 2024
f391582
Comments
NicolasHug Oct 22, 2024
d475890
scaffolding
NicolasHug Oct 22, 2024
14e2876
Added logic
NicolasHug Oct 22, 2024
9c9e462
Added test
NicolasHug Oct 22, 2024
2bce920
Use C++ decoding APIs in sampler
NicolasHug Oct 22, 2024
b8284cc
Remove parameter, just sort if not already sorted
NicolasHug Oct 22, 2024
be80996
Remove parameter
NicolasHug Oct 22, 2024
12c0e29
Merge branch 'main' of github.com:pytorch/torchcodec into pts_sort_an…
NicolasHug Oct 23, 2024
4dda5b7
Rename
NicolasHug Oct 23, 2024
7d26623
Fix last frame request
NicolasHug Oct 23, 2024
3a8839d
Better fix
NicolasHug Oct 23, 2024
5e114f2
Merge branch 'pts_sort_and_dedup' into samplers_hack
NicolasHug Oct 23, 2024
a76a6ad
Clean up
NicolasHug Oct 23, 2024
1482529
Frame and FrameBatch improvements
NicolasHug Oct 23, 2024
4661237
Fix mypy?
NicolasHug Oct 23, 2024
d43dd91
Use timestamps as parameter name
NicolasHug Oct 23, 2024
1c3a736
Merge branch 'frame_improvements' into samplers_hack
NicolasHug Oct 23, 2024
f2feab9
better
NicolasHug Oct 23, 2024
92b7954
minor refac
NicolasHug Oct 23, 2024
8dd9b0a
Address comments
NicolasHug Oct 23, 2024
046580b
Merge branch 'pts_sort_and_dedup' into samplers_hack
NicolasHug Oct 23, 2024
3268606
fix
NicolasHug Oct 23, 2024
1772d46
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_hack
NicolasHug Oct 23, 2024
c03294b
Fix binary search of getFramesDisplayedByTimestamps
NicolasHug Oct 24, 2024
5ab33b9
Comment
NicolasHug Oct 24, 2024
fa374bc
comment
NicolasHug Oct 24, 2024
dce5876
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_hack
NicolasHug Oct 24, 2024
efa1d81
Merge branch 'fix_pts' into samplers_hack
NicolasHug Oct 24, 2024
c75417b
Nits
NicolasHug Oct 24, 2024
978a996
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_hack
NicolasHug Oct 24, 2024
da59be4
merge fixes
NicolasHug Oct 24, 2024
ad1dd3a
Simplify even further
NicolasHug Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/torchcodec/samplers/_common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Union

import torch
from torchcodec import Frame, FrameBatch
from torch import Tensor
from torchcodec import FrameBatch

_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]]

Expand Down Expand Up @@ -42,22 +42,6 @@ def _error_policy(
}


def _chunk_list(lst, chunk_size):
# return list of sublists of length chunk_size
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]


def _to_framebatch(frames: list[Frame]) -> FrameBatch:
# IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and
# _decode_all_clips_timestamps
data = torch.stack([frame.data for frame in frames])
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames])
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames])
return FrameBatch(
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds
)


def _validate_common_params(*, decoder, num_frames_per_clip, policy):
if len(decoder) < 1:
raise ValueError(
Expand All @@ -72,3 +56,19 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy):
raise ValueError(
f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}."
)


def _make_5d_framebatch(
*,
data: Tensor,
pts_seconds: Tensor,
duration_seconds: Tensor,
num_clips: int,
num_frames_per_clip: int,
) -> FrameBatch:
last_3_dims = data.shape[-3:]
return FrameBatch(
data=data.view(num_clips, num_frames_per_clip, *last_3_dims),
pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip),
duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip),
)
74 changes: 19 additions & 55 deletions src/torchcodec/samplers/_index_based.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from typing import List, Literal, Optional
from typing import Literal, Optional

import torch

from torchcodec import Frame, FrameBatch
from torchcodec import FrameBatch
from torchcodec.decoders import VideoDecoder
from torchcodec.decoders._core import get_frames_at_indices
from torchcodec.samplers._common import (
_chunk_list,
_make_5d_framebatch,
_POLICY_FUNCTION_TYPE,
_POLICY_FUNCTIONS,
_to_framebatch,
_validate_common_params,
)

Expand Down Expand Up @@ -117,51 +117,6 @@ def _build_all_clips_indices(
return all_clips_indices


def _decode_all_clips_indices(
decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int
) -> list[FrameBatch]:
# This takes the list of all the frames to decode (in arbitrary order),
# decode all the frames, and then packs them into clips of length
# num_frames_per_clip.
#
# To avoid backwards seeks (which are slow), we:
# - sort all the frame indices to be decoded
# - dedup them
# - decode all unique frames in sorted order
# - re-assemble the decoded frames back to their original order
#
# TODO: Write this in C++ so we can avoid the copies that happen in `_to_framebatch`

all_clips_indices_sorted, argsort = zip(
*sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices))
)
previous_decoded_frame = None
all_decoded_frames = [None] * len(all_clips_indices)
for i, j in enumerate(argsort):
frame_index = all_clips_indices_sorted[i]
if (
previous_decoded_frame is not None # then we know i > 0
and frame_index == all_clips_indices_sorted[i - 1]
):
# Avoid decoding the same frame twice.
# IMPORTANT: this is only correct because a copy of the frame will
# happen within `_to_framebatch` when we call torch.stack.
# If a copy isn't made, the same underlying memory will be used for
# the 2 consecutive frames. When we re-write this, we should make
# sure to explicitly copy the data.
decoded_frame = previous_decoded_frame
else:
decoded_frame = decoder.get_frame_at(index=frame_index)
previous_decoded_frame = decoded_frame
all_decoded_frames[j] = decoded_frame

all_clips: list[list[Frame]] = _chunk_list(
all_decoded_frames, chunk_size=num_frames_per_clip
)

return [_to_framebatch(clip) for clip in all_clips]


def _generic_index_based_sampler(
kind: Literal["random", "regular"],
decoder: VideoDecoder,
Expand All @@ -174,7 +129,7 @@ def _generic_index_based_sampler(
# Important note: sampling_range_end defines the upper bound of where a clip
# can *start*, not where a clip can end.
policy: Literal["repeat_last", "wrap", "error"],
) -> List[FrameBatch]:
) -> FrameBatch:

_validate_common_params(
decoder=decoder,
Expand Down Expand Up @@ -221,9 +176,18 @@ def _generic_index_based_sampler(
num_frames_in_video=len(decoder),
policy_fun=_POLICY_FUNCTIONS[policy],
)
return _decode_all_clips_indices(
decoder,
all_clips_indices=all_clips_indices,

# TODO: Use public method of decoder, when it exists
frames, pts_seconds, duration_seconds = get_frames_at_indices(
decoder._decoder,
stream_index=decoder.stream_index,
frame_indices=all_clips_indices,
)
return _make_5d_framebatch(
data=frames,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
)

Expand All @@ -237,7 +201,7 @@ def clips_at_random_indices(
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> List[FrameBatch]:
) -> FrameBatch:
return _generic_index_based_sampler(
kind="random",
decoder=decoder,
Expand All @@ -259,7 +223,7 @@ def clips_at_regular_indices(
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
policy: Literal["repeat_last", "wrap", "error"] = "repeat_last",
) -> List[FrameBatch]:
) -> FrameBatch:

return _generic_index_based_sampler(
kind="regular",
Expand Down
75 changes: 19 additions & 56 deletions src/torchcodec/samplers/_time_based.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import List, Literal, Optional
from typing import Literal, Optional

import torch

from torchcodec import Frame, FrameBatch
from torchcodec.decoders import VideoDecoder
from torchcodec import FrameBatch
from torchcodec.decoders._core import get_frames_by_pts
from torchcodec.samplers._common import (
_chunk_list,
_make_5d_framebatch,
_POLICY_FUNCTION_TYPE,
_POLICY_FUNCTIONS,
_to_framebatch,
_validate_common_params,
)

Expand Down Expand Up @@ -147,51 +146,6 @@ def _build_all_clips_timestamps(
return all_clips_timestamps


def _decode_all_clips_timestamps(
decoder: VideoDecoder, all_clips_timestamps: list[float], num_frames_per_clip: int
) -> list[FrameBatch]:
# This is 99% the same as _decode_all_clips_indices. The only change is the
# call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx)

all_clips_timestamps_sorted, argsort = zip(
*sorted(
(frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps)
)
)
previous_decoded_frame = None
all_decoded_frames = [None] * len(all_clips_timestamps)
for i, j in enumerate(argsort):
frame_pts_seconds = all_clips_timestamps_sorted[i]
if (
previous_decoded_frame is not None # then we know i > 0
and frame_pts_seconds == all_clips_timestamps_sorted[i - 1]
):
# Avoid decoding the same frame twice.
# Unfortunatly this is unlikely to lead to speed-up as-is: it's
# pretty unlikely that 2 pts will be the same since pts are float
# contiguous values. Theoretically the dedup can still happen, but
# it would be much more efficient to implement it at the frame index
# level. We should do that once we implement that in C++.
# See also https://github.com/pytorch/torchcodec/issues/256.
#
# IMPORTANT: this is only correct because a copy of the frame will
# happen within `_to_framebatch` when we call torch.stack.
# If a copy isn't made, the same underlying memory will be used for
# the 2 consecutive frames. When we re-write this, we should make
# sure to explicitly copy the data.
decoded_frame = previous_decoded_frame
else:
decoded_frame = decoder.get_frame_displayed_at(seconds=frame_pts_seconds)
previous_decoded_frame = decoded_frame
all_decoded_frames[j] = decoded_frame

all_clips: list[list[Frame]] = _chunk_list(
all_decoded_frames, chunk_size=num_frames_per_clip
)

return [_to_framebatch(clip) for clip in all_clips]


def _generic_time_based_sampler(
kind: Literal["random", "regular"],
decoder,
Expand All @@ -204,7 +158,7 @@ def _generic_time_based_sampler(
sampling_range_start: Optional[float],
sampling_range_end: Optional[float], # interval is [start, end).
policy: str = "repeat_last",
) -> List[FrameBatch]:
) -> FrameBatch:
# Note: *everywhere*, sampling_range_end denotes the upper bound of where a
# clip can start. This is an *open* upper bound, i.e. we will make sure no
# clip starts exactly at (or above) sampling_range_end.
Expand Down Expand Up @@ -246,6 +200,7 @@ def _generic_time_based_sampler(
sampling_range_end, # excluded
seconds_between_clip_starts,
)
num_clips = len(clip_start_seconds)

all_clips_timestamps = _build_all_clips_timestamps(
clip_start_seconds=clip_start_seconds,
Expand All @@ -255,9 +210,17 @@ def _generic_time_based_sampler(
policy_fun=_POLICY_FUNCTIONS[policy],
)

return _decode_all_clips_timestamps(
decoder,
all_clips_timestamps=all_clips_timestamps,
# TODO: Use public method of decoder, when it exists
frames, pts_seconds, duration_seconds = get_frames_by_pts(
decoder._decoder,
stream_index=decoder.stream_index,
timestamps=all_clips_timestamps,
)
return _make_5d_framebatch(
data=frames,
pts_seconds=pts_seconds,
duration_seconds=duration_seconds,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
)

Expand All @@ -272,7 +235,7 @@ def clips_at_random_timestamps(
sampling_range_start: Optional[float] = None,
sampling_range_end: Optional[float] = None, # interval is [start, end).
policy: str = "repeat_last",
) -> List[FrameBatch]:
) -> FrameBatch:
return _generic_time_based_sampler(
kind="random",
decoder=decoder,
Expand All @@ -296,7 +259,7 @@ def clips_at_regular_timestamps(
sampling_range_start: Optional[float] = None,
sampling_range_end: Optional[float] = None, # interval is [start, end).
policy: str = "repeat_last",
) -> List[FrameBatch]:
) -> FrameBatch:
return _generic_time_based_sampler(
kind="regular",
decoder=decoder,
Expand Down
Loading
Loading