Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Development #210

Merged
merged 256 commits into from
Aug 14, 2024
Merged
Changes from 1 commit
Commits
Show all changes
256 commits
Select commit Hold shift + click to select a range
b074740
Merge branch 'development' into regul_optim_refactor
BalzaniEdoardo Jul 9, 2024
cb34670
refactor regularizers
clewis7 Jul 9, 2024
e152fcc
Merge remote-tracking branch 'origin' into regul_optim_refactor
clewis7 Jul 9, 2024
38a2366
move regularizer and solver to BaseRegressor
clewis7 Jul 9, 2024
54e1867
clean up solver and regularizer logic
clewis7 Jul 9, 2024
745e75c
fix get_params()
clewis7 Jul 9, 2024
6edc8a5
Merge branch 'regul_optim_refactor' of github.com:flatironinstitute/n…
BalzaniEdoardo Jul 9, 2024
2a6f5e1
fix regularizer strength param
clewis7 Jul 10, 2024
94f554d
fix lasso reg and solver instantiation
clewis7 Jul 10, 2024
fe743a5
requested changes
clewis7 Jul 10, 2024
8d0494b
add code comments and doc strings, lint files
clewis7 Jul 11, 2024
54973cd
fix lasso regularizer, move instantiate solver logic to base
clewis7 Jul 11, 2024
060eb36
instantiate solver updates
clewis7 Jul 11, 2024
69acde6
fix scaling for proximal gradient
clewis7 Jul 11, 2024
b66ad92
fix group lasso and model instantiation
clewis7 Jul 12, 2024
eacc0f7
small typo
clewis7 Jul 12, 2024
79b7866
shuffle things around
BalzaniEdoardo Jul 13, 2024
6d617c8
use isinstance for checking GroupLasso
BalzaniEdoardo Jul 13, 2024
13caece
remove type_checking
BalzaniEdoardo Jul 13, 2024
f15e176
fixed docs
BalzaniEdoardo Jul 13, 2024
3299d11
fixed table
BalzaniEdoardo Jul 15, 2024
60957a9
requested changes
clewis7 Jul 15, 2024
505a90b
added rescaling to all but exp decay basis
BalzaniEdoardo Jul 15, 2024
e65d75d
added rescaling to all but exp decay basis
BalzaniEdoardo Jul 15, 2024
786ac47
small changes
clewis7 Jul 15, 2024
adf31a1
add more docstrings
clewis7 Jul 15, 2024
4ff561c
black src
clewis7 Jul 15, 2024
d567efb
fix group lasso penalization
clewis7 Jul 15, 2024
722e9d1
trimmed parameters when default are available
BalzaniEdoardo Jul 16, 2024
1d13a46
added description of regularizer as string param
BalzaniEdoardo Jul 16, 2024
88c140e
requested changes
clewis7 Jul 16, 2024
3053259
Create a copy of basis in TransformerBasis.__init__
bagibence Jul 16, 2024
008df44
modified basis
BalzaniEdoardo Jul 16, 2024
11091ac
Make TransformerBasis compatible with pickle and CV, define __setattr…
bagibence Jul 16, 2024
8a13aca
Add tests for TransformerBasis
bagibence Jul 16, 2024
865765c
Don't set the original's kernel to None in TransformerBasis.__sklearn…
bagibence Jul 16, 2024
b8f4560
update tests
clewis7 Jul 16, 2024
4e0dcd2
Rename Basis._kernel to kernel_ to follow scikit-learn's convention
bagibence Jul 16, 2024
86b1c4d
remove switch case to support python3.9
clewis7 Jul 16, 2024
756222e
python3.9 must use Union for type checking
clewis7 Jul 16, 2024
9be4ad3
fixed tests
BalzaniEdoardo Jul 16, 2024
fceea56
run ci only for non-draft prs
BalzaniEdoardo Jul 17, 2024
a1f524c
Merge pull request #192 from flatironinstitute/fix_draft_pr_ci
BalzaniEdoardo Jul 17, 2024
b3d549f
Merge branch 'development' into regul_optim_refactor
BalzaniEdoardo Jul 17, 2024
0561ec4
increase test coverage
clewis7 Jul 17, 2024
64b2019
improved dof calc
BalzaniEdoardo Jul 17, 2024
94b6400
Merge branch 'regul_optim_refactor' of github.com:flatironinstitute/n…
BalzaniEdoardo Jul 17, 2024
f886eb1
Have Basis.to_transformer make a copy and update docs
bagibence Jul 17, 2024
be9fe56
Test that copying in Basis.to_transformer works
bagibence Jul 17, 2024
16e8289
fix bug
clewis7 Jul 17, 2024
ce65ab0
lint files
clewis7 Jul 17, 2024
e60957f
Define addition, multiplication, exponentiation for TransformerBasis
bagibence Jul 17, 2024
092b631
fix dof calc with pytrees, add regularizer tests
clewis7 Jul 17, 2024
d670a23
Add TransformerBasis to __all__ and update tests accordingly
bagibence Jul 17, 2024
e170a38
Add tests for TransformerBasis.__dir__
bagibence Jul 17, 2024
5641f0b
more tests, increase test coverage
clewis7 Jul 17, 2024
9f0dc51
change _call logic
BalzaniEdoardo Jul 17, 2024
956ee6a
fixed basis and tests
BalzaniEdoardo Jul 17, 2024
fe9a65d
bugfix cosyne
BalzaniEdoardo Jul 17, 2024
fb84bf2
changes
BalzaniEdoardo Jul 18, 2024
874b1f1
reverted changes
BalzaniEdoardo Jul 18, 2024
54a5872
removed unused
BalzaniEdoardo Jul 18, 2024
6d92592
renamed tests
BalzaniEdoardo Jul 18, 2024
ed3545a
fixed tests
BalzaniEdoardo Jul 18, 2024
e83d207
added tests for basis
BalzaniEdoardo Jul 18, 2024
b3f6b2b
linted
BalzaniEdoardo Jul 18, 2024
992a1e1
simplified call logic
BalzaniEdoardo Jul 18, 2024
7ae9b52
removed docsrtings incorrect info
BalzaniEdoardo Jul 18, 2024
a5a0c5b
bugfix ridge prox, removed if statement on lasso;
BalzaniEdoardo Jul 18, 2024
00f4208
Merge branch 'regul_optim_refactor' of github.com:flatironinstitute/n…
BalzaniEdoardo Jul 18, 2024
fb66051
simplified logic dof
BalzaniEdoardo Jul 18, 2024
9b5e5d2
move regularizer_stregnth to base regressor, update tests
clewis7 Jul 18, 2024
1f650aa
remove unused import
clewis7 Jul 18, 2024
cf6a38e
added tests improved dof calculation
BalzaniEdoardo Jul 18, 2024
87d577e
added tests improved dof calculation
BalzaniEdoardo Jul 18, 2024
2f8ab2d
fixed docs
BalzaniEdoardo Jul 18, 2024
3033d2f
added tests that must pass
BalzaniEdoardo Jul 18, 2024
3bd9cde
isorted tests
BalzaniEdoardo Jul 18, 2024
cc14f69
added test to pop glm
BalzaniEdoardo Jul 18, 2024
a862db9
fix regularizer strength param, update tests
clewis7 Jul 19, 2024
897dca6
switched to bounds
BalzaniEdoardo Jul 19, 2024
cca8075
fixed tests basis
BalzaniEdoardo Jul 19, 2024
998cfd3
linted
BalzaniEdoardo Jul 19, 2024
0c31798
fixed docs
BalzaniEdoardo Jul 19, 2024
ad9ce69
improved docstrings
BalzaniEdoardo Jul 19, 2024
b39d23b
Update src/nemos/basis.py
BalzaniEdoardo Jul 19, 2024
e2a6a92
removed unused if else
BalzaniEdoardo Jul 19, 2024
ea33517
black
BalzaniEdoardo Jul 19, 2024
71c9e3f
some requested changes
clewis7 Jul 22, 2024
d4b56cb
make degrees of freedom an attribute of glm, make calculation private
clewis7 Jul 22, 2024
c7faf44
separate param initialization and solver initialization
clewis7 Jul 22, 2024
2bece54
changed precision on clip
BalzaniEdoardo Jul 22, 2024
d7b47f1
enabled float64 for precise comparrison
BalzaniEdoardo Jul 22, 2024
a189cae
fix docs
clewis7 Jul 22, 2024
0ae67f3
fix docs
clewis7 Jul 22, 2024
781f8bb
change scale param in glm to scale_, update tests
clewis7 Jul 22, 2024
1a535f5
fix docs
clewis7 Jul 22, 2024
84914a9
fixed docs
BalzaniEdoardo Jul 23, 2024
feef74b
make _penalized() methods static
clewis7 Jul 23, 2024
b786369
remove LBFGSB
clewis7 Jul 23, 2024
be6e379
start working on updating contributing guide
clewis7 Jul 23, 2024
bbcc472
Apply suggestions from code review
clewis7 Jul 23, 2024
14f3c86
requested changes
clewis7 Jul 23, 2024
e63ad21
more contrib guide, stop CI run on draft
clewis7 Jul 23, 2024
a5d36bf
don't think this should run on draft either
clewis7 Jul 23, 2024
18c4630
Update src/nemos/_regularizer_builder.py
BalzaniEdoardo Jul 24, 2024
3949cc4
moved solver instantiation to base regressor
BalzaniEdoardo Jul 24, 2024
4f0a12d
deepcopied solver_kwargs
BalzaniEdoardo Jul 24, 2024
094b61a
fixed tests
BalzaniEdoardo Jul 24, 2024
404a266
removed unused inputs
BalzaniEdoardo Jul 24, 2024
769c312
additional text fixes
BalzaniEdoardo Jul 24, 2024
5c745ec
final text fixes and linting
BalzaniEdoardo Jul 24, 2024
6f4866b
improved comment
BalzaniEdoardo Jul 24, 2024
0f12021
added test and improved docs for GroupLasso mask
BalzaniEdoardo Jul 24, 2024
7ff0df9
Merge pull request #183 from flatironinstitute/regul_optim_refactor
BalzaniEdoardo Jul 24, 2024
87485a5
Merge branch 'development' into min_max_basis
BalzaniEdoardo Jul 24, 2024
54e0326
add tests for checking get_params()
clewis7 Jul 24, 2024
12fa0b6
added description on how to contribute to basis and updated dev notes
BalzaniEdoardo Jul 24, 2024
a00536b
changeing to dev notes
BalzaniEdoardo Jul 24, 2024
e68ccfc
merged dev
BalzaniEdoardo Jul 24, 2024
5ecc8cc
requested changes
clewis7 Jul 24, 2024
a0e873b
updated base regressor note
BalzaniEdoardo Jul 24, 2024
9672994
updated note on observation models
BalzaniEdoardo Jul 24, 2024
f33b559
fixed regularizer note
BalzaniEdoardo Jul 24, 2024
8cbe550
uodated glm guidelines
BalzaniEdoardo Jul 24, 2024
ad4d5e7
removed warning
BalzaniEdoardo Jul 24, 2024
3c127b5
renamed stuff
BalzaniEdoardo Jul 24, 2024
0344360
renamed stuff
BalzaniEdoardo Jul 24, 2024
fb35931
Merge branch 'contributing_guide' of github.com:flatironinstitute/nem…
BalzaniEdoardo Jul 24, 2024
849d55a
updated descr
BalzaniEdoardo Jul 24, 2024
df263d9
add more info
BalzaniEdoardo Jul 24, 2024
05ba9bf
correct method name
BalzaniEdoardo Jul 24, 2024
5d04de0
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
eb046cb
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
ee53981
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
e3057c0
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
51f344b
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
68b956a
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
7524ea3
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
09fb337
Update tests/test_glm.py
BalzaniEdoardo Jul 24, 2024
680027e
some minor tweaks
clewis7 Jul 24, 2024
58b375d
use sets when possible
BalzaniEdoardo Jul 24, 2024
8004c0a
small test typo
clewis7 Jul 24, 2024
fcba09c
Merge pull request #200 from flatironinstitute/get_params
BalzaniEdoardo Jul 24, 2024
33372fd
modified scheme
BalzaniEdoardo Jul 25, 2024
26dd16b
Merge branch 'contributing_guide' of github.com:flatironinstitute/nem…
BalzaniEdoardo Jul 25, 2024
b762094
Update CONTRIBUTING.md
BalzaniEdoardo Jul 25, 2024
55f6218
Update CONTRIBUTING.md
BalzaniEdoardo Jul 25, 2024
b9fd3f0
modified notes
BalzaniEdoardo Jul 25, 2024
393b882
fixed gitflow info
BalzaniEdoardo Jul 25, 2024
ddfed69
fixed notes
BalzaniEdoardo Jul 25, 2024
15008f4
added legend
BalzaniEdoardo Jul 25, 2024
7753105
added example
BalzaniEdoardo Jul 25, 2024
07c3494
added more info on docs
BalzaniEdoardo Jul 25, 2024
da1ba91
added more info on ipynb
BalzaniEdoardo Jul 25, 2024
f29c8e1
added info on methods
BalzaniEdoardo Jul 25, 2024
1fb8f46
typo
BalzaniEdoardo Jul 25, 2024
fc27bb4
update called
BalzaniEdoardo Jul 25, 2024
91611d9
describe the attrs and how they should be used
BalzaniEdoardo Jul 25, 2024
d748877
removed examples and concrete implementations
BalzaniEdoardo Jul 25, 2024
2ee3fde
removed examples and concrete implementations
BalzaniEdoardo Jul 25, 2024
ec9b078
improve guidelines
BalzaniEdoardo Jul 25, 2024
48c33e8
important mark
BalzaniEdoardo Jul 25, 2024
bc57cbe
merged min/max and regul refactor
BalzaniEdoardo Jul 25, 2024
8dd83f1
linted
BalzaniEdoardo Jul 25, 2024
2be4574
added example
BalzaniEdoardo Jul 25, 2024
a29eb3e
linted docstrings
BalzaniEdoardo Jul 25, 2024
fb2af51
linted docstrings
BalzaniEdoardo Jul 25, 2024
507fd1b
fixed docs
BalzaniEdoardo Jul 25, 2024
5da9ab1
updated info about gamma
BalzaniEdoardo Jul 26, 2024
b94951a
updated info about Base
BalzaniEdoardo Jul 26, 2024
de71113
fixed grammar
BalzaniEdoardo Jul 26, 2024
fe98584
fixed grammar
BalzaniEdoardo Jul 26, 2024
14b6a1e
imporoved description
BalzaniEdoardo Jul 26, 2024
7c99155
Update src/nemos/basis.py
BalzaniEdoardo Jul 27, 2024
390e84a
Update src/nemos/basis.py
BalzaniEdoardo Jul 27, 2024
a3e1ebd
moved to setter, surfaced exception
BalzaniEdoardo Jul 27, 2024
631dcd8
Update src/nemos/basis.py
BalzaniEdoardo Jul 27, 2024
02abd27
Merge branch 'min_max_basis' of github.com:flatironinstitute/nemos in…
BalzaniEdoardo Jul 27, 2024
2316f19
set to numnerical precision
BalzaniEdoardo Jul 27, 2024
3e12e39
Update src/nemos/basis.py
BalzaniEdoardo Jul 27, 2024
5ae3d06
added example for mspline
BalzaniEdoardo Jul 27, 2024
5735a3b
Update src/nemos/basis.py
BalzaniEdoardo Jul 27, 2024
3cfe3e0
linted
BalzaniEdoardo Jul 27, 2024
06db929
Merge branch 'min_max_basis' of github.com:flatironinstitute/nemos in…
BalzaniEdoardo Jul 27, 2024
e3d4a32
merged
BalzaniEdoardo Jul 27, 2024
8784eac
Merge pull request #191 from flatironinstitute/min_max_basis
BalzaniEdoardo Jul 27, 2024
0c757f5
merged conflicts
BalzaniEdoardo Jul 27, 2024
296b449
added test for get params transformer
BalzaniEdoardo Jul 27, 2024
06129b4
substantial changes in the plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Jul 29, 2024
bac99f4
refined tutorial
BalzaniEdoardo Jul 29, 2024
cc2285c
improved content of tutorial
BalzaniEdoardo Jul 29, 2024
eb3ce98
fixed capitalization
BalzaniEdoardo Jul 31, 2024
05487f0
fix capitalization readme
BalzaniEdoardo Jul 31, 2024
d8f188a
added schematic of pipeline
BalzaniEdoardo Jul 31, 2024
7b5b4db
Improve docstrings
bagibence Aug 2, 2024
058a738
fix name of workflow
billbrod Aug 5, 2024
0109eb5
fix alert syntax
billbrod Aug 5, 2024
7e2be08
adds missing word
billbrod Aug 5, 2024
a6d5a4b
small text fix
billbrod Aug 5, 2024
17d1a24
describe releases
billbrod Aug 5, 2024
d3e11fe
merged
BalzaniEdoardo Aug 6, 2024
ce30a41
added note on tox
BalzaniEdoardo Aug 6, 2024
7d47df5
updated tox ini
BalzaniEdoardo Aug 6, 2024
47bed51
Update CONTRIBUTING.md
BalzaniEdoardo Aug 8, 2024
7430340
Update docs/developers_notes/04-regularizer.md
BalzaniEdoardo Aug 8, 2024
90d7ffc
Merge pull request #198 from flatironinstitute/contributing_guide
BalzaniEdoardo Aug 8, 2024
7015833
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 8, 2024
f22d222
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 8, 2024
d527734
updated quickstart link and note on labeling
BalzaniEdoardo Aug 8, 2024
7d58c54
Merge branch 'issue_156' of github.com:bagibence/nemos into issue_156
BalzaniEdoardo Aug 8, 2024
5013473
fix convergence tests
clewis7 Aug 8, 2024
75c057d
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 8, 2024
6f10240
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 8, 2024
5693b77
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 8, 2024
073c935
added a sentence
BalzaniEdoardo Aug 8, 2024
96c8a76
Merge branch 'issue_156' of github.com:bagibence/nemos into issue_156
BalzaniEdoardo Aug 8, 2024
5379f37
added explicit opening figure
BalzaniEdoardo Aug 8, 2024
b7e0414
improved description
BalzaniEdoardo Aug 8, 2024
e2cee2c
added boilerplate code
BalzaniEdoardo Aug 8, 2024
e281225
Add some tests
bagibence Aug 8, 2024
c537679
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 10, 2024
d944c6d
addressed comments
BalzaniEdoardo Aug 10, 2024
8aeedc7
Merge branch 'issue_156' of github.com:bagibence/nemos into issue_156
BalzaniEdoardo Aug 10, 2024
66bf72b
Merge branch 'development' into issue_156
BalzaniEdoardo Aug 10, 2024
5f52fed
Update src/nemos/basis.py
BalzaniEdoardo Aug 10, 2024
0ac28e7
Update src/nemos/basis.py
BalzaniEdoardo Aug 10, 2024
d27f0c6
Update src/nemos/basis.py
BalzaniEdoardo Aug 10, 2024
18c19e5
Update src/nemos/basis.py
BalzaniEdoardo Aug 10, 2024
91f5e7a
improved example
BalzaniEdoardo Aug 10, 2024
e175cf6
moved pipeline tests
BalzaniEdoardo Aug 10, 2024
749c522
Merge branch 'issue_156' of github.com:bagibence/nemos into issue_156
BalzaniEdoardo Aug 10, 2024
016ea2e
change error msg
BalzaniEdoardo Aug 10, 2024
93767ef
fixed prox-operator loss
BalzaniEdoardo Aug 12, 2024
205c4e8
test with groups larger than 1
BalzaniEdoardo Aug 12, 2024
49d8577
linted
BalzaniEdoardo Aug 12, 2024
c3dbe3d
Merge pull request #206 from flatironinstitute/update_converge
BalzaniEdoardo Aug 12, 2024
18cf739
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 13, 2024
aab99c8
Merge branch 'development' into issue_156
BalzaniEdoardo Aug 13, 2024
746b45e
Merge branch 'issue_156' of github.com:bagibence/nemos into issue_156
BalzaniEdoardo Aug 13, 2024
1f5cff0
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 13, 2024
01bb2b5
Update docs/api_guide/plot_06_sklearn_pipeline_cv_demo.py
BalzaniEdoardo Aug 13, 2024
a16bc6d
Update docs/quickstart.md
BalzaniEdoardo Aug 13, 2024
58b9d99
added emojis
BalzaniEdoardo Aug 13, 2024
935e911
added cv figure in svg
BalzaniEdoardo Aug 13, 2024
7a10328
Merge branch 'issue_156' of github.com:bagibence/nemos into issue_156
BalzaniEdoardo Aug 13, 2024
6008602
Merge pull request #169 from bagibence/issue_156
BalzaniEdoardo Aug 13, 2024
32132b1
fixed basis docstrings
BalzaniEdoardo Aug 14, 2024
b6b0789
Merge pull request #209 from flatironinstitute/docstrings_call
BalzaniEdoardo Aug 14, 2024
b5f0808
updated package version
BalzaniEdoardo Aug 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
moved pipeline tests
BalzaniEdoardo committed Aug 10, 2024
commit e175cf6b56056e37569b6fd6f9afb0540a5aae43
160 changes: 0 additions & 160 deletions tests/test_basis.py
Original file line number Diff line number Diff line change
@@ -4699,166 +4699,6 @@ def test_transformerbasis_pickle(tmpdir, basis_cls, n_basis_funcs):
assert trans_bas2.n_basis_funcs == n_basis_funcs


