Skip to content

Commit

Permalink
Simplified and optimized logic for calculating per-metric subsampling…
Browse files Browse the repository at this point in the history
… rate for MapData

Differential Revision: D66366076
  • Loading branch information
ltiao authored and facebook-github-bot committed Nov 22, 2024
1 parent 1713245 commit a5bb39a
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

from bisect import bisect_right
from collections.abc import Iterable, Sequence
from copy import deepcopy
from logging import Logger
Expand Down Expand Up @@ -412,6 +413,10 @@ def subsample(
)


def _ceil_divide(a, b):
return -np.floor_divide(-a, b)


def _subsample_one_metric(
map_df: pd.DataFrame,
map_key: str | None = None,
Expand All @@ -421,30 +426,37 @@ def _subsample_one_metric(
include_first_last: bool = True,
) -> pd.DataFrame:
"""Helper function to subsample a dataframe that holds a single metric."""

grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)
derived_keep_every = 1
if keep_every is not None:
derived_keep_every = keep_every
elif limit_rows_per_group is not None:
max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max()
derived_keep_every = np.ceil(max_rows / limit_rows_per_group)
elif limit_rows_per_metric is not None:
group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy()
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
for k in range(1, group_sizes.max() + 1):
if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric:
derived_keep_every = k
break
# if no such `k` is found, then `derived_keep_every` stays as 1.

else:
group_sizes = grouped_map_df.size()
max_rows = group_sizes.max()
if limit_rows_per_group is not None:
derived_keep_every = _ceil_divide(max_rows, limit_rows_per_group)
elif limit_rows_per_metric is not None:
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
ks = np.arange(max_rows, 0, -1)
# total sizes in ascending order
total_sizes = _ceil_divide(group_sizes.values, ks[..., np.newaxis]).sum(1)
# binary search
i = bisect_right(total_sizes, limit_rows_per_metric)
# if no such `k` is found, then `derived_keep_every` stays as 1.
if i > 0:
derived_keep_every = ks[i - 1]
elif total_sizes[0] == limit_rows_per_metric:
derived_keep_every = ks[0]
if derived_keep_every <= 1:
filtered_map_df = map_df
else:
filtered_dfs = []
for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS):
for _, df_g in grouped_map_df:
df_g = df_g.sort_values(map_key)
if include_first_last:
rows_per_group = int(np.ceil(len(df_g) / derived_keep_every))
rows_per_group = _ceil_divide(len(df_g), derived_keep_every)
linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group)
idcs = np.round(linspace_idcs).astype(int)
filtered_df = df_g.iloc[idcs]
Expand Down

0 comments on commit a5bb39a

Please sign in to comment.