Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ready] motion_score_raft #478

Merged
merged 4 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,16 @@ process:
sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow
size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w)
max_size: null # maximum allowed for the longer edge of resized frames
divisible: 1 # The number that the dimensions must be divisible by.
relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length
any_or_all: any # keep this sample when any/all videos meet the filter condition
- video_motion_score_raft_filter: # Keep samples with video motion scores (based on RAFT model) within a specific range.
min_score: 1.0 # the minimum motion score to keep samples
max_score: 10000.0 # the maximum motion score to keep samples
sampling_fps: 2 # the samplig rate of frames_per_second to compute optical flow
size: null # resize frames along the smaller edge before computing optical flow, or a sequence like (h, w)
max_size: null # maximum allowed for the longer edge of resized frames
divisible: 8 # The number that the dimensions must be divisible by.
relative: false # whether to normalize the optical flow magnitude to [0, 1], relative to the frame's diagonal length
any_or_all: any # keep this sample when any/all videos meet the filter condition
- video_nsfw_filter: # filter samples according to the nsfw scores of videos in them
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .video_frames_text_similarity_filter import \
VideoFramesTextSimilarityFilter
from .video_motion_score_filter import VideoMotionScoreFilter
from .video_motion_score_raft_filter import VideoMotionScoreRaftFilter
from .video_nsfw_filter import VideoNSFWFilter
from .video_ocr_area_ratio_filter import VideoOcrAreaRatioFilter
from .video_resolution_filter import VideoResolutionFilter
Expand All @@ -57,7 +58,8 @@
'TextActionFilter', 'TextEntityDependencyFilter', 'TextLengthFilter',
'TokenNumFilter', 'VideoAestheticsFilter', 'VideoAspectRatioFilter',
'VideoDurationFilter', 'VideoFramesTextSimilarityFilter',
'VideoMotionScoreFilter', 'VideoNSFWFilter', 'VideoOcrAreaRatioFilter',
'VideoResolutionFilter', 'VideoTaggingFromFramesFilter',
'VideoWatermarkFilter', 'WordRepetitionFilter', 'WordsNumFilter'
'VideoMotionScoreFilter', 'VideoMotionScoreRaftFilter', 'VideoNSFWFilter',
'VideoOcrAreaRatioFilter', 'VideoResolutionFilter',
'VideoTaggingFromFramesFilter', 'VideoWatermarkFilter',
'WordRepetitionFilter', 'WordsNumFilter'
]
67 changes: 31 additions & 36 deletions data_juicer/ops/filter/video_motion_score_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import calculate_resized_dimensions

from ..base_op import OPERATORS, UNFORKABLE, Filter

Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self,
size: Union[PositiveInt, Tuple[PositiveInt],
Tuple[PositiveInt, PositiveInt], None] = None,
max_size: Optional[PositiveInt] = None,
divisible: PositiveFloat = 1,
relative: bool = False,
any_or_all: str = 'any',
*args,
Expand All @@ -69,6 +71,7 @@ def __init__(self,
being resized according to size, size will be overruled so that the
longer edge is equal to max_size. As a result, the smaller edge may
be shorter than size. This is only supported if size is an int.
:param divisible: The number that the dimensions must be divisible by.
:param relative: If `True`, the optical flow magnitude is normalized to
a [0, 1] range, relative to the frame's diagonal length.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
Expand All @@ -92,6 +95,7 @@ def __init__(self,
size = (size, )
self.size = size
self.max_size = max_size
self.divisible = divisible
self.relative = relative

self.extra_kwargs = self._default_kwargs
Expand All @@ -104,7 +108,21 @@ def __init__(self,
f'Can only be one of ["any", "all"].')
self.any = (any_or_all == 'any')

def compute_stats_single(self, sample, context=False):
def setup_model(self, rank=None):
self.model = cv2.calcOpticalFlowFarneback

def compute_flow(self, prev_frame, curr_frame):
curr_frame = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)
if prev_frame is None:
flow = None
else:
flow = self.model(prev_frame, curr_frame, None,
**self.extra_kwargs)
return flow, curr_frame

def compute_stats_single(self, sample, rank=None, context=False):
self.rank = rank

# check if it's computed already
if StatsKeys.video_motion_score in sample[Fields.stats]:
return sample
Expand All @@ -115,6 +133,8 @@ def compute_stats_single(self, sample, context=False):
[], dtype=np.float64)
return sample

self.setup_model(rank)

# load videos
loaded_video_keys = sample[self.video_key]
unique_motion_scores = {}
Expand All @@ -133,6 +153,11 @@ def compute_stats_single(self, sample, context=False):
# at least two frames for computing optical flow
sampling_step = max(min(sampling_step, total_frames - 1),
1)
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
new_size = calculate_resized_dimensions(
(height, width), self.size, self.max_size,
self.divisible)

