diff --git a/python/ray/data/preprocessors/scaler.py b/python/ray/data/preprocessors/scaler.py index 82ef81e390fc..f8da3acaa5fb 100644 --- a/python/ray/data/preprocessors/scaler.py +++ b/python/ray/data/preprocessors/scaler.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd -from ray.data.aggregate import AbsMax, Max, Mean, Min, Std +from ray.data.aggregate import AbsMax, ApproximateQuantile, Max, Mean, Min, Std from ray.data.preprocessor import Preprocessor from ray.util.annotations import PublicAPI @@ -311,7 +311,7 @@ def __repr__(self): @PublicAPI(stability="alpha") class RobustScaler(Preprocessor): - r"""Scale and translate each column using quantiles. + r"""Scale and translate each column using approximate quantiles. The general formula is given by @@ -323,6 +323,9 @@ class RobustScaler(Preprocessor): high and low quantiles, respectively. By default, :math:`\mu_{h}` is the third quartile and :math:`\mu_{l}` is the first quartile. + Internally, the `ApproximateQuantile` aggregator is used to calculate the + approximate quantiles. + .. tip:: This scaler works well when your data contains many outliers. @@ -377,53 +380,52 @@ class RobustScaler(Preprocessor): columns will be the same as the input columns. If not None, the length of ``output_columns`` must match the length of ``columns``, othwerwise an error will be raised. + quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL); + higher values yield lower error but use more memory. Defaults to 800. See + https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html + for details on accuracy and size. """ + DEFAULT_QUANTILE_PRECISION = 800 + def __init__( self, columns: List[str], quantile_range: Tuple[float, float] = (0.25, 0.75), output_columns: Optional[List[str]] = None, + quantile_precision: int = DEFAULT_QUANTILE_PRECISION, ): super().__init__() self.columns = columns self.quantile_range = quantile_range + self.quantile_precision = quantile_precision self.output_columns = Preprocessor._derive_and_validate_output_columns( columns, output_columns ) def _fit(self, dataset: "Dataset") -> Preprocessor: - low = self.quantile_range[0] - med = 0.50 - high = self.quantile_range[1] - - num_records = dataset.count() - max_index = num_records - 1 - split_indices = [int(percentile * max_index) for percentile in (low, med, high)] + quantiles = [ + self.quantile_range[0], + 0.50, + self.quantile_range[1], + ] + aggregates = [ + ApproximateQuantile( + on=col, + quantiles=quantiles, + quantile_precision=self.quantile_precision, + ) + for col in self.columns + ] + aggregated = dataset.aggregate(*aggregates) self.stats_ = {} - - # TODO(matt): Handle case where quantile lands between 2 numbers. - # The current implementation will simply choose the closest index. - # This will affect the results of small datasets more than large datasets. for col in self.columns: - filtered_dataset = dataset.map_batches( - lambda df: df[[col]], batch_format="pandas" - ) - sorted_dataset = filtered_dataset.sort(col) - _, low, med, high = sorted_dataset.split_at_indices(split_indices) - - def _get_first_value(ds: "Dataset", c: str): - return ds.take(1)[0][c] - - low_val = _get_first_value(low, col) - med_val = _get_first_value(med, col) - high_val = _get_first_value(high, col) - - self.stats_[f"low_quantile({col})"] = low_val - self.stats_[f"median({col})"] = med_val - self.stats_[f"high_quantile({col})"] = high_val + low_q, med_q, high_q = aggregated[f"approx_quantile({col})"] + self.stats_[f"low_quantile({col})"] = low_q + self.stats_[f"median({col})"] = med_q + self.stats_[f"high_quantile({col})"] = high_q return self