@pytest.mark.parametrize(
"bas",
[
basis.MSplineBasis(5),
basis.BSplineBasis(5),
basis.CyclicBSplineBasis(5),
basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)),
basis.RaisedCosineBasisLinear(5),
]
)
def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])

pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas",
[
basis.MSplineBasis(5),
basis.BSplineBasis(5),
basis.CyclicBSplineBasis(5),
basis.RaisedCosineBasisLinear(5),
basis.RaisedCosineBasisLog(5),
],
)
def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(3, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv = 3)
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)

@pytest.mark.parametrize(
"bas",
[
basis.MSplineBasis(5),
basis.BSplineBasis(5),
basis.CyclicBSplineBasis(5),
basis.RaisedCosineBasisLinear(5),
basis.RaisedCosineBasisLog(5),
],
)
def test_sklearn_transformer_pipeline_cv_multiprocess(bas, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(3, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv = 3, n_jobs=3)
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)

@pytest.mark.parametrize(
"bas_cls",
[
basis.MSplineBasis,
basis.BSplineBasis,
basis.CyclicBSplineBasis,
basis.RaisedCosineBasisLinear,
basis.RaisedCosineBasisLog,
],
)
def test_sklearn_transformer_pipeline_cv_directly_over_basis(bas_cls, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas_cls(5))
pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)])
param_grid = dict(
transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20))
)
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv = 3)
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)

