Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
jsfreischuetz committed Jul 23, 2024
1 parent bf82d9f commit 6f08843
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def surrogate_predict(
raise RuntimeError(
"Surrogate model can make predictions *only* after "
"all initial points have been evaluated "
f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}"
f"{sum(len(o.config.index) for o in self._observations)} <= {self.base_optimizer._initial_design._n_configs}"
)
if self.base_optimizer._config_selector._model is None:
raise RuntimeError("Surrogate model is not yet trained")
Expand Down Expand Up @@ -428,7 +428,8 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace.
configs : list
List of ConfigSpace configs.
"""
values = [config.to_dict() for (_, config) in configs.astype("O").iterrows()]
return [
ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict())
for (_, config) in configs.astype("O").iterrows()
ConfigSpace.Configuration(self.optimizer_parameter_space, values=value)
for value in values
]
3 changes: 0 additions & 3 deletions mlos_core/mlos_core/optimizers/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,6 @@ class Observations:
def __init__(self, observations: List[Observation] = []):
self.observations = observations

def __len__(self) -> int:
return len(self.observations)

def append(self, observation: Observation) -> None:
"""
Appends an observation to the collection.
Expand Down
4 changes: 2 additions & 2 deletions mlos_core/mlos_core/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def get_observations(self) -> Observations:
observations : Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]
A triplet of (config, score, context) DataFrames of observations.
"""
if len(self._observations) == 0:
if sum(len(o.config.index) for o in self._observations) == 0:
raise ValueError("No observations registered yet.")

return self._observations
Expand All @@ -243,7 +243,7 @@ def get_best_observations(
observations : Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]
A triplet of best (config, score, context) DataFrames of best observations.
"""
if len(self._observations) == 0:
if sum(len(o.config.index) for o in self._observations) == 0:
raise ValueError("No observations registered yet.")

configs, scores, contexts, metadata = self._observations.to_legacy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pandas as pd
import pytest

from mlos_core.mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer
from mlos_core.optimizers import BaseOptimizer, OptimizerType
from mlos_core.optimizers.bayesian_optimizers import BaseBayesianOptimizer

Expand Down
25 changes: 12 additions & 13 deletions mlos_core/mlos_core/tests/optimizers/optimizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import pandas as pd
import pytest

from mlos_core.optimizers.observations import Suggestion
from mlos_core.optimizers.observations import Observations, Suggestion
from mlos_core.optimizers import (
BaseOptimizer,
ConcreteOptimizer,
Expand Down Expand Up @@ -337,17 +337,16 @@ def objective(point: pd.DataFrame) -> pd.DataFrame:
)

# Retrieve and check all observations
for all_configs, all_scores, all_contexts, _metadata in (
optimizer.get_observations().to_legacy(),
llamatune_optimizer.get_observations().to_legacy(),
):
for obs in [optimizer.get_observations(), llamatune_optimizer.get_observations()]:
assert isinstance(obs, Observations)
all_configs, all_scores, all_contexts, _metadata = obs.to_legacy()
assert isinstance(all_configs, pd.DataFrame)
assert isinstance(all_scores, pd.DataFrame)
assert all_contexts is None
assert set(all_configs.columns) == {"x", "y"}
assert set(all_scores.columns) == {"score"}
assert len(all_configs) == num_iters
assert len(all_scores) == num_iters
assert len(all_configs.index) == num_iters
assert len(all_scores.index) == num_iters

# .surrogate_predict method not currently implemented if space adapter is employed
if isinstance(llamatune_optimizer, BaseBayesianOptimizer):
Expand Down Expand Up @@ -425,19 +424,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame:
for _ in range(max_iterations):
suggestion = optimizer.suggest()
assert isinstance(suggestion, Suggestion)
assert isinstance(suggestion.context, pd.DataFrame)
assert (suggestion.context.columns == ["x", "y"]).all()
assert isinstance(suggestion.config, pd.DataFrame)
assert (suggestion.config.columns == ["x", "y"]).all()
# Check suggestion values are the expected dtype
assert isinstance(suggestion.context["x"].iloc[0], np.integer)
assert isinstance(suggestion.context["y"].iloc[0], np.floating)
assert isinstance(suggestion.config["x"].iloc[0], np.integer)
assert isinstance(suggestion.config["y"].iloc[0], np.floating)
# Check that suggestion is in the space
test_configuration = CS.Configuration(
optimizer.parameter_space, suggestion.context.astype("O").iloc[0].to_dict()
optimizer.parameter_space, suggestion.config.astype("O").iloc[0].to_dict()
)
# Raises an error if outside of configuration space
test_configuration.is_valid_configuration()
# Test registering the suggested configuration with a score.
observation = objective(suggestion.context)
observation = objective(suggestion.config)
assert isinstance(observation, pd.DataFrame)
optimizer.register(observation=suggestion.evaluate(observation))

