Skip to content

Commit 21271f5

Browse files
Add example of a sklearn like pipeline (#294)
* Add example of a sklearn like pipeline * Review comments * Update comment after discussion in #279 * update --------- Co-authored-by: MarcoGorelli <33491632+MarcoGorelli@users.noreply.github.com>
1 parent 44331de commit 21271f5

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Scikit-learn-like pipeline.
2+
3+
This is an example of how a (possibly) lazy data frame may be used
4+
in a sklearn-like pipeline.
5+
6+
The example is motivated by the prospect of a fully lazy, ONNX-based
7+
data frame implementation. The concept is that calls to `fit` are
8+
eager. They compute some state that is later meant to be transferred
9+
into the ONNX graph/model. That transfer happens when calling
10+
`transform`. The logic within the `transform` methods is "traced"
11+
lazily and the resulting lazy object is then exported to ONNX.
12+
"""
13+
from __future__ import annotations
14+
15+
from typing import TYPE_CHECKING, Any
16+
17+
if TYPE_CHECKING:
18+
from typing_extensions import Self
19+
20+
from dataframe_api.typing import Column, DataFrame, Scalar
21+
22+
23+
class Scaler:
24+
"""Apply a standardization scaling factor to `column_names`."""
25+
26+
scalings_: dict[str, Scalar]
27+
28+
def __init__(self, column_names: list[str]) -> None:
29+
self.column_names = column_names
30+
31+
def fit(self, df: DataFrame) -> Self:
32+
"""Compute scaling factors from given data frame.
33+
34+
A typical data science workflow is to fit on one dataset,
35+
and then transform and multiple datasets. Therefore, we
36+
make sure to `persist` within the `fit` method.
37+
"""
38+
scalings = df.select(*self.column_names).std().persist()
39+
40+
self.scalings_ = {
41+
column_name: scalings.col(column_name).get_value(0)
42+
for column_name in self.column_names
43+
}
44+
45+
return self
46+
47+
def transform(self, df: DataFrame) -> DataFrame:
48+
"""Apply the "trained" scaling values.
49+
50+
This function is guaranteed to not collect values.
51+
"""
52+
columns: list[Column] = []
53+
for column_name in df.column_names:
54+
if column_name not in self.column_names:
55+
continue
56+
column = df.col(column_name) / self.scalings_[column_name]
57+
columns.append(column)
58+
59+
# Note: `assign` is not in-place
60+
return df.assign(*columns)
61+
62+
63+
class FeatureSelector:
64+
"""Limit columns to those seen in training including their order."""
65+
66+
def fit(self, df: DataFrame) -> Self:
67+
"""Record the observed columns and their order.
68+
69+
This function is guaranteed to not collect values.
70+
"""
71+
self.columns_ = df.column_names
72+
return self
73+
74+
def transform(self, df: DataFrame) -> DataFrame:
75+
"""Select and sort the columns as observed in training.
76+
77+
This function is guaranteed to not collect values.
78+
"""
79+
# Note: This assumes that select ensures the column order.
80+
return df.select(*self.columns_)
81+
82+
83+
class Pipeline:
84+
"""Linear pipeline of transformers."""
85+
86+
def __init__(self, steps: list[Any]) -> None:
87+
self.steps = steps
88+
89+
def fit(self, df: DataFrame) -> Self:
90+
"""Call fit on the steps of the pipeline subsequently.
91+
92+
Calling this function may trigger a collection.
93+
"""
94+
for step in self.steps:
95+
step.fit(df)
96+
97+
self.steps_ = self.steps
98+
return self
99+
100+
def transform(self, df: DataFrame) -> DataFrame:
101+
"""Call transform on all steps of this pipeline subsequently.
102+
103+
This function is guaranteed to not trigger a collection.
104+
"""
105+
for step in self.steps_:
106+
df = step.transform(df)
107+
108+
return df

0 commit comments

Comments
 (0)