Skip to content

Commit

Permalink
Simplified and optimized logic for calculating per-metric subsampling…
Browse files Browse the repository at this point in the history
… rate for MapData (#3106)

Summary:

This refines the logic for calculating per-metric subsampling rates in `MapData.subsample` and incorporates a (probably premature) performance optimization, achieved by utilizing binary search on a sorted list instead of linear search.

Differential Revision: D66366076
  • Loading branch information
ltiao authored and facebook-github-bot committed Nov 23, 2024
1 parent f3ed01f commit b7b8491
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions ax/core/map_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@

from __future__ import annotations

from bisect import bisect_right
from collections.abc import Iterable, Sequence
from copy import deepcopy
from logging import Logger
from typing import Any, Generic, TypeVar

import numpy as np
import numpy.typing as npt
import pandas as pd
from ax.core.data import Data
from ax.core.types import TMapTrialEvaluation
Expand Down Expand Up @@ -412,6 +414,10 @@ def subsample(
)


def _ceil_divide(a: int | npt.NDArray, b: int | npt.NDArray) -> int | npt.NDArray:
return -np.floor_divide(-a, b)


def _subsample_one_metric(
map_df: pd.DataFrame,
map_key: str | None = None,
Expand All @@ -421,30 +427,37 @@ def _subsample_one_metric(
include_first_last: bool = True,
) -> pd.DataFrame:
"""Helper function to subsample a dataframe that holds a single metric."""

grouped_map_df = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS)
derived_keep_every = 1
if keep_every is not None:
derived_keep_every = keep_every
elif limit_rows_per_group is not None:
max_rows = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().max()
derived_keep_every = np.ceil(max_rows / limit_rows_per_group)
elif limit_rows_per_metric is not None:
group_sizes = map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS).size().to_numpy()
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
for k in range(1, group_sizes.max() + 1):
if (np.ceil(group_sizes / k)).sum() <= limit_rows_per_metric:
derived_keep_every = k
break
# if no such `k` is found, then `derived_keep_every` stays as 1.

else:
group_sizes = grouped_map_df.size()
max_rows = group_sizes.max()
if limit_rows_per_group is not None:
derived_keep_every = _ceil_divide(max_rows, limit_rows_per_group)
elif limit_rows_per_metric is not None:
# search for the `keep_every` such that when you apply it to each group,
# the total number of rows is smaller than `limit_rows_per_metric`.
ks = np.arange(max_rows, 0, -1)
# total sizes in ascending order
total_sizes = np.sum(
_ceil_divide(group_sizes.values, ks[..., np.newaxis]), axis=1
)
# binary search
i = bisect_right(total_sizes, limit_rows_per_metric)
# if no such `k` is found, then `derived_keep_every` stays as 1.
if i > 0:
derived_keep_every = ks[i - 1]
if derived_keep_every <= 1:
filtered_map_df = map_df
else:
filtered_dfs = []
for _, df_g in map_df.groupby(MapData.DEDUPLICATE_BY_COLUMNS):
for _, df_g in grouped_map_df:
df_g = df_g.sort_values(map_key)
if include_first_last:
rows_per_group = int(np.ceil(len(df_g) / derived_keep_every))
rows_per_group = _ceil_divide(len(df_g), derived_keep_every)
linspace_idcs = np.linspace(0, len(df_g) - 1, rows_per_group)
idcs = np.round(linspace_idcs).astype(int)
filtered_df = df_g.iloc[idcs]
Expand Down

0 comments on commit b7b8491

Please sign in to comment.