Expand Down
101 changes: 101 additions & 0 deletions test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
,x,y
0,1.8663263131194956,1.3131832170213436
1,2.3560757511413075,2.3560757511413075
2,2.3399274243564108,0.8177778158479249
3,0.8293927654292901,0.8293927654292901
4,2.405616532605058,2.8744180610511156
5,2.6277979042262842,2.6277979042262842
6,1.0734518098736001,1.502985376570376
7,2.050388805516409,2.050388805516409
8,2.1381060809487007,1.1107522643711847
9,1.6835885581968748,1.6835885581968748
10,1.509249495923429,0.041305348772046724
11,2.318479864837122,2.318479864837122
12,2.6479235719083496,1.0946579517041168
13,1.846188535300481,1.846188535300481
14,0.22614372492892965,1.1064720180059235
15,2.7994203059475646,2.7994203059475646
16,1.954134429679732,1.1916077331784627
17,2.3661904288222364,2.3661904288222364
18,0.9505083665066137,1.7042959578782075
19,2.6073821686836776,2.6073821686836776
20,1.308520271687038,2.4064429262404774
21,0.4313004735436937,0.4313004735436937
22,2.112782913355006,2.1137439245687175
23,0.6563763170222657,0.6563763170222657
24,2.7746028858466953,1.3264222662125298
25,2.7279478769174177,2.7279478769174177
26,0.1794276683395557,0.5528612514414409
27,0.1420658364045454,0.1420658364045454
28,2.0246428307469904,1.7838743398033465
29,1.5999304889962516,1.5999304889962516
30,0.12997218808441047,1.6842992401901937
31,0.989005336862745,0.989005336862745
32,1.5089004993378552,0.33568295272321147
33,1.8215811186554536,1.8215811186554536
34,1.697833929151594,0.02029218597000837
35,1.8523251264128913,1.8523251264128913
36,2.7363686592994627,2.3715723991711
37,2.9762443985650844,2.9762443985650844
38,2.8764052864585996,2.3758924058749193
39,0.8557528800735293,0.8557528800735293
40,1.874750115917733,1.4342813870120237
41,0.5870255359976947,0.5870255359976947
42,1.1469523560945194,0.16162105543870975
43,1.3549452247825773,1.3549452247825773
44,2.9460142245658636,0.37182810146088896
45,0.3581426937787452,0.3581426937787452
46,2.2155691684300405,1.7619109003919537
47,1.4148976029611033,1.4148976029611033
48,0.3213804515815989,0.6876556963818538
49,2.699895584510026,2.699895584510026
50,1.2502606134080796,1.6075549875948476
51,0.018625549761388194,0.018625549761388194
52,0.9019251173109034,1.3106795165268306
53,1.8364469911972727,1.8364469911972727
54,2.754594226141719,1.877210009887606
55,2.1179926952453196,2.1179926952453196
56,0.4495011479697817,2.2381902274101497
57,2.4930209773006133,2.4930209773006133
58,1.9011773068529374,1.3149296433672826
59,0.4577183240235161,0.4577183240235161
60,1.7052288457415705,1.5846728327551816
61,2.8542862912607796,2.8542862912607796
62,1.4410775355300482,1.507678690147651
63,1.6106345787732292,1.6106345787732292
64,2.457606201192475,0.17134691426657966
65,2.0082652292236465,2.0082652292236465
66,2.3013498851384173,2.1243460859328116
67,2.39060155117559,2.39060155117559
68,1.6732824852823485,2.897509595976383
69,0.44147069967899155,0.44147069967899155
70,0.08894100160624674,1.7816804778743154
71,0.34219709622798744,0.34219709622798744
72,2.8524295502523667,0.9771222432760417
73,0.5808560704613316,0.5808560704613316
74,1.3734349466922826,2.7612077132792634
75,2.637207484544027,2.637207484544027
76,0.7578472651395907,1.0440263786080384
77,0.5477661947409262,0.5477661947409262
78,2.7053881541129763,2.1195844895153924
79,2.179975384686422,2.179975384686422
80,2.700263510429123,2.3374914023079727
81,1.7974643418128773,1.7974643418128773
82,0.8733757346936245,0.45418579322229635
83,1.0055239774483047,1.0055239774483047
84,1.9726553314734583,0.2200276308978545
85,0.16501918621867095,0.16501918621867095
86,0.9695844417635293,1.7714454133889586
87,2.5616957013769253,2.5616957013769253
88,0.861187275000027,0.5192016804443764
89,0.40206361799652846,0.40206361799652846
90,2.9839614859328836,0.5384936084081755
91,0.9526404690815875,0.9526404690815875
92,1.7048742139773216,0.028045723500733533
93,2.7019458634652755,2.7019458634652755
94,2.931724292767761,1.6706840374105045
95,0.2543215301825704,0.2543215301825704
96,0.9990073971873003,2.185286029109016
97,0.42730612002545165,0.42730612002545165
98,1.6574068184926365,0.8191297790510568
99,2.923485414261779,2.923485414261779

0 comments on commit 6f08843

Please sign in to comment.