Skip to content

Commit

Permalink
Fix int/float typing in video_utils.py (#8234)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
bryant1410 and NicolasHug authored Jan 31, 2024
1 parent 0be6c7e commit 806dba6
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class VideoClips:
video_paths (List[str]): paths to the video files
clip_length_in_frames (int): size of a clip in number of frames
frames_between_clips (int): step (in frames) between each clip
frame_rate (int, optional): if specified, it will resample the video
frame_rate (float, optional): if specified, it will resample the video
so that it has `frame_rate`, and then the clips will be defined
on the resampled video
num_workers (int): how many subprocesses to use for data loading.
Expand All @@ -102,7 +102,7 @@ def __init__(
video_paths: List[str],
clip_length_in_frames: int = 16,
frames_between_clips: int = 1,
frame_rate: Optional[int] = None,
frame_rate: Optional[float] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 0,
_video_width: int = 0,
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(

def _compute_frame_pts(self) -> None:
self.video_pts = [] # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
self.video_fps: List[int] = [] # len = num_videos
self.video_fps: List[float] = [] # len = num_videos

# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
Expand Down Expand Up @@ -203,15 +203,15 @@ def subset(self, indices: List[int]) -> "VideoClips":

@staticmethod
def compute_clips_for_video(
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
fps = 1
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
total_frames = len(video_pts) * frame_rate / fps
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[_idxs]
clips = unfold(video_pts, num_frames, step)
Expand All @@ -227,7 +227,7 @@ def compute_clips_for_video(
idxs = unfold(_idxs, num_frames, step)
return clips, idxs

def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None:
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
Expand Down Expand Up @@ -275,8 +275,8 @@ def get_clip_location(self, idx: int) -> Tuple[int, int]:
return video_idx, clip_idx

@staticmethod
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
step = float(original_fps) / new_fps
def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]:
step = original_fps / new_fps
if step.is_integer():
# optimization: if step is integer, don't need to perform
# advanced indexing
Expand Down

0 comments on commit 806dba6

Please sign in to comment.