@pytest.mark.parametrize(
"bas_cls",
[
basis.MSplineBasis,
basis.BSplineBasis,
basis.CyclicBSplineBasis,
basis.RaisedCosineBasisLinear,
basis.RaisedCosineBasisLog,
],
)
def test_sklearn_transformer_pipeline_cv_illegal_combination(bas_cls, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas_cls(5))
pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)])
param_grid = dict(
transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)),
transformerbasis__n_basis_funcs=(3, 5, 10),
)
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv = 3)
with pytest.raises(
ValueError, match="Set either _basis or parameters for _basis, not both."
):
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas, expected_nans",
[
(basis.MSplineBasis(5), 0),
(basis.BSplineBasis(5), 0),
(basis.CyclicBSplineBasis(5), 0),
(basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)), 0),
(basis.RaisedCosineBasisLinear(5), 0),
(basis.RaisedCosineBasisLog(5), 0),
(basis.RaisedCosineBasisLog(5) + basis.MSplineBasis(5), 0),
(basis.MSplineBasis(5, mode="conv", window_size=3), 6),
(basis.BSplineBasis(5, mode="conv", window_size=3), 6),
(
basis.CyclicBSplineBasis(
5, mode="conv", window_size=3, predictor_causality="acausal"
),
4,
),
(
basis.OrthExponentialBasis(
5, decay_rates=np.linspace(0.1, 1, 5), mode="conv", window_size=7
),
14,
),
(basis.RaisedCosineBasisLinear(5, mode="conv", window_size=3), 6),
(basis.RaisedCosineBasisLog(5, mode="conv", window_size=3), 6),
(
basis.RaisedCosineBasisLog(5, mode="conv", window_size=3)
+ basis.MSplineBasis(5),
6,
),
(
basis.RaisedCosineBasisLog(5, mode="conv", window_size=3)
* basis.MSplineBasis(5),
6,
),
],
)
def test_sklearn_transformer_pipeline_pynapple(
bas, poissonGLM_model_instantiation, expected_nans
):
X, y, model, _, _ = poissonGLM_model_instantiation

