Skip to content

Commit

Permalink
Make filtering on main var as well. Debugging print
Browse files Browse the repository at this point in the history
  • Loading branch information
Ludvig committed Aug 26, 2024
1 parent 77580d0 commit 9184b3b
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions generalize/model/cross_validate/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,11 +607,11 @@ def simplest_model_refit_strategy(cv_results):
)
].copy()

# Only keep solutions where the `other_vars`
# Only keep solutions where all specified variables
# are equal to or higher/lower (specified per var)
# than the best solution
if other_vars is not None:
for var_nm, var_direction in reversed(other_vars):
for var_nm, var_direction in reversed(other_vars) + [main_var]:
made_threshold_cv_results = made_threshold_cv_results.loc[
get_direction_fn(var_direction)(
made_threshold_cv_results[var_nm],
Expand All @@ -622,16 +622,11 @@ def simplest_model_refit_strategy(cv_results):
ascending=var_direction == "minimize",
kind="stable", # NOTE: Required for iterative sorting!
)
print(made_threshold_cv_results)

selected_index = (
made_threshold_cv_results.sort_values(
by=main_var[0],
ascending=score_direction == "minimize",
kind="stable", # NOTE: Required for iterative sorting!
)
.reset_index(drop=True)
.loc[0, "original_index"]
)
selected_index = made_threshold_cv_results.reset_index(drop=True).loc[
0, "original_index"
]

if messenger.verbose:
messenger("simplest_model_refit_strategy: ")
Expand Down

0 comments on commit 9184b3b

Please sign in to comment.