From 9c0a22185f4ad7583302e107a6a4867ac6410334 Mon Sep 17 00:00:00 2001 From: Martin Stancsics Date: Thu, 19 Dec 2024 12:01:43 +0100 Subject: [PATCH] Make category expansion work --- src/glum/_glm.py | 6 +++++- src/glum/_util.py | 12 ++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/glum/_glm.py b/src/glum/_glm.py index 35e69b7d..456f4054 100644 --- a/src/glum/_glm.py +++ b/src/glum/_glm.py @@ -2484,7 +2484,11 @@ def _set_up_and_check_fit_args( for col in self.categorical_levels_ } - if any(X.dtypes == "category"): + if any( + dtype + for _, dtype in X.schema + if isinstance(dtype, (nw.Categorical, nw.Enum)) + ): # do we want to expand penalties for strings that we treat as categoricals? P1 = _expand_categorical_penalties( self.P1, X, drop_first, self.has_missing_category_ ) diff --git a/src/glum/_util.py b/src/glum/_util.py index 61216d0b..7904718c 100644 --- a/src/glum/_util.py +++ b/src/glum/_util.py @@ -112,7 +112,7 @@ def _add_missing_categories( def _expand_categorical_penalties( - penalty, X, drop_first, has_missing_category + penalty, X: nw.DataFrame, drop_first: bool, has_missing_category: dict[str, bool] ) -> Union[np.ndarray, str]: """Determine penalty matrices ``P1`` or ``P2`` after expanding categorical columns. @@ -135,9 +135,13 @@ def _expand_categorical_penalties( expanded_penalty = [] # type: ignore - for element, (column, dt) in zip(penalty, X.dtypes.items()): - if isinstance(dt, pd.CategoricalDtype): - length = len(dt.categories) + has_missing_category[column] - drop_first + for element, (column, dt) in zip(penalty, X.schema.items()): + if isinstance(dt, (nw.Enum, nw.Categorical)): + length = ( + len(X[column].cat.get_categories()) + + has_missing_category[column] + - drop_first + ) expanded_penalty.extend(element for _ in range(length)) else: expanded_penalty.append(element)