Skip to content

Commit

Permalink
fix tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
BalzaniEdoardo committed Dec 20, 2024
1 parent 055dd45 commit 7bcdf96
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions docs/how_to_guide/plot_06_sklearn_pipeline_cv_demo.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ scores = np.zeros((len(regularizer_strength) * len(n_basis_funcs), n_folds))
coeffs = {}
# initialize basis and model
basis = nmo.basis.TransformerBasis(nmo.basis.RaisedCosineLinearEval(6))
basis = nmo.basis.RaisedCosineLinearEval(6).set_input_shape(1)
basis = nmo.basis.TransformerBasis(basis)
model = nmo.glm.GLM(regularizer="Ridge")
# loop over combinations
Expand Down Expand Up @@ -441,13 +442,13 @@ We are now able to capture the distribution of the firing rate appropriately: bo

In the previous example we set the number of basis functions of the [`Basis`](nemos.basis._basis.Basis) wrapped in our [`TransformerBasis`](nemos.basis._transformer_basis.TransformerBasis). However, if we are for example not sure about the type of basis functions we want to use, or we have already defined some basis functions of our own, then we can use cross-validation to directly evaluate those as well.

Here we include `transformerbasis___basis` in the parameter grid to try different values for `TransformerBasis._basis`:
Here we include `transformerbasis__basis` in the parameter grid to try different values for `TransformerBasis.basis`:


```{code-cell} ipython3
param_grid = dict(
glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
transformerbasis___basis=(
transformerbasis__basis=(
nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1),
nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1),
nmo.basis.RaisedCosineLogEval(5).set_input_shape(1),
Expand Down Expand Up @@ -481,7 +482,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_)
# Read out the number of basis functions
cvdf["transformerbasis_config"] = [
f"{b.__class__.__name__} - {b.n_basis_funcs}"
for b in cvdf["param_transformerbasis___basis"]
for b in cvdf["param_transformerbasis__basis"]
]
cvdf_wide = cvdf.pivot(
Expand Down Expand Up @@ -537,7 +538,7 @@ Please note that because it would lead to unexpected behavior, mixing the two wa
param_grid = dict(
glm__regularizer_strength=(0.1, 0.01, 0.001, 1e-6),
transformerbasis__n_basis_funcs=(3, 5, 10, 20, 100),
transformerbasis___basis=(
transformerbasis__basis=(
nmo.basis.RaisedCosineLinearEval(5).set_input_shape(1),
nmo.basis.RaisedCosineLinearEval(10).set_input_shape(1),
nmo.basis.RaisedCosineLogEval(5).set_input_shape(1),
Expand Down Expand Up @@ -592,7 +593,7 @@ cvdf = pd.DataFrame(gridsearch.cv_results_)
# Read out the number of basis functions
cvdf["transformerbasis_config"] = [
f"{b.__class__.__name__} - {b.n_basis_funcs}"
for b in cvdf["param_transformerbasis___basis"]
for b in cvdf["param_transformerbasis__basis"]
]
cvdf_wide = cvdf.pivot(
Expand Down

0 comments on commit 7bcdf96

Please sign in to comment.