Skip to content

Commit

Permalink
add include_labcodes (#168)
Browse files Browse the repository at this point in the history
* add include_labcodes

* update black
  • Loading branch information
jduerholt authored Apr 21, 2023
1 parent d0f4d76 commit 0947f53
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
11 changes: 11 additions & 0 deletions bofire/surrogates/trainable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def cross_validate(
experiments: pd.DataFrame,
folds: int = -1,
include_X: bool = False,
include_labcodes: bool = False,
hooks: Dict[
str,
Callable[
Expand Down Expand Up @@ -83,6 +84,8 @@ def cross_validate(
Tuple[CvResults, CvResults, Dict[str, List[Any]]]: First CvResults object reflects the training data,
second CvResults object the test data, dictionary object holds the return values of the applied hooks.
"""
if include_labcodes and "labcode" not in experiments.columns:
raise ValueError("No labcodes available for the provided experiments.")

if len(self.output_features) > 1: # type: ignore
raise NotImplementedError(
Expand Down Expand Up @@ -114,6 +117,12 @@ def cross_validate(
X_test = experiments.iloc[test_index][self.input_features.get_keys()] # type: ignore
y_train = experiments.iloc[train_index][self.output_features.get_keys()] # type: ignore
y_test = experiments.iloc[test_index][self.output_features.get_keys()] # type: ignore
train_labcodes = (
experiments.iloc[train_index]["labcode"] if include_labcodes else None
)
test_labcodes = (
experiments.iloc[test_index]["labcode"] if include_labcodes else None
)
# now fit the model
self._fit(X_train, y_train)
# now do the scoring
Expand All @@ -127,6 +136,7 @@ def cross_validate(
predicted=y_train_pred[key + "_pred"],
standard_deviation=y_train_pred[key + "_sd"],
X=X_train if include_X else None,
labcodes=train_labcodes,
)
)
test_results.append(
Expand All @@ -136,6 +146,7 @@ def cross_validate(
predicted=y_test_pred[key + "_pred"],
standard_deviation=y_test_pred[key + "_sd"],
X=X_test if include_X else None,
labcodes=test_labcodes,
)
)
# now call the hooks if available
Expand Down
13 changes: 10 additions & 3 deletions tests/bofire/surrogates/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def test_model_cross_validate_descriptor():
assert len(test_cv.results) == efolds


@pytest.mark.parametrize("include_X", [True, False])
def test_model_cross_validate_include_X(include_X):
@pytest.mark.parametrize("include_X, include_labcodes", [[True, False], [False, True]])
def test_model_cross_validate_include_X(include_X, include_labcodes):
input_features = Inputs(
features=[
ContinuousInput(
Expand All @@ -93,6 +93,7 @@ def test_model_cross_validate_include_X(include_X):
)
output_features = Outputs(features=[ContinuousOutput(key="y")])
experiments = input_features.sample(n=10)
experiments["labcode"] = [str(i) for i in range(10)]
experiments.eval("y=((x_1**2 + x_2 - 11)**2+(x_1 + x_2**2 -7)**2)", inplace=True)
experiments["valid_y"] = 1
model = SingleTaskGPSurrogate(
Expand All @@ -101,14 +102,20 @@ def test_model_cross_validate_include_X(include_X):
)
model = surrogates.map(model)
train_cv, test_cv, _ = model.cross_validate(
experiments, folds=5, include_X=include_X
experiments, folds=5, include_X=include_X, include_labcodes=include_labcodes
)
if include_X:
assert train_cv.results[0].X.shape == (8, 2)
assert test_cv.results[0].X.shape == (2, 2)
if include_X is False:
assert train_cv.results[0].X is None
assert test_cv.results[0].X is None
if include_labcodes:
assert train_cv.results[0].labcodes.shape == (8,)
assert test_cv.results[0].labcodes.shape == (2,)
else:
assert train_cv.results[0].labcodes is None
assert train_cv.results[0].labcodes is None


def test_model_cross_validate_hooks():
Expand Down

0 comments on commit 0947f53

Please sign in to comment.