|
| 1 | +"""Scikit-learn-like pipeline. |
| 2 | +
|
| 3 | +This is an example of how a (possibly) lazy data frame may be used |
| 4 | +in a sklearn-like pipeline. |
| 5 | +
|
| 6 | +The example is motivated by the prospect of a fully lazy, ONNX-based |
| 7 | +data frame implementation. The concept is that calls to `fit` are |
| 8 | +eager. They compute some state that is later meant to be transferred |
| 9 | +into the ONNX graph/model. That transfer happens when calling |
| 10 | +`transform`. The logic within the `transform` methods is "traced" |
| 11 | +lazily and the resulting lazy object is then exported to ONNX. |
| 12 | +""" |
| 13 | +from __future__ import annotations |
| 14 | + |
| 15 | +from typing import TYPE_CHECKING, Any |
| 16 | + |
| 17 | +if TYPE_CHECKING: |
| 18 | + from typing_extensions import Self |
| 19 | + |
| 20 | + from dataframe_api.typing import Column, DataFrame, Scalar |
| 21 | + |
| 22 | + |
| 23 | +class Scaler: |
| 24 | + """Apply a standardization scaling factor to `column_names`.""" |
| 25 | + |
| 26 | + scalings_: dict[str, Scalar] |
| 27 | + |
| 28 | + def __init__(self, column_names: list[str]) -> None: |
| 29 | + self.column_names = column_names |
| 30 | + |
| 31 | + def fit(self, df: DataFrame) -> Self: |
| 32 | + """Compute scaling factors from given data frame. |
| 33 | +
|
| 34 | + A typical data science workflow is to fit on one dataset, |
| 35 | + and then transform and multiple datasets. Therefore, we |
| 36 | + make sure to `persist` within the `fit` method. |
| 37 | + """ |
| 38 | + scalings = df.select(*self.column_names).std().persist() |
| 39 | + |
| 40 | + self.scalings_ = { |
| 41 | + column_name: scalings.col(column_name).get_value(0) |
| 42 | + for column_name in self.column_names |
| 43 | + } |
| 44 | + |
| 45 | + return self |
| 46 | + |
| 47 | + def transform(self, df: DataFrame) -> DataFrame: |
| 48 | + """Apply the "trained" scaling values. |
| 49 | +
|
| 50 | + This function is guaranteed to not collect values. |
| 51 | + """ |
| 52 | + columns: list[Column] = [] |
| 53 | + for column_name in df.column_names: |
| 54 | + if column_name not in self.column_names: |
| 55 | + continue |
| 56 | + column = df.col(column_name) / self.scalings_[column_name] |
| 57 | + columns.append(column) |
| 58 | + |
| 59 | + # Note: `assign` is not in-place |
| 60 | + return df.assign(*columns) |
| 61 | + |
| 62 | + |
| 63 | +class FeatureSelector: |
| 64 | + """Limit columns to those seen in training including their order.""" |
| 65 | + |
| 66 | + def fit(self, df: DataFrame) -> Self: |
| 67 | + """Record the observed columns and their order. |
| 68 | +
|
| 69 | + This function is guaranteed to not collect values. |
| 70 | + """ |
| 71 | + self.columns_ = df.column_names |
| 72 | + return self |
| 73 | + |
| 74 | + def transform(self, df: DataFrame) -> DataFrame: |
| 75 | + """Select and sort the columns as observed in training. |
| 76 | +
|
| 77 | + This function is guaranteed to not collect values. |
| 78 | + """ |
| 79 | + # Note: This assumes that select ensures the column order. |
| 80 | + return df.select(*self.columns_) |
| 81 | + |
| 82 | + |
| 83 | +class Pipeline: |
| 84 | + """Linear pipeline of transformers.""" |
| 85 | + |
| 86 | + def __init__(self, steps: list[Any]) -> None: |
| 87 | + self.steps = steps |
| 88 | + |
| 89 | + def fit(self, df: DataFrame) -> Self: |
| 90 | + """Call fit on the steps of the pipeline subsequently. |
| 91 | +
|
| 92 | + Calling this function may trigger a collection. |
| 93 | + """ |
| 94 | + for step in self.steps: |
| 95 | + step.fit(df) |
| 96 | + |
| 97 | + self.steps_ = self.steps |
| 98 | + return self |
| 99 | + |
| 100 | + def transform(self, df: DataFrame) -> DataFrame: |
| 101 | + """Call transform on all steps of this pipeline subsequently. |
| 102 | +
|
| 103 | + This function is guaranteed to not trigger a collection. |
| 104 | + """ |
| 105 | + for step in self.steps_: |
| 106 | + df = step.transform(df) |
| 107 | + |
| 108 | + return df |
0 commit comments