Skip to content
60 changes: 31 additions & 29 deletions python/ray/data/preprocessors/scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pandas as pd

from ray.data.aggregate import AbsMax, Max, Mean, Min, Std
from ray.data.aggregate import AbsMax, ApproximateQuantile, Max, Mean, Min, Std
from ray.data.preprocessor import Preprocessor
from ray.util.annotations import PublicAPI

Expand Down Expand Up @@ -311,7 +311,7 @@ def __repr__(self):

@PublicAPI(stability="alpha")
class RobustScaler(Preprocessor):
r"""Scale and translate each column using quantiles.
r"""Scale and translate each column using approximate quantiles.

The general formula is given by

Expand All @@ -323,6 +323,9 @@ class RobustScaler(Preprocessor):
high and low quantiles, respectively. By default, :math:`\mu_{h}` is the third
quartile and :math:`\mu_{l}` is the first quartile.

Internally, the `ApproximateQuantile` aggregator is used to calculate the
approximate quantiles.

.. tip::
This scaler works well when your data contains many outliers.

Expand Down Expand Up @@ -377,53 +380,52 @@ class RobustScaler(Preprocessor):
columns will be the same as the input columns. If not None, the length of
``output_columns`` must match the length of ``columns``, othwerwise an error
will be raised.
quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL);
higher values yield lower error but use more memory. Defaults to 800. See
https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html
for details on accuracy and size.
"""

DEFAULT_QUANTILE_PRECISION = 800

def __init__(
self,
columns: List[str],
quantile_range: Tuple[float, float] = (0.25, 0.75),
output_columns: Optional[List[str]] = None,
quantile_precision: int = DEFAULT_QUANTILE_PRECISION,
):
super().__init__()
self.columns = columns
self.quantile_range = quantile_range
self.quantile_precision = quantile_precision

self.output_columns = Preprocessor._derive_and_validate_output_columns(
columns, output_columns
)

def _fit(self, dataset: "Dataset") -> Preprocessor:
low = self.quantile_range[0]
med = 0.50
high = self.quantile_range[1]

num_records = dataset.count()
max_index = num_records - 1
split_indices = [int(percentile * max_index) for percentile in (low, med, high)]
quantiles = [
self.quantile_range[0],
0.50,
self.quantile_range[1],
]
aggregates = [
ApproximateQuantile(
on=col,
quantiles=quantiles,
quantile_precision=self.quantile_precision,
)
for col in self.columns
]
aggregated = dataset.aggregate(*aggregates)

self.stats_ = {}

# TODO(matt): Handle case where quantile lands between 2 numbers.
# The current implementation will simply choose the closest index.
# This will affect the results of small datasets more than large datasets.
for col in self.columns:
filtered_dataset = dataset.map_batches(
lambda df: df[[col]], batch_format="pandas"
)
sorted_dataset = filtered_dataset.sort(col)
_, low, med, high = sorted_dataset.split_at_indices(split_indices)

def _get_first_value(ds: "Dataset", c: str):
return ds.take(1)[0][c]

low_val = _get_first_value(low, col)
med_val = _get_first_value(med, col)
high_val = _get_first_value(high, col)

self.stats_[f"low_quantile({col})"] = low_val
self.stats_[f"median({col})"] = med_val
self.stats_[f"high_quantile({col})"] = high_val
low_q, med_q, high_q = aggregated[f"approx_quantile({col})"]
self.stats_[f"low_quantile({col})"] = low_q
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be better if we also show value of quantile here.

self.stats_[f"median({col})"] = med_q
self.stats_[f"high_quantile({col})"] = high_q

return self

Expand Down