Skip to content

Commit df65225

Browse files
kyudsowenowenisme
andauthored
[Data] Use Approximate Quantile for RobustScaler Preprocessor (#58371)
## Description Currently Ray Data has a preprocessor called `RobustScaler`. This scales the data based on given quantiles. Calculating the quantiles involves sorting the entire dataset by column for each column (C sorts for C number of columns), which, for a large dataset, will require a lot of calculations. ** MAJOR EDIT **: had to replace the original `tdigest` with `ddsketch` as I couldn't actually find well-maintained tdigest libraries for python. ddsketch is better maintained. ** MAJOR EDIT 2 **: discussed offline to use `ApproximateQuantile` aggregator ## Related issues N/A ## Additional information N/A --------- Signed-off-by: kyuds <kyuseung1016@gmail.com> Signed-off-by: Daniel Shin <kyuseung1016@gmail.com> Co-authored-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com>
1 parent 5e71d58 commit df65225

File tree

1 file changed

+31
-29
lines changed

1 file changed

+31
-29
lines changed

python/ray/data/preprocessors/scaler.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pandas as pd
55

6-
from ray.data.aggregate import AbsMax, Max, Mean, Min, Std
6+
from ray.data.aggregate import AbsMax, ApproximateQuantile, Max, Mean, Min, Std
77
from ray.data.preprocessor import Preprocessor
88
from ray.util.annotations import PublicAPI
99

@@ -311,7 +311,7 @@ def __repr__(self):
311311

312312
@PublicAPI(stability="alpha")
313313
class RobustScaler(Preprocessor):
314-
r"""Scale and translate each column using quantiles.
314+
r"""Scale and translate each column using approximate quantiles.
315315
316316
The general formula is given by
317317
@@ -323,6 +323,9 @@ class RobustScaler(Preprocessor):
323323
high and low quantiles, respectively. By default, :math:`\mu_{h}` is the third
324324
quartile and :math:`\mu_{l}` is the first quartile.
325325
326+
Internally, the `ApproximateQuantile` aggregator is used to calculate the
327+
approximate quantiles.
328+
326329
.. tip::
327330
This scaler works well when your data contains many outliers.
328331
@@ -377,53 +380,52 @@ class RobustScaler(Preprocessor):
377380
columns will be the same as the input columns. If not None, the length of
378381
``output_columns`` must match the length of ``columns``, othwerwise an error
379382
will be raised.
383+
quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL);
384+
higher values yield lower error but use more memory. Defaults to 800. See
385+
https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html
386+
for details on accuracy and size.
380387
"""
381388

389+
DEFAULT_QUANTILE_PRECISION = 800
390+
382391
def __init__(
383392
self,
384393
columns: List[str],
385394
quantile_range: Tuple[float, float] = (0.25, 0.75),
386395
output_columns: Optional[List[str]] = None,
396+
quantile_precision: int = DEFAULT_QUANTILE_PRECISION,
387397
):
388398
super().__init__()
389399
self.columns = columns
390400
self.quantile_range = quantile_range
401+
self.quantile_precision = quantile_precision
391402

392403
self.output_columns = Preprocessor._derive_and_validate_output_columns(
393404
columns, output_columns
394405
)
395406

396407
def _fit(self, dataset: "Dataset") -> Preprocessor:
397-
low = self.quantile_range[0]
398-
med = 0.50
399-
high = self.quantile_range[1]
400-
401-
num_records = dataset.count()
402-
max_index = num_records - 1
403-
split_indices = [int(percentile * max_index) for percentile in (low, med, high)]
408+
quantiles = [
409+
self.quantile_range[0],
410+
0.50,
411+
self.quantile_range[1],
412+
]
413+
aggregates = [
414+
ApproximateQuantile(
415+
on=col,
416+
quantiles=quantiles,
417+
quantile_precision=self.quantile_precision,
418+
)
419+
for col in self.columns
420+
]
421+
aggregated = dataset.aggregate(*aggregates)
404422

405423
self.stats_ = {}
406-
407-
# TODO(matt): Handle case where quantile lands between 2 numbers.
408-
# The current implementation will simply choose the closest index.
409-
# This will affect the results of small datasets more than large datasets.
410424
for col in self.columns:
411-
filtered_dataset = dataset.map_batches(
412-
lambda df: df[[col]], batch_format="pandas"
413-
)
414-
sorted_dataset = filtered_dataset.sort(col)
415-
_, low, med, high = sorted_dataset.split_at_indices(split_indices)
416-
417-
def _get_first_value(ds: "Dataset", c: str):
418-
return ds.take(1)[0][c]
419-
420-
low_val = _get_first_value(low, col)
421-
med_val = _get_first_value(med, col)
422-
high_val = _get_first_value(high, col)
423-
424-
self.stats_[f"low_quantile({col})"] = low_val
425-
self.stats_[f"median({col})"] = med_val
426-
self.stats_[f"high_quantile({col})"] = high_val
425+
low_q, med_q, high_q = aggregated[f"approx_quantile({col})"]
426+
self.stats_[f"low_quantile({col})"] = low_q
427+
self.stats_[f"median({col})"] = med_q
428+
self.stats_[f"high_quantile({col})"] = high_q
427429

428430
return self
429431

0 commit comments

Comments
 (0)