Skip to content

Commit

Permalink
Merge pull request #37 from Grutschus/20-label-distribution
Browse files Browse the repository at this point in the history
20 label distribution
  • Loading branch information
Grutschus authored Dec 1, 2023
2 parents 14b70f2 + 6c90ce9 commit 1ac4f08
Show file tree
Hide file tree
Showing 2 changed files with 443 additions and 0 deletions.
91 changes: 91 additions & 0 deletions datasets/transforms/sampling_strategy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from typing import List, Tuple

import numpy as np
import pandas as pd

from registry import SAMPLING_STRATEGIES
Expand Down Expand Up @@ -64,3 +65,93 @@ def sample(self, annotation: pd.Series) -> List[IntervalInSeconds]:
clip_start += self.stride if self.overlap else self.clip_len + self.stride

return sample_list


@SAMPLING_STRATEGIES.register_module()
class GaussianSampling(SamplingStrategy):
"""Samples the center of each interval from a gaussian distribution that
is centered around the center of a given class interval.
Example: If the priority class is `fall` and the fall interval is (2, 4),
the center of clips are sampled from a Gaussian distribution centered around
3. The standard deviation can be freely chosen.
Args:
clip_len (float): Length of the clips to sample in seconds.
focus_interval_start_name (str): Name of the column in the annotation
that contains the start timestamp of the priority class interval.
Defaults to "fall_start".
focus_interval_end_name (str): Name of the column in the annotation
that contains the end timestamp of the priority class interval.
Defaults to "fall_end".
n_samples_per_sec (float | None): Number of samples per second. If None `1/clip_len`
is used to sample approximately the same number of samples as a
UniformSampling strategy. Defaults to None.
fallback_sampler (SamplingStrategy | dict | None): Sampler to use if the timestamps
of the focus interval are not present in the annotation. If None,
`UniformSampling` with equal `clip_len` is used. Defaults to None.
std (None | float): Standard deviation of the gaussian distribution. If None,
`min(focus_interval_center, total_length - focus_interval_center) / 3` is used.
Defaults to None.
"""

def __init__(
self,
clip_len: float,
focus_interval_start_name: str = "fall_start",
focus_interval_end_name: str = "fall_end",
n_samples_per_sec: float | None = None,
fallback_sampler: SamplingStrategy | dict | None = None,
std: None | float = None,
) -> None:
self.clip_len = clip_len
self.focus_interval_start_name = focus_interval_start_name
self.focus_interval_end_name = focus_interval_end_name
self.std = std
if fallback_sampler is None:
self.fallback_sampler: SamplingStrategy = UniformSampling(clip_len)
elif isinstance(fallback_sampler, dict):
self.fallback_sampler = SAMPLING_STRATEGIES.build(fallback_sampler)
else:
self.fallback_sampler = fallback_sampler
self.n_samples_per_sec = n_samples_per_sec

def sample(self, annotation: pd.Series) -> List[IntervalInSeconds]:
if (
self.focus_interval_start_name not in annotation.keys()
or self.focus_interval_end_name not in annotation.keys()
):
raise ValueError(
"Given focus interval names "
f"{self.focus_interval_start_name} and {self.focus_interval_end_name} "
"are not in the annotation."
)

focus_interval = (
annotation[self.focus_interval_start_name],
annotation[self.focus_interval_end_name],
)

if any(np.isnan(focus_interval)):
return self.fallback_sampler.sample(annotation)

mean = sum(focus_interval) / 2
std = self.std
if std is None:
std = min(mean, annotation["length"] - mean) / 3

n_samples_per_sec = self.n_samples_per_sec
if n_samples_per_sec is None:
n_samples_per_sec = 1.0 / self.clip_len

samples = np.random.normal(
mean, std, int(n_samples_per_sec * annotation["length"])
).round(decimals=2)

sample_list = []
for sample in samples:
start = max(0, sample - self.clip_len / 2)
end = min(annotation["length"], sample + self.clip_len / 2)
sample_list.append((start, end))

return sample_list
Loading

0 comments on commit 1ac4f08

Please sign in to comment.