Skip to content

Commit

Permalink
feat: make ColumnDropped dataframe-agnostic
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 5, 2024
1 parent 8cabda3 commit 97225a1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ maintainers = [
]

dependencies = [
"narwhals>=0.7.16",
"pandas>=1.1.5",
"scikit-learn>=1.0",
"importlib-metadata >= 1.0; python_version < '3.8'",
Expand Down Expand Up @@ -111,4 +112,3 @@ markers = [
"formulaic: tests that require formulaic (deselect with '-m \"not formulaic\"')",
"umap: tests that require umap (deselect with '-m \"not umap\"')"
]

17 changes: 6 additions & 11 deletions sklego/preprocessing/pandastransformers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import narwhals as nw
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
Expand Down Expand Up @@ -106,9 +107,9 @@ def fit(self, X, y=None):
If dropping the specified columns would result in an empty output DataFrame.
"""
self.columns_ = as_list(self.columns)
self._check_X_for_type(X)
X = nw.from_native(X)
self._check_column_names(X)
self.feature_names_ = X.columns.drop(self.columns_).tolist()
self.feature_names_ = [x for x in X.columns if x not in self.columns_]
self._check_column_length()
return self

Expand All @@ -131,10 +132,10 @@ def transform(self, X):
If `X` is not a `pd.DataFrame` object.
"""
check_is_fitted(self, ["feature_names_"])
self._check_X_for_type(X)
X = nw.from_native(X)
if self.columns_:
return X.drop(columns=self.columns_)
return X
return nw.to_native(X.drop(self.columns_))
return nw.to_native(X)

def get_feature_names(self):
"""Alias for `.feature_names_` attribute"""
Expand All @@ -151,12 +152,6 @@ def _check_column_names(self, X):
if len(non_existent_columns) > 0:
raise KeyError(f"{list(non_existent_columns)} column(s) not in DataFrame")

@staticmethod
def _check_X_for_type(X):
"""Checks if input of the Selector is of the required dtype"""
if not isinstance(X, pd.DataFrame):
raise TypeError("Provided variable X is not of type pandas.DataFrame")


class PandasTypeSelector(BaseEstimator, TransformerMixin):
"""The `PandasTypeSelector` transformer allows to select columns in a pandas DataFrame based on their type.
Expand Down
38 changes: 33 additions & 5 deletions tests/test_preprocessing/test_columndropper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import polars as pl
import pytest
from pandas.testing import assert_frame_equal
from pandas.testing import assert_frame_equal as pandas_assert_frame_equal
from polars.testing import assert_frame_equal as polars_assert_frame_equal
from sklearn.pipeline import make_pipeline

from sklego.preprocessing import ColumnDropper
Expand All @@ -19,6 +21,19 @@ def df():
)


@pytest.fixture()
def df_polars():
return pl.DataFrame(
{
"a": [1, 2, 3, 4, 5, 6],
"b": [10, 9, 8, 7, 6, 5],
"c": ["a", "b", "a", "b", "c", "c"],
"d": ["b", "a", "a", "b", "a", "b"],
"e": [0, 1, 0, 1, 0, 1],
}
)


def test_drop_two(df):
result_df = ColumnDropper(["a", "b"]).fit_transform(df)
expected_df = pd.DataFrame(
Expand All @@ -29,7 +44,7 @@ def test_drop_two(df):
}
)

assert_frame_equal(result_df, expected_df)
pandas_assert_frame_equal(result_df, expected_df)


def test_drop_one(df):
Expand All @@ -43,7 +58,7 @@ def test_drop_one(df):
}
)

assert_frame_equal(result_df, expected_df)
pandas_assert_frame_equal(result_df, expected_df)


def test_drop_all(df):
Expand All @@ -53,7 +68,7 @@ def test_drop_all(df):

def test_drop_none(df):
result_df = ColumnDropper([]).fit_transform(df)
assert_frame_equal(result_df, df)
pandas_assert_frame_equal(result_df, df)


def test_drop_not_in_frame(df):
Expand All @@ -73,10 +88,23 @@ def test_drop_one_in_pipeline(df):
}
)

assert_frame_equal(result_df, expected_df)
pandas_assert_frame_equal(result_df, expected_df)


def test_get_feature_names():
df = pd.DataFrame({"a": [4, 5, 6], "b": ["4", "5", "6"]})
transformer = ColumnDropper("a").fit(df)
assert transformer.get_feature_names() == ["b"]


def test_drop_two_polars(df_polars):
result_df = ColumnDropper(["a", "b"]).fit_transform(df_polars)
expected_df = pl.DataFrame(
{
"c": ["a", "b", "a", "b", "c", "c"],
"d": ["b", "a", "a", "b", "a", "b"],
"e": [0, 1, 0, 1, 0, 1],
}
)

polars_assert_frame_equal(result_df, expected_df)

0 comments on commit 97225a1

Please sign in to comment.