Skip to content

Commit

Permalink
Merge pull request #41 from Grutschus/38-further-sampling-strategies
Browse files Browse the repository at this point in the history
38 further sampling strategies
  • Loading branch information
Grutschus authored Dec 4, 2023
2 parents 8d28ecd + e6c54c1 commit fd62a4e
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 33 deletions.
20 changes: 19 additions & 1 deletion datasets/high_quality_fall_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Callable, List, Optional, Union

import numpy as np
import pandas as pd
from mmaction.datasets import BaseActionDataset
from mmaction.registry import DATASETS
Expand Down Expand Up @@ -60,7 +61,10 @@ class HighQualityFallDataset(BaseActionDataset):
modality (str): Modality of data. Support ``'RGB'``, ``'Flow'``.
Defaults to ``'RGB'``.
test_mode (bool): Store True when building test or validation dataset.
Defaults to False."""
Defaults to False.
drop_ratios (List[float], optional): List of drop ratios for each class.
If None, no samples are dropped. Ignored for multi_class.
Defaults to None."""

def __init__(
self,
Expand All @@ -74,6 +78,7 @@ def __init__(
start_index: int = 0,
modality: str = "RGB",
test_mode: bool = False,
drop_ratios: List[float] | None = None,
**kwargs,
) -> None:
# Bug in MMENGINE: kwarg `custom_imports` is not removed from kwargs
Expand All @@ -90,6 +95,7 @@ def __init__(
self.label_strategy = LABEL_STRATEGIES.build(label_strategy) # type: LabelStrategy
else:
self.label_strategy = label_strategy
self.drop_ratios = drop_ratios
super().__init__(
ann_file,
pipeline=pipeline,
Expand All @@ -113,6 +119,8 @@ def load_data_list(self) -> List[dict]:
]

for clip, label in zip(sampled_clips, labels):
if self._should_drop(label, self.drop_ratios):
continue
data_list.append(
{
"filename": annotation["video_path"],
Expand All @@ -121,3 +129,13 @@ def load_data_list(self) -> List[dict]:
}
)
return data_list

def _should_drop(
self, label: int | list[int], drop_ratios: List[float] | None
) -> bool:
if self.test_mode or drop_ratios is None or not isinstance(label, int):
return False
for index, drop_ratio in enumerate(drop_ratios):
if np.random.rand() <= drop_ratio and index == label:
return True
return False
25 changes: 25 additions & 0 deletions datasets/transforms/sampling_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,28 @@ def sample(self, annotation: pd.Series) -> List[IntervalInSeconds]:
sample_list.append((start, end))

return sample_list


@SAMPLING_STRATEGIES.register_module()
class FilterSampling(SamplingStrategy):
"""Meta-sampling strategy that performs video-level filtering.
It drops complete videos from the dataset."""

def __init__(
self,
sampler: SamplingStrategy | dict,
filter_column_name: str = "category",
values: str | list[str] | None = "ADL",
blacklist: bool = True,
) -> None:
if isinstance(sampler, dict):
self.sampler = SAMPLING_STRATEGIES.build(sampler)
else:
self.sampler = sampler

def sample(self, annotation: pd.Series) -> List[IntervalInSeconds]:
raise NotImplementedError
# Check whether the filter applies and we should discard the sample -> return empty list

# Otherwise return the samples of the sampler
Loading

0 comments on commit fd62a4e

Please sign in to comment.