diff --git a/bofire/surrogates/feature_importance.py b/bofire/surrogates/feature_importance.py index 10d5fb088..46ab0402d 100644 --- a/bofire/surrogates/feature_importance.py +++ b/bofire/surrogates/feature_importance.py @@ -48,10 +48,13 @@ def permutation_importance( k.name: {feature.key: [] for feature in model.inputs} for k in metrics.keys() } pred = model.predict(X) - original_metrics = { - k.name: metrics[k](y[output_key].values, pred[output_key + "_pred"].values) # type: ignore - for k in metrics.keys() - } + if len(pred) >= 2: + original_metrics = { + k.name: metrics[k](y[output_key].values, pred[output_key + "_pred"].values) # type: ignore + for k in metrics.keys() + } + else: + original_metrics = {k.name: np.nan for k in metrics.keys()} for feature in model.inputs: for _ in range(n_repeats): @@ -62,9 +65,13 @@ def permutation_importance( pred = model.predict(X_i) # compute scores for metricenum, metric in metrics.items(): - prelim_results[metricenum.name][feature.key].append( - metric(y[output_key].values, pred[output_key + "_pred"].values) # type: ignore - ) + if len(pred) >= 2: + prelim_results[metricenum.name][feature.key].append( + metric(y[output_key].values, pred[output_key + "_pred"].values) # type: ignore + ) + else: + prelim_results[metricenum.name][feature.key].append(np.nan) # type: ignore + # convert dictionaries to dataframe for easier postprocessing and statistics # and return results = {} diff --git a/bofire/surrogates/trainable.py b/bofire/surrogates/trainable.py index ddf3d82e2..c38cd6d3a 100644 --- a/bofire/surrogates/trainable.py +++ b/bofire/surrogates/trainable.py @@ -96,6 +96,8 @@ def cross_validate( raise NotImplementedError( "Cross validation not implemented for multi-output models" ) + # first filter the experiments based on the model setting + experiments = self._preprocess_experiments(experiments) n = len(experiments) if folds > n: warnings.warn( @@ -117,8 +119,6 @@ def cross_validate( # instantiate kfold object cv = KFold(n_splits=folds, shuffle=True, random_state=random_state) key = self.outputs.get_keys()[0] # type: ignore - # first filter the experiments based on the model setting - experiments = self._preprocess_experiments(experiments) train_results = [] test_results = [] # now get the indices for the split diff --git a/tests/bofire/surrogates/test_feature_importance.py b/tests/bofire/surrogates/test_feature_importance.py index b1d77c288..0ddfd9c14 100644 --- a/tests/bofire/surrogates/test_feature_importance.py +++ b/tests/bofire/surrogates/test_feature_importance.py @@ -61,6 +61,22 @@ def test_permutation_importance(): assert list(results[m.name].index) == ["mean", "std"] +def test_permutation_importance_nan(): + model, experiments = get_model_and_data() + X = experiments[model.inputs.get_keys()][:1] + y = experiments[["y"]][:1] + model.fit(experiments=experiments) + results = permutation_importance(model=model, X=X, y=y, n_repeats=5) + assert isinstance(results, dict) + assert len(results) == len(metrics) + for m in metrics.keys(): + assert m.name in results.keys() + assert isinstance(results[m.name], pd.DataFrame) + assert list(results[m.name].columns) == model.inputs.get_keys() + assert list(results[m.name].index) == ["mean", "std"] + assert len(results[m.name].dropna()) == 0 + + @pytest.mark.parametrize("use_test", [True, False]) def test_permutation_importance_hook(use_test): model, experiments = get_model_and_data()