Skip to content

Commit c355d28

Browse files
authoredJan 21, 2025
Merge pull request #296 from flatironinstitute/hotfix_repr
Hotfix repr
2 parents ec5b381 + d329376 commit c355d28

File tree

6 files changed

+25
-8
lines changed

6 files changed

+25
-8
lines changed
 

‎pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "nemos"
7-
version = "0.2.0"
7+
version = "0.2.1"
88
authors = [{name = "nemos authors"}]
99
description = "NEural MOdelS, a statistical modeling framework for neuroscience."
1010
readme = "README.md"

‎src/nemos/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python3
2-
__version__ = "0.2.0"
2+
__version__ = "0.2.1"
33

44
from . import (
55
basis,

‎src/nemos/glm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,12 @@ def _get_optimal_solver_params_config(self):
11261126
def __repr__(self):
11271127
return format_repr(self, multiline=True)
11281128

1129+
def __sklearn_clone__(self) -> GLM:
1130+
"""Clone the PopulationGLM, dropping feature_mask"""
1131+
params = self.get_params(deep=False)
1132+
klass = self.__class__(**params)
1133+
return klass
1134+
11291135

11301136
class PopulationGLM(GLM):
11311137
"""
@@ -1633,7 +1639,7 @@ def _predict(
16331639
+ bs
16341640
)
16351641

1636-
def __sklearn_clone__(self) -> GLM:
1642+
def __sklearn_clone__(self) -> PopulationGLM:
16371643
"""Clone the PopulationGLM, dropping feature_mask"""
16381644
params = self.get_params(deep=False)
16391645
params.pop("feature_mask")

‎src/nemos/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ def format_repr(
534534
)
535535
if repr_param:
536536
if k in use_name_keys:
537-
v = v.__name__
537+
v = getattr(v, "__name__", repr(v))
538538
elif isinstance(v, str):
539539
v = repr(v)
540540
disp_params.append(f"{k}={v}")

‎tests/test_glm.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,12 +1478,16 @@ def test_deviance_against_statsmodels(self, poissonGLM_model_instantiation):
14781478
def test_compatibility_with_sklearn_cv(self, poissonGLM_model_instantiation):
14791479
X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
14801480
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
1481-
GridSearchCV(model, param_grid).fit(X, y)
1481+
cls = GridSearchCV(model, param_grid).fit(X, y)
1482+
# check that the repr works after cloning
1483+
repr(cls)
14821484

14831485
def test_compatibility_with_sklearn_cv_gamma(self, gammaGLM_model_instantiation):
14841486
X, y, model, true_params, firing_rate = gammaGLM_model_instantiation
14851487
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
1486-
GridSearchCV(model, param_grid).fit(X, y)
1488+
cls = GridSearchCV(model, param_grid).fit(X, y)
1489+
# check that the repr works after cloning
1490+
repr(cls)
14871491

14881492
@pytest.mark.parametrize(
14891493
"regr_setup, glm_class",
@@ -3572,12 +3576,16 @@ def test_deviance_against_statsmodels(self, poisson_population_GLM_model):
35723576
def test_compatibility_with_sklearn_cv(self, poisson_population_GLM_model):
35733577
X, y, model, true_params, firing_rate = poisson_population_GLM_model
35743578
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
3575-
GridSearchCV(model, param_grid).fit(X, y)
3579+
cls = GridSearchCV(model, param_grid).fit(X, y)
3580+
# check that the repr works after cloning
3581+
repr(cls)
35763582

35773583
def test_compatibility_with_sklearn_cv_gamma(self, gamma_population_GLM_model):
35783584
X, y, model, true_params, firing_rate = gamma_population_GLM_model
35793585
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
3580-
GridSearchCV(model, param_grid).fit(X, y)
3586+
cls = GridSearchCV(model, param_grid).fit(X, y)
3587+
# check that the repr works after cloning
3588+
repr(cls)
35813589

35823590
def test_sklearn_clone(self, poisson_population_GLM_model):
35833591
X, y, model, true_params, firing_rate = poisson_population_GLM_model

‎tests/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22
from contextlib import nullcontext as does_not_raise
3+
from copy import deepcopy
34

45
import jax
56
import jax.numpy as jnp
@@ -595,6 +596,8 @@ def __repr__(self):
595596
(Example(a=0, b=False, c=None), None, [], "Example(a=0, b=False, d=1)"),
596597
# Falsey values excluded2
597598
(Example(a=0, b=[], c={}), None, [], "Example(a=0, d=1)"),
599+
# function without the __name__
600+
(nmo.observation_models.PoissonObservations(deepcopy(jax.numpy.exp)),None, [], "PoissonObservations(inverse_link_function=<PjitFunction>)")
598601
],
599602
)
600603
def test_format_repr(obj, exclude_keys, use_name_keys, expected):

0 commit comments

Comments
 (0)