From 06cbfed88f8c07cbe4c6dce2327ab706f8ed0a22 Mon Sep 17 00:00:00 2001 From: polinaeterna Date: Wed, 31 Jul 2024 14:37:42 +0200 Subject: [PATCH] compute stats for datetimes --- .../worker/src/worker/statistics_utils.py | 120 +++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/services/worker/src/worker/statistics_utils.py b/services/worker/src/worker/statistics_utils.py index f2651bb091..4048a9f8be 100644 --- a/services/worker/src/worker/statistics_utils.py +++ b/services/worker/src/worker/statistics_utils.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Copyright 2024 The HuggingFace Authors. +import datetime import enum import io import logging @@ -50,11 +51,12 @@ class ColumnType(str, enum.Enum): STRING_TEXT = "string_text" AUDIO = "audio" IMAGE = "image" + DATETIME = "datetime" class Histogram(TypedDict): hist: list[int] - bin_edges: list[Union[int, float]] + bin_edges: list[Union[int, float, str]] class NumericalStatisticsItem(TypedDict): @@ -68,6 +70,17 @@ class NumericalStatisticsItem(TypedDict): histogram: Optional[Histogram] +class DatetimeStatisticsItem(TypedDict): + nan_count: int + nan_proportion: float + min: Optional[str] # might be None in very rare cases when the whole column is only None values + max: Optional[str] + mean: Optional[str] + median: Optional[str] + std: Optional[str] # string representation of timedelta + histogram: Optional[Histogram] + + class CategoricalStatisticsItem(TypedDict): nan_count: int nan_proportion: float @@ -83,7 +96,9 @@ class BoolStatisticsItem(TypedDict): frequencies: dict[str, int] -SupportedStatistics = Union[NumericalStatisticsItem, CategoricalStatisticsItem, BoolStatisticsItem] +SupportedStatistics = Union[ + NumericalStatisticsItem, CategoricalStatisticsItem, BoolStatisticsItem, DatetimeStatisticsItem +] class StatisticsPerColumnItem(TypedDict): @@ -699,3 +714,104 @@ def get_shape(example: Optional[Union[bytes, dict[str, Any]]]) -> Union[tuple[No @classmethod def transform(cls, example: Optional[Union[bytes, dict[str, Any]]]) -> Optional[int]: return cls.get_width(example) + + +class DatetimeColumn(Column): + transform_column = IntColumn + + @classmethod + def compute_transformed_data( + cls, + data: pl.DataFrame, + column_name: str, + transformed_column_name: str, + min_date: datetime.datetime, + ) -> pl.DataFrame: + return data.select((pl.col(column_name) - min_date).dt.total_seconds().alias(transformed_column_name)) + + @staticmethod + def shift_and_convert_to_string(min_date, seconds) -> str: + return datetime_to_string(min_date + datetime.timedelta(seconds=seconds)) + + @classmethod + def _compute_statistics( + cls, + data: pl.DataFrame, + column_name: str, + n_samples: int, + ) -> DatetimeStatisticsItem: + nan_count, nan_proportion = nan_count_proportion(data, column_name, n_samples) + if nan_count == n_samples: # all values are None + return DatetimeStatisticsItem( + nan_count=n_samples, + nan_proportion=1.0, + min=None, + max=None, + mean=None, + median=None, + std=None, + histogram=None, + ) + + min_date = data[column_name].min() + timedelta_column_name = f"{column_name}_timedelta" + # compute distribution of time passed from min date in **seconds** + timedelta_df = cls.compute_transformed_data(data, column_name, timedelta_column_name, min_date) + timedelta_stats: NumericalStatisticsItem = cls.transform_column.compute_statistics( + timedelta_df, + column_name=timedelta_column_name, + n_samples=n_samples, + ) + for stat in ("max", "mean", "median"): + timedelta_stats[stat] = cls.shift_and_convert_to_string(min_date, timedelta_stats[stat]) + + bin_edges = [ + cls.shift_and_convert_to_string(min_date, seconds) for seconds in timedelta_stats["histogram"]["bin_edges"] + ] + + return DatetimeStatisticsItem( + nan_count=nan_count, + nan_proportion=nan_proportion, + min=datetime_to_string(min_date), + max=timedelta_stats["max"], + mean=timedelta_stats["mean"], + median=timedelta_stats["median"], + std=str(timedelta_stats["std"]), + histogram=Histogram( + hist=timedelta_stats["histogram"]["hist"], + bin_edges=bin_edges, + ), + ) + + def compute_and_prepare_response(self, data: pl.DataFrame) -> StatisticsPerColumnItem: + stats = self.compute_statistics(data, column_name=self.name, n_samples=self.n_samples) + return StatisticsPerColumnItem( + column_name=self.name, + column_type=ColumnType.DATETIME, + column_statistics=stats, + ) + + +def datetime_to_string(dt: datetime.datetime, format: str = "%Y-%m-%d %H:%M:%S") -> str: + """ + Convert a datetime.datetime object to a string. + + Args: + dt (datetime): The datetime object to convert. + format (str, optional): The format of the output string. Defaults to "%Y-%m-%d %H:%M:%S". + + Returns: + str: The datetime object as a string. + """ + return dt.strftime(format) + + +if __name__ == "__main__": + path = "/home/polina/workspace/notebooks/stats/data/fineweb/000_00000.parquet" + column_name = "date" + data = pl.read_parquet(path, columns=[column_name]) + data = data.select(pl.col(column_name).cast(pl.Datetime)) + n_samples = data.shape[0] + # column = DatetimeColumn(feature_name=column_name, n_samples=data.shape[0]) + stats = DatetimeColumn.compute_statistics(data, column_name=column_name, n_samples=n_samples) + print(stats)