Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hotfix repr #296

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "nemos"
version = "0.2.0"
version = "0.2.1"
authors = [{name = "nemos authors"}]
description = "NEural MOdelS, a statistical modeling framework for neuroscience."
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
__version__ = "0.2.0"
__version__ = "0.2.1"

from . import (
basis,
Expand Down
8 changes: 7 additions & 1 deletion src/nemos/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,12 @@ def _get_optimal_solver_params_config(self):
def __repr__(self):
return format_repr(self, multiline=True)

def __sklearn_clone__(self) -> GLM:
"""Clone the PopulationGLM, dropping feature_mask"""
params = self.get_params(deep=False)
klass = self.__class__(**params)
return klass


class PopulationGLM(GLM):
"""
Expand Down Expand Up @@ -1633,7 +1639,7 @@ def _predict(
+ bs
)

def __sklearn_clone__(self) -> GLM:
def __sklearn_clone__(self) -> PopulationGLM:
"""Clone the PopulationGLM, dropping feature_mask"""
params = self.get_params(deep=False)
params.pop("feature_mask")
Expand Down
2 changes: 1 addition & 1 deletion src/nemos/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def format_repr(
)
if repr_param:
if k in use_name_keys:
v = v.__name__
v = getattr(v, "__name__", repr(v))
elif isinstance(v, str):
v = repr(v)
disp_params.append(f"{k}={v}")
Expand Down
16 changes: 12 additions & 4 deletions tests/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,12 +1478,16 @@ def test_deviance_against_statsmodels(self, poissonGLM_model_instantiation):
def test_compatibility_with_sklearn_cv(self, poissonGLM_model_instantiation):
X, y, model, true_params, firing_rate = poissonGLM_model_instantiation
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

def test_compatibility_with_sklearn_cv_gamma(self, gammaGLM_model_instantiation):
X, y, model, true_params, firing_rate = gammaGLM_model_instantiation
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

@pytest.mark.parametrize(
"regr_setup, glm_class",
Expand Down Expand Up @@ -3572,12 +3576,16 @@ def test_deviance_against_statsmodels(self, poisson_population_GLM_model):
def test_compatibility_with_sklearn_cv(self, poisson_population_GLM_model):
X, y, model, true_params, firing_rate = poisson_population_GLM_model
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

def test_compatibility_with_sklearn_cv_gamma(self, gamma_population_GLM_model):
X, y, model, true_params, firing_rate = gamma_population_GLM_model
param_grid = {"solver_name": ["BFGS", "GradientDescent"]}
GridSearchCV(model, param_grid).fit(X, y)
cls = GridSearchCV(model, param_grid).fit(X, y)
# check that the repr works after cloning
repr(cls)

def test_sklearn_clone(self, poisson_population_GLM_model):
X, y, model, true_params, firing_rate = poisson_population_GLM_model
Expand Down
3 changes: 3 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import warnings
from contextlib import nullcontext as does_not_raise
from copy import deepcopy

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -595,6 +596,8 @@ def __repr__(self):
(Example(a=0, b=False, c=None), None, [], "Example(a=0, b=False, d=1)"),
# Falsey values excluded2
(Example(a=0, b=[], c={}), None, [], "Example(a=0, d=1)"),
# function without the __name__
(nmo.observation_models.PoissonObservations(deepcopy(jax.numpy.exp)),None, [], "PoissonObservations(inverse_link_function=<PjitFunction>)")
],
)
def test_format_repr(obj, exclude_keys, use_name_keys, expected):
Expand Down
Loading