Skip to content

Commit

Permalink
set meaningful names to models (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
pplonski committed May 27, 2020
1 parent f076bf6 commit 803b9e4
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/multi_class_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

automl = AutoML(
#results_path="AutoML_22",
algorithms=["Nearest Neighbors"],
#algorithms=["Nearest Neighbors"],
#algorithms=["Random Forest"],

# "Linear",
Expand Down
2 changes: 1 addition & 1 deletion supervised/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def stacked_ensemble_step(self):

params = copy.deepcopy(m.params)
params["validation"]["X_train_path"] = X_train_stacked_path
params["name"] = "stacked_" + params["name"]
params["name"] = "Stacked" + params["name"]
params["is_stacked"] = True

if "model_architecture_json" in params["learner"]:
Expand Down
16 changes: 11 additions & 5 deletions supervised/tuner/mljar_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def __init__(

self._unique_params_keys = []

def get_model_name(self, model_type, models_cnt, special = ""):
return special + model_type.replace(" ", "")+ f"_{models_cnt}"

def simple_algorithms_params(self):
models_cnt = 0
generated_params = []
Expand All @@ -55,7 +58,8 @@ def simple_algorithms_params(self):
params = self._get_model_params(model_type, seed=i + 1)
if params is None:
continue
params["name"] = f"model_{models_cnt + 1}"

params["name"] = self.get_model_name(model_type, models_cnt+1)

unique_params_key = MljarTuner.get_params_key(params)
if unique_params_key not in self._unique_params_keys:
Expand Down Expand Up @@ -102,8 +106,8 @@ def default_params(self, models_cnt):
model_type, seed=models_cnt + 1, params_type="default"
)
if params is None:
continue
params["name"] = f"model_{models_cnt + 1}"
continue
params["name"] = self.get_model_name(model_type, models_cnt+1, special="Default")

unique_params_key = MljarTuner.get_params_key(params)
if unique_params_key not in self._unique_params_keys:
Expand Down Expand Up @@ -139,7 +143,8 @@ def get_not_so_random_params(self, models_cnt):
params = self._get_model_params(model_type, seed=i + 1)
if params is None:
continue
params["name"] = f"model_{models_cnt + 1}"

params["name"] = self.get_model_name(model_type, models_cnt+1)

unique_params_key = MljarTuner.get_params_key(params)
if unique_params_key not in self._unique_params_keys:
Expand Down Expand Up @@ -201,7 +206,8 @@ def get_hill_climbing_params(self, current_models):
if p is not None:
all_params = copy.deepcopy(m.params)
all_params["learner"] = p
all_params["name"] = f"model_{model_max_index + 1}"

all_params["name"] = self.get_model_name(all_params["learner"]["model_type"], model_max_index+1)

unique_params_key = MljarTuner.get_params_key(all_params)
if unique_params_key not in self._unique_params_keys:
Expand Down

0 comments on commit 803b9e4

Please sign in to comment.