Skip to content

Commit

Permalink
Make category expansion work
Browse files Browse the repository at this point in the history
  • Loading branch information
stanmart committed Dec 19, 2024
1 parent 5995d78 commit 9c0a221
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/glum/_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,7 +2484,11 @@ def _set_up_and_check_fit_args(
for col in self.categorical_levels_
}

if any(X.dtypes == "category"):
if any(
dtype
for _, dtype in X.schema
if isinstance(dtype, (nw.Categorical, nw.Enum))
): # do we want to expand penalties for strings that we treat as categoricals?
P1 = _expand_categorical_penalties(
self.P1, X, drop_first, self.has_missing_category_
)
Expand Down
12 changes: 8 additions & 4 deletions src/glum/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _add_missing_categories(


def _expand_categorical_penalties(
penalty, X, drop_first, has_missing_category
penalty, X: nw.DataFrame, drop_first: bool, has_missing_category: dict[str, bool]
) -> Union[np.ndarray, str]:
"""Determine penalty matrices ``P1`` or ``P2`` after expanding categorical columns.
Expand All @@ -135,9 +135,13 @@ def _expand_categorical_penalties(

expanded_penalty = [] # type: ignore

for element, (column, dt) in zip(penalty, X.dtypes.items()):
if isinstance(dt, pd.CategoricalDtype):
length = len(dt.categories) + has_missing_category[column] - drop_first
for element, (column, dt) in zip(penalty, X.schema.items()):
if isinstance(dt, (nw.Enum, nw.Categorical)):
length = (
len(X[column].cat.get_categories())
+ has_missing_category[column]
- drop_first
)
expanded_penalty.extend(element for _ in range(length))
else:
expanded_penalty.append(element)
Expand Down

0 comments on commit 9c0a221

Please sign in to comment.