forked from mljar/mljar-supervised
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathregistry.py
71 lines (58 loc) · 2.2 KB
/
registry.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# tasks that can be handled by the package
from typing import List, Type
BINARY_CLASSIFICATION = "binary_classification"
MULTICLASS_CLASSIFICATION = "multiclass_classification"
REGRESSION = "regression"
class AlgorithmsRegistry:
from supervised.algorithms.algorithm import BaseAlgorithm
registry = {
BINARY_CLASSIFICATION: {},
MULTICLASS_CLASSIFICATION: {},
REGRESSION: {},
}
@staticmethod
def add(
task_name: str,
model_class: Type[BaseAlgorithm],
model_params: dict,
required_preprocessing: list,
additional: dict,
default_params: dict,
) -> None:
model_information = {
"class": model_class,
"params": model_params,
"required_preprocessing": required_preprocessing,
"additional": additional,
"default_params": default_params,
}
AlgorithmsRegistry.registry[task_name][
model_class.algorithm_short_name
] = model_information
@staticmethod
def get_supported_ml_tasks() -> List[str]:
return AlgorithmsRegistry.registry.keys()
@staticmethod
def get_algorithm_class(ml_task: str, algorithm_name: str) -> Type[BaseAlgorithm]:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["class"]
@staticmethod
def get_long_name(ml_task: str, algorithm_name: str) -> str:
return AlgorithmsRegistry.registry[ml_task][algorithm_name][
"class"
].algorithm_name
@staticmethod
def get_max_rows_limit(ml_task: str, algorithm_name: str) -> int:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["additional"][
"max_rows_limit"
]
@staticmethod
def get_max_cols_limit(ml_task: str, algorithm_name: str) -> int:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["additional"][
"max_cols_limit"
]
@staticmethod
def get_eval_metric(ml_task: str, algorithm_name: str, automl_eval_metric: str):
if algorithm_name == "Xgboost":
return xgboost_eval_metric(ml_task, automl_eval_metric)
return automl_eval_metric
# Import algorithm to be registered