prev_frame = None
frame_count = 0
Expand All @@ -143,27 +168,21 @@ def compute_stats_single(self, sample, context=False):
# a corrupt frame or reaching the end of the video.
break

height, width, _ = frame.shape
new_size = _compute_resized_output_size(
(height, width), self.size, self.max_size)
if new_size != (height, width):
frame = cv2.resize(frame,
new_size,
interpolation=cv2.INTER_AREA)

gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
if prev_frame is None:
prev_frame = gray_frame
# return flow of shape (H, W, 2) and transformed frame
# of shape (H, W, 3) in BGR mode
flow, prev_frame = self.compute_flow(prev_frame, frame)
if flow is None:
continue

flow = cv2.calcOpticalFlowFarneback(
prev_frame, gray_frame, None, **self.extra_kwargs)
mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
frame_motion_score = np.mean(mag)
if self.relative:
frame_motion_score /= np.hypot(*flow.shape[:2])
frame_motion_score /= np.hypot(*frame.shape[:2])
video_motion_scores.append(frame_motion_score)
prev_frame = gray_frame

# quickly skip frames
frame_count += sampling_step
Expand Down Expand Up @@ -197,27 +216,3 @@ def process_single(self, sample):
return keep_bools.any()
else:
return keep_bools.all()


def _compute_resized_output_size(
frame_size: Tuple[int, int],
size: Union[Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt]],
max_size: Optional[int] = None,
) -> Tuple[int, int]:
h, w = frame_size
short, long = (w, h) if w <= h else (h, w)

if size is None: # no change
new_short, new_long = short, long
elif len(size) == 1: # specified size only for the smallest edge
new_short = size[0]
new_long = int(new_short * long / short)
else: # specified both h and w
new_short, new_long = min(size), max(size)

if max_size is not None and new_long > max_size:
new_short = int(max_size * new_short / new_long)
new_long = max_size

new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
return new_h, new_w
80 changes: 80 additions & 0 deletions data_juicer/ops/filter/video_motion_score_raft_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import sys
from typing import Optional, Tuple, Union

from pydantic import PositiveFloat, PositiveInt

from data_juicer import cuda_device_count
from data_juicer.ops.filter.video_motion_score_filter import \
VideoMotionScoreFilter
from data_juicer.utils.lazy_loader import LazyLoader

from ..base_op import OPERATORS, UNFORKABLE

torch = LazyLoader('torch', 'torch')
tvm = LazyLoader('tvm', 'torchvision.models')
tvt = LazyLoader('tvt', 'torchvision.transforms')

OP_NAME = 'video_motion_score_raft_filter'


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class VideoMotionScoreRaftFilter(VideoMotionScoreFilter):
"""Filter to keep samples with video motion scores within a specified range.
This operator utilizes the RAFT (Recurrent All-Pairs Field Transforms)
model from torchvision to predict optical flow between video frames.

For further details, refer to the official torchvision documentation:
https://pytorch.org/vision/main/models/raft.html

The original paper on RAFT is available here:
https://arxiv.org/abs/2003.12039
"""

_accelerator = 'cuda'
_default_kwargs = {}

def __init__(self,
min_score: float = 1.0,
max_score: float = sys.float_info.max,
sampling_fps: PositiveFloat = 2,
size: Union[PositiveInt, Tuple[PositiveInt],
Tuple[PositiveInt, PositiveInt], None] = None,
max_size: Optional[PositiveInt] = None,
divisible: PositiveFloat = 8,
relative: bool = False,
any_or_all: str = 'any',
*args,
**kwargs):
super().__init__(min_score, max_score, sampling_fps, size, max_size,
divisible, relative, any_or_all, *args, **kwargs)

def setup_model(self, rank=None):
self.model = tvm.optical_flow.raft_large(
weights=tvm.optical_flow.Raft_Large_Weights.DEFAULT,
progress=False)
if self.use_cuda():
rank = rank if rank is not None else 0
rank = rank % cuda_device_count()
self.device = f'cuda:{rank}'
else:
self.device = 'cpu'
self.model.to(self.device)
self.model.eval()

self.transforms = tvt.Compose([
tvt.ToTensor(),
tvt.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
tvt.Lambda(lambda img: img.flip(-3).unsqueeze(0)), # BGR to RGB
])