# transform input to pynapple
ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]])
X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep)
y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep)
bas = basis.TransformerBasis(bas)
# fit a pipeline & predict from pynapple
pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])
pipe.fit(X_nap[:, : bas._basis._n_input_dimensionality] ** 2, y_nap)

# get rate
rate = pipe.predict(X_nap[:, : bas._basis._n_input_dimensionality] ** 2)
# check rate is Tsd with same time info
assert isinstance(rate, nap.Tsd)
assert np.all(rate.t == X_nap.t)
assert np.all(rate.time_support == X_nap.time_support)
assert np.sum(np.isnan(rate.d)) == expected_nans


@pytest.mark.parametrize(
"tsd",
[
174 changes: 174 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import numpy as np
import pynapple as nap
import pytest
from sklearn import pipeline
from sklearn.model_selection import GridSearchCV

from nemos import basis


@pytest.mark.parametrize(
"bas",
[
basis.MSplineBasis(5),
basis.BSplineBasis(5),
basis.CyclicBSplineBasis(5),
basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)),
basis.RaisedCosineBasisLinear(5),
],
)
def test_sklearn_transformer_pipeline(bas, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])

pipe.fit(X[:, : bas._basis._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas",
[
basis.MSplineBasis(5),
basis.BSplineBasis(5),
basis.CyclicBSplineBasis(5),
basis.RaisedCosineBasisLinear(5),
basis.RaisedCosineBasisLog(5),
],
)
def test_sklearn_transformer_pipeline_cv(bas, poissonGLM_model_instantiation):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(3, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3)
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas",
[
basis.MSplineBasis(5),
basis.BSplineBasis(5),
basis.CyclicBSplineBasis(5),
basis.RaisedCosineBasisLinear(5),
basis.RaisedCosineBasisLog(5),
],
)
def test_sklearn_transformer_pipeline_cv_multiprocess(
bas, poissonGLM_model_instantiation
):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas)
pipe = pipeline.Pipeline([("basis", bas), ("fit", model)])
param_grid = dict(basis__n_basis_funcs=(3, 5, 10))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3, n_jobs=3)
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas_cls",
[
basis.MSplineBasis,
basis.BSplineBasis,
basis.CyclicBSplineBasis,
basis.RaisedCosineBasisLinear,
basis.RaisedCosineBasisLog,
],
)
def test_sklearn_transformer_pipeline_cv_directly_over_basis(
bas_cls, poissonGLM_model_instantiation
):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas_cls(5))
pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)])
param_grid = dict(transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)))
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3)
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas_cls",
[
basis.MSplineBasis,
basis.BSplineBasis,
basis.CyclicBSplineBasis,
basis.RaisedCosineBasisLinear,
basis.RaisedCosineBasisLog,
],
)
def test_sklearn_transformer_pipeline_cv_illegal_combination(
bas_cls, poissonGLM_model_instantiation
):
X, y, model, _, _ = poissonGLM_model_instantiation
bas = basis.TransformerBasis(bas_cls(5))
pipe = pipeline.Pipeline([("transformerbasis", bas), ("fit", model)])
param_grid = dict(
transformerbasis___basis=(bas_cls(5), bas_cls(10), bas_cls(20)),
transformerbasis__n_basis_funcs=(3, 5, 10),
)
gridsearch = GridSearchCV(pipe, param_grid=param_grid, cv=3)
with pytest.raises(
ValueError, match="Set either _basis or parameters for _basis, not both."
):
gridsearch.fit(X[:, : bas._n_input_dimensionality] ** 2, y)


