Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Ludvig committed Aug 30, 2024
1 parent d6d3486 commit e8f9f2c
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
6 changes: 4 additions & 2 deletions generalize/model/cross_validate/nested_cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

# TODO Consider order of the arguments
# TODO Should outer_split be allowed to have a split per repetition?
# TODO Add option to retrieve an attribute from the pipeline. E.g. "nmf__H" for NMF components.
# Should probably be a list of attribute names to extract from "best_estimator_" objects


def nested_cross_validate(
Expand Down Expand Up @@ -420,7 +422,7 @@ def nested_cross_validate(
# Ensure grid names point to their pipeline step
pipeline_keys = list(pipe.named_steps.keys())
for key in grid.keys():
if not "__" in key or key.split("__")[0] not in pipeline_keys:
if "__" not in key or key.split("__")[0] not in pipeline_keys:
messenger(f"Pipeline keys for debugging:\n{pipeline_keys}")
raise ValueError(
f"Grid keys must be prefixed by either 'model__' or "
Expand Down Expand Up @@ -1028,7 +1030,7 @@ def get_header(path):

# Change random IDs to letter IDs (AA, AB, AC, ...)
for in_res, in_coefs in zip(inner_results, best_coefficients):
if not "random_id" in in_res:
if "random_id" not in in_res:
messenger("Bad inner results data frame: ", in_res)
with messenger.indentation(add_indent=2):
messenger("with column names: ", in_res.columns)
Expand Down
3 changes: 1 addition & 2 deletions generalize/model/pipeline/pipeline_designer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import FunctionTransformer
from sklearn.base import BaseEstimator


from generalize.model.transformers import (
DimTransformerWrapper,
Expand Down Expand Up @@ -56,7 +56,6 @@ def _identity(x):


class PipelineDesigner:

# name -> (transformer, kwargs)
PRECONFIGURED_TRANSFORMERS: Dict[str, Tuple[Callable, Dict[str, Any]]] = {
"identity": (
Expand Down

0 comments on commit e8f9f2c

Please sign in to comment.