def compute_flow(self, prev_frame, curr_frame):
curr_frame = self.transforms(curr_frame).to(self.device)
if prev_frame is None:
flow = None
else:
with torch.inference_mode():
flows = self.model(prev_frame, curr_frame)
flow = flows[-1][0].cpu().numpy().transpose(
(1, 2, 0)) # 2, H, W -> H, W, 2
return flow, curr_frame
54 changes: 53 additions & 1 deletion data_juicer/utils/mm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import re
import shutil
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import av
import numpy as np
Expand Down Expand Up @@ -164,6 +164,58 @@ def iou(box1, box2):
return 1.0 * intersection / union


def calculate_resized_dimensions(
original_size: Tuple[PositiveInt, PositiveInt],
target_size: Union[PositiveInt, Tuple[PositiveInt, PositiveInt]],
max_length: Optional[int] = None,
divisible: PositiveInt = 1) -> Tuple[int, int]:
"""
Resize dimensions based on specified constraints.

:param original_size: The original dimensions as (height, width).
:param target_size: Desired target size; can be a single integer
(short edge) or a tuple (height, width).
:param max_length: Maximum allowed length for the longer edge.
:param divisible: The number that the dimensions must be divisible by.
:return: Resized dimensions as (height, width).
"""

height, width = original_size
short_edge, long_edge = sorted((width, height))

# Normalize target_size to a tuple
if isinstance(target_size, int):
target_size = (target_size, )

# Initialize new dimensions
if target_size:
if len(target_size) == 1: # Only the smaller edge is specified
new_short_edge = target_size[0]
new_long_edge = int(new_short_edge * long_edge / short_edge)
else: # Both dimensions are specified
new_short_edge = min(target_size)
new_long_edge = max(target_size)
else: # No change
new_short_edge, new_long_edge = short_edge, long_edge

# Enforce maximum length constraint
if max_length is not None and new_long_edge > max_length:
scaling_factor = max_length / new_long_edge
new_short_edge = int(new_short_edge * scaling_factor)
new_long_edge = max_length

# Determine final dimensions based on original orientation
resized_dimensions = ((new_short_edge,
new_long_edge) if width <= height else
(new_long_edge, new_short_edge))

# Ensure final dimensions are divisible by the specified value
resized_dimensions = tuple(
int(dim / divisible) * divisible for dim in resized_dimensions)

return resized_dimensions


# Audios
def load_audios(paths):
return [load_audio(path) for path in paths]
Expand Down
3 changes: 2 additions & 1 deletion docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The operators in Data-Juicer are categorized into 5 types.
|-----------------------------------|:------:|-------------------------------------------------|
| [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data |
| [ Mapper ]( #mapper ) | 52 | Edits and transforms samples |
| [ Filter ]( #filter ) | 43 | Filters out low-quality samples |
| [ Filter ]( #filter ) | 44 | Filters out low-quality samples |
| [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples |
| [ Selector ]( #selector ) | 4 | Selects top samples based on ranking |

Expand Down Expand Up @@ -149,6 +149,7 @@ All the specific operators are listed below, each featured with several capabili
| video_duration_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep data samples whose videos' durations are within a specified range | [code](../data_juicer/ops/filter/video_duration_filter.py) | [tests](../tests/ops/filter/test_video_duration_filter.py) |
| video_frames_text_similarity_filter | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keep data samples whose similarities between sampled video frame images and text are within a specific range | [code](../data_juicer/ops/filter/video_frames_text_similarity_filter.py) | [tests](../tests/ops/filter/test_video_frames_text_similarity_filter.py) |
| video_motion_score_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep samples with video motion scores within a specific range | [code](../data_juicer/ops/filter/video_motion_score_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_filter.py) |
| video_motion_score_raft_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keep samples with video motion scores (based on RAFT model) within a specific range | [code](../data_juicer/ops/filter/video_motion_score_raft_filter.py) | [tests](../tests/ops/filter/test_video_motion_score_raft_filter.py) |
| video_nsfw_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keeps samples containing videos with NSFW scores below the threshold | [code](../data_juicer/ops/filter/video_nsfw_filter.py) | [tests](../tests/ops/filter/test_video_nsfw_filter.py) |
| video_ocr_area_ratio_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Keep data samples whose detected text area ratios for specified frames in the video are within a specified range | [code](../data_juicer/ops/filter/video_ocr_area_ratio_filter.py) | [tests](../tests/ops/filter/test_video_ocr_area_ratio_filter.py) |
| video_resolution_filter | ![Video](https://img.shields.io/badge/Video-F2B138?style=plastic) | Keeps samples containing videos with horizontal and vertical resolutions within the specified range | [code](../data_juicer/ops/filter/video_resolution_filter.py) | [tests](../tests/ops/filter/test_video_resolution_filter.py) |
Expand Down
Loading
Loading