@pytest.mark.parametrize(
"bas, expected_nans",
[
(basis.MSplineBasis(5), 0),
(basis.BSplineBasis(5), 0),
(basis.CyclicBSplineBasis(5), 0),
(basis.OrthExponentialBasis(5, decay_rates=np.arange(1, 6)), 0),
(basis.RaisedCosineBasisLinear(5), 0),
(basis.RaisedCosineBasisLog(5), 0),
(basis.RaisedCosineBasisLog(5) + basis.MSplineBasis(5), 0),
(basis.MSplineBasis(5, mode="conv", window_size=3), 6),
(basis.BSplineBasis(5, mode="conv", window_size=3), 6),
(
basis.CyclicBSplineBasis(
5, mode="conv", window_size=3, predictor_causality="acausal"
),
4,
),
(
basis.OrthExponentialBasis(
5, decay_rates=np.linspace(0.1, 1, 5), mode="conv", window_size=7
),
14,
),
(basis.RaisedCosineBasisLinear(5, mode="conv", window_size=3), 6),
(basis.RaisedCosineBasisLog(5, mode="conv", window_size=3), 6),
(
basis.RaisedCosineBasisLog(5, mode="conv", window_size=3)
+ basis.MSplineBasis(5),
6,
),
(
basis.RaisedCosineBasisLog(5, mode="conv", window_size=3)
* basis.MSplineBasis(5),
6,
),
],
)
def test_sklearn_transformer_pipeline_pynapple(
bas, poissonGLM_model_instantiation, expected_nans
):
X, y, model, _, _ = poissonGLM_model_instantiation

# transform input to pynapple
ep = nap.IntervalSet(start=[0, 20.5], end=[20, X.shape[0]])
X_nap = nap.TsdFrame(t=np.arange(X.shape[0]), d=X, time_support=ep)
y_nap = nap.Tsd(t=np.arange(X.shape[0]), d=y, time_support=ep)
bas = basis.TransformerBasis(bas)
# fit a pipeline & predict from pynapple
pipe = pipeline.Pipeline([("eval", bas), ("fit", model)])
pipe.fit(X_nap[:, : bas._basis._n_input_dimensionality] ** 2, y_nap)

# get rate
rate = pipe.predict(X_nap[:, : bas._basis._n_input_dimensionality] ** 2)
# check rate is Tsd with same time info
assert isinstance(rate, nap.Tsd)
assert np.all(rate.t == X_nap.t)
assert np.all(rate.time_support == X_nap.time_support)
assert np.sum(np.isnan(rate.d)) == expected_nans