From 6f088437ece7178dbf5c118bceebe41050ee919a Mon Sep 17 00:00:00 2001 From: Johannes Freischuetz Date: Tue, 23 Jul 2024 01:10:09 +0000 Subject: [PATCH] save --- .../bayesian_optimizers/smac_optimizer.py | 7 +- .../mlos_core/optimizers/observations.py | 3 - mlos_core/mlos_core/optimizers/optimizer.py | 4 +- .../optimizers/bayesian_optimizers_test.py | 2 +- .../tests/optimizers/optimizer_test.py | 25 +++-- test.csv | 101 ++++++++++++++++++ 6 files changed, 120 insertions(+), 22 deletions(-) create mode 100644 test.csv diff --git a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py index 8acbae93cb..5742a220fc 100644 --- a/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py +++ b/mlos_core/mlos_core/optimizers/bayesian_optimizers/smac_optimizer.py @@ -375,7 +375,7 @@ def surrogate_predict( raise RuntimeError( "Surrogate model can make predictions *only* after " "all initial points have been evaluated " - f"{len(self._observations)} <= {self.base_optimizer._initial_design._n_configs}" + f"{sum(len(o.config.index) for o in self._observations)} <= {self.base_optimizer._initial_design._n_configs}" ) if self.base_optimizer._config_selector._model is None: raise RuntimeError("Surrogate model is not yet trained") @@ -428,7 +428,8 @@ def _to_configspace_configs(self, *, configs: pd.DataFrame) -> List[ConfigSpace. configs : list List of ConfigSpace configs. """ + values = [config.to_dict() for (_, config) in configs.astype("O").iterrows()] return [ - ConfigSpace.Configuration(self.optimizer_parameter_space, values=config.to_dict()) - for (_, config) in configs.astype("O").iterrows() + ConfigSpace.Configuration(self.optimizer_parameter_space, values=value) + for value in values ] diff --git a/mlos_core/mlos_core/optimizers/observations.py b/mlos_core/mlos_core/optimizers/observations.py index 2743c7348e..5d87ea9e25 100644 --- a/mlos_core/mlos_core/optimizers/observations.py +++ b/mlos_core/mlos_core/optimizers/observations.py @@ -166,9 +166,6 @@ class Observations: def __init__(self, observations: List[Observation] = []): self.observations = observations - def __len__(self) -> int: - return len(self.observations) - def append(self, observation: Observation) -> None: """ Appends an observation to the collection. diff --git a/mlos_core/mlos_core/optimizers/optimizer.py b/mlos_core/mlos_core/optimizers/optimizer.py index 43b7d8b396..2c608a6613 100644 --- a/mlos_core/mlos_core/optimizers/optimizer.py +++ b/mlos_core/mlos_core/optimizers/optimizer.py @@ -217,7 +217,7 @@ def get_observations(self) -> Observations: observations : Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]] A triplet of (config, score, context) DataFrames of observations. """ - if len(self._observations) == 0: + if sum(len(o.config.index) for o in self._observations) == 0: raise ValueError("No observations registered yet.") return self._observations @@ -243,7 +243,7 @@ def get_best_observations( observations : Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]] A triplet of best (config, score, context) DataFrames of best observations. """ - if len(self._observations) == 0: + if sum(len(o.config.index) for o in self._observations) == 0: raise ValueError("No observations registered yet.") configs, scores, contexts, metadata = self._observations.to_legacy() diff --git a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py index e7cfbbfee4..900ccc1281 100644 --- a/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py +++ b/mlos_core/mlos_core/tests/optimizers/bayesian_optimizers_test.py @@ -10,7 +10,7 @@ import pandas as pd import pytest -from mlos_core.mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer +from mlos_core.optimizers.bayesian_optimizers.smac_optimizer import SmacOptimizer from mlos_core.optimizers import BaseOptimizer, OptimizerType from mlos_core.optimizers.bayesian_optimizers import BaseBayesianOptimizer diff --git a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py index b44e4c0c56..a01185aaa9 100644 --- a/mlos_core/mlos_core/tests/optimizers/optimizer_test.py +++ b/mlos_core/mlos_core/tests/optimizers/optimizer_test.py @@ -13,7 +13,7 @@ import pandas as pd import pytest -from mlos_core.optimizers.observations import Suggestion +from mlos_core.optimizers.observations import Observations, Suggestion from mlos_core.optimizers import ( BaseOptimizer, ConcreteOptimizer, @@ -337,17 +337,16 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: ) # Retrieve and check all observations - for all_configs, all_scores, all_contexts, _metadata in ( - optimizer.get_observations().to_legacy(), - llamatune_optimizer.get_observations().to_legacy(), - ): + for obs in [optimizer.get_observations(), llamatune_optimizer.get_observations()]: + assert isinstance(obs, Observations) + all_configs, all_scores, all_contexts, _metadata = obs.to_legacy() assert isinstance(all_configs, pd.DataFrame) assert isinstance(all_scores, pd.DataFrame) assert all_contexts is None assert set(all_configs.columns) == {"x", "y"} assert set(all_scores.columns) == {"score"} - assert len(all_configs) == num_iters - assert len(all_scores) == num_iters + assert len(all_configs.index) == num_iters + assert len(all_scores.index) == num_iters # .surrogate_predict method not currently implemented if space adapter is employed if isinstance(llamatune_optimizer, BaseBayesianOptimizer): @@ -425,19 +424,19 @@ def objective(point: pd.DataFrame) -> pd.DataFrame: for _ in range(max_iterations): suggestion = optimizer.suggest() assert isinstance(suggestion, Suggestion) - assert isinstance(suggestion.context, pd.DataFrame) - assert (suggestion.context.columns == ["x", "y"]).all() + assert isinstance(suggestion.config, pd.DataFrame) + assert (suggestion.config.columns == ["x", "y"]).all() # Check suggestion values are the expected dtype - assert isinstance(suggestion.context["x"].iloc[0], np.integer) - assert isinstance(suggestion.context["y"].iloc[0], np.floating) + assert isinstance(suggestion.config["x"].iloc[0], np.integer) + assert isinstance(suggestion.config["y"].iloc[0], np.floating) # Check that suggestion is in the space test_configuration = CS.Configuration( - optimizer.parameter_space, suggestion.context.astype("O").iloc[0].to_dict() + optimizer.parameter_space, suggestion.config.astype("O").iloc[0].to_dict() ) # Raises an error if outside of configuration space test_configuration.is_valid_configuration() # Test registering the suggested configuration with a score. - observation = objective(suggestion.context) + observation = objective(suggestion.config) assert isinstance(observation, pd.DataFrame) optimizer.register(observation=suggestion.evaluate(observation)) diff --git a/test.csv b/test.csv new file mode 100644 index 0000000000..8a27835dba --- /dev/null +++ b/test.csv @@ -0,0 +1,101 @@ +,x,y +0,1.8663263131194956,1.3131832170213436 +1,2.3560757511413075,2.3560757511413075 +2,2.3399274243564108,0.8177778158479249 +3,0.8293927654292901,0.8293927654292901 +4,2.405616532605058,2.8744180610511156 +5,2.6277979042262842,2.6277979042262842 +6,1.0734518098736001,1.502985376570376 +7,2.050388805516409,2.050388805516409 +8,2.1381060809487007,1.1107522643711847 +9,1.6835885581968748,1.6835885581968748 +10,1.509249495923429,0.041305348772046724 +11,2.318479864837122,2.318479864837122 +12,2.6479235719083496,1.0946579517041168 +13,1.846188535300481,1.846188535300481 +14,0.22614372492892965,1.1064720180059235 +15,2.7994203059475646,2.7994203059475646 +16,1.954134429679732,1.1916077331784627 +17,2.3661904288222364,2.3661904288222364 +18,0.9505083665066137,1.7042959578782075 +19,2.6073821686836776,2.6073821686836776 +20,1.308520271687038,2.4064429262404774 +21,0.4313004735436937,0.4313004735436937 +22,2.112782913355006,2.1137439245687175 +23,0.6563763170222657,0.6563763170222657 +24,2.7746028858466953,1.3264222662125298 +25,2.7279478769174177,2.7279478769174177 +26,0.1794276683395557,0.5528612514414409 +27,0.1420658364045454,0.1420658364045454 +28,2.0246428307469904,1.7838743398033465 +29,1.5999304889962516,1.5999304889962516 +30,0.12997218808441047,1.6842992401901937 +31,0.989005336862745,0.989005336862745 +32,1.5089004993378552,0.33568295272321147 +33,1.8215811186554536,1.8215811186554536 +34,1.697833929151594,0.02029218597000837 +35,1.8523251264128913,1.8523251264128913 +36,2.7363686592994627,2.3715723991711 +37,2.9762443985650844,2.9762443985650844 +38,2.8764052864585996,2.3758924058749193 +39,0.8557528800735293,0.8557528800735293 +40,1.874750115917733,1.4342813870120237 +41,0.5870255359976947,0.5870255359976947 +42,1.1469523560945194,0.16162105543870975 +43,1.3549452247825773,1.3549452247825773 +44,2.9460142245658636,0.37182810146088896 +45,0.3581426937787452,0.3581426937787452 +46,2.2155691684300405,1.7619109003919537 +47,1.4148976029611033,1.4148976029611033 +48,0.3213804515815989,0.6876556963818538 +49,2.699895584510026,2.699895584510026 +50,1.2502606134080796,1.6075549875948476 +51,0.018625549761388194,0.018625549761388194 +52,0.9019251173109034,1.3106795165268306 +53,1.8364469911972727,1.8364469911972727 +54,2.754594226141719,1.877210009887606 +55,2.1179926952453196,2.1179926952453196 +56,0.4495011479697817,2.2381902274101497 +57,2.4930209773006133,2.4930209773006133 +58,1.9011773068529374,1.3149296433672826 +59,0.4577183240235161,0.4577183240235161 +60,1.7052288457415705,1.5846728327551816 +61,2.8542862912607796,2.8542862912607796 +62,1.4410775355300482,1.507678690147651 +63,1.6106345787732292,1.6106345787732292 +64,2.457606201192475,0.17134691426657966 +65,2.0082652292236465,2.0082652292236465 +66,2.3013498851384173,2.1243460859328116 +67,2.39060155117559,2.39060155117559 +68,1.6732824852823485,2.897509595976383 +69,0.44147069967899155,0.44147069967899155 +70,0.08894100160624674,1.7816804778743154 +71,0.34219709622798744,0.34219709622798744 +72,2.8524295502523667,0.9771222432760417 +73,0.5808560704613316,0.5808560704613316 +74,1.3734349466922826,2.7612077132792634 +75,2.637207484544027,2.637207484544027 +76,0.7578472651395907,1.0440263786080384 +77,0.5477661947409262,0.5477661947409262 +78,2.7053881541129763,2.1195844895153924 +79,2.179975384686422,2.179975384686422 +80,2.700263510429123,2.3374914023079727 +81,1.7974643418128773,1.7974643418128773 +82,0.8733757346936245,0.45418579322229635 +83,1.0055239774483047,1.0055239774483047 +84,1.9726553314734583,0.2200276308978545 +85,0.16501918621867095,0.16501918621867095 +86,0.9695844417635293,1.7714454133889586 +87,2.5616957013769253,2.5616957013769253 +88,0.861187275000027,0.5192016804443764 +89,0.40206361799652846,0.40206361799652846 +90,2.9839614859328836,0.5384936084081755 +91,0.9526404690815875,0.9526404690815875 +92,1.7048742139773216,0.028045723500733533 +93,2.7019458634652755,2.7019458634652755 +94,2.931724292767761,1.6706840374105045 +95,0.2543215301825704,0.2543215301825704 +96,0.9990073971873003,2.185286029109016 +97,0.42730612002545165,0.42730612002545165 +98,1.6574068184926365,0.8191297790510568 +99,2.923485414261779,2.923485414261779