Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations & code refactor #704

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions supervised/algorithms/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import uuid
from typing import Union

import numpy as np
import pandas as pd

from supervised.utils.common import construct_learner_name
from supervised.utils.importance import PermutationImportance
Expand All @@ -16,15 +18,15 @@ class BaseAlgorithm:
algorithm_name = "Unknown"
algorithm_short_name = "Unknown"

def __init__(self, params):
self.params = params
self.stop_training = False
self.library_version = None
self.model = None
self.uid = params.get("uid", str(uuid.uuid4()))
self.ml_task = params.get("ml_task")
self.model_file_path = None
self.name = "amazing_learner"
def __init__(self, params: dict):
self.params: dict = params
self.stop_training: bool = False
self.library_version: str = None
self.model: object = None
self.uid: str = params.get("uid", str(uuid.uuid4()))
self.ml_task: str = params.get("ml_task")
self.model_file_path: str = None
self.name: str = "amazing_learner"

def set_learner_name(self, fold, repeat, repeats):
self.name = construct_learner_name(fold, repeat, repeats)
Expand All @@ -38,15 +40,15 @@ def reload(self):
self.load(self.model_file_path)

def fit(
self,
X,
y,
sample_weight=None,
X_validation=None,
y_validation=None,
sample_weight_validation=None,
log_to_file=None,
max_time=None,
self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are you adding spaces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is automatic reformatting by the IDE.

X: Union[np.ndarray, pd.DataFrame],
y: Union[np.ndarray, pd.Series],
sample_weight=None,
X_validation=None,
y_validation=None,
sample_weight_validation=None,
log_to_file=None,
max_time=None,
):
pass

Expand Down Expand Up @@ -76,18 +78,18 @@ def get_fname(self):
return f"{self.name}.{self.file_extension()}"

def interpret(
self,
X_train,
y_train,
X_validation,
y_validation,
model_file_path,
learner_name,
target_name=None,
class_names=None,
metric_name=None,
ml_task=None,
explain_level=2,
self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spaces again?

X_train,
y_train,
X_validation,
y_validation,
model_file_path,
learner_name,
target_name=None,
class_names=None,
metric_name=None,
ml_task=None,
explain_level=2,
):
# do not produce feature importance for Baseline
if self.algorithm_short_name == "Baseline":
Expand Down
19 changes: 9 additions & 10 deletions supervised/algorithms/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class BaselineClassifierAlgorithm(SklearnAlgorithm, ClassifierMixin):
algorithm_name = "Baseline Classifier"
algorithm_short_name = "Baseline"

def __init__(self, params):
def __init__(self, params: dict):
super(BaselineClassifierAlgorithm, self).__init__(params)
logger.debug("BaselineClassifierAlgorithm.__init__")

self.library_version = sklearn.__version__
self.max_iters = additional.get("max_steps", 1)
self.library_version: str = sklearn.__version__
self.max_iters: int = additional.get("max_steps", 1)
self.model = DummyClassifier(
strategy="prior", random_state=params.get("seed", 1)
)
Expand All @@ -36,9 +36,9 @@ def file_extension(self):

def is_fitted(self):
return (
hasattr(self.model, "n_outputs_")
and self.model.n_outputs_ is not None
and self.model.n_outputs_ > 0
hasattr(self.model, "n_outputs_")
and self.model.n_outputs_ is not None
and self.model.n_outputs_ > 0
)


Expand All @@ -59,9 +59,9 @@ def file_extension(self):

def is_fitted(self):
return (
hasattr(self.model, "n_outputs_")
and self.model.n_outputs_ is not None
and self.model.n_outputs_ > 0
hasattr(self.model, "n_outputs_")
and self.model.n_outputs_ is not None
and self.model.n_outputs_ > 0
)


Expand All @@ -86,5 +86,4 @@ def is_fitted(self):
{},
)


AlgorithmsRegistry.add(REGRESSION, BaselineRegressorAlgorithm, {}, {}, additional, {})
6 changes: 4 additions & 2 deletions supervised/algorithms/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from supervised.algorithms.algorithm import BaseAlgorithm
from supervised.algorithms.registry import BINARY_CLASSIFICATION, AlgorithmsRegistry

logger = logging.getLogger(__name__)
Expand All @@ -8,8 +9,9 @@


class AlgorithmFactory(object):

@classmethod
def get_algorithm(cls, params):
def get_algorithm(cls, params) -> BaseAlgorithm:
alg_type = params.get("model_type", "Xgboost")
ml_task = params.get("ml_task", BINARY_CLASSIFICATION)

Expand All @@ -20,7 +22,7 @@ def get_algorithm(cls, params):
raise AutoMLException(f"Cannot get algorithm class. {str(e)}")

@classmethod
def load(cls, json_desc, learner_path, lazy_load):
def load(cls, json_desc, learner_path, lazy_load) -> BaseAlgorithm:
learner = AlgorithmFactory.get_algorithm(json_desc.get("params"))
learner.set_params(json_desc, learner_path)
if not lazy_load:
Expand Down
30 changes: 17 additions & 13 deletions supervised/algorithms/registry.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# tasks that can be handled by the package

from typing import List, Type

BINARY_CLASSIFICATION = "binary_classification"
MULTICLASS_CLASSIFICATION = "multiclass_classification"
REGRESSION = "regression"


class AlgorithmsRegistry:
from supervised.algorithms.algorithm import BaseAlgorithm
registry = {
BINARY_CLASSIFICATION: {},
MULTICLASS_CLASSIFICATION: {},
Expand All @@ -13,13 +17,13 @@ class AlgorithmsRegistry:

@staticmethod
def add(
task_name,
model_class,
model_params,
required_preprocessing,
additional,
default_params,
):
task_name: str,
model_class: Type[BaseAlgorithm],
model_params: dict,
required_preprocessing: list,
additional: dict,
default_params: dict,
) -> None:
model_information = {
"class": model_class,
"params": model_params,
Expand All @@ -32,33 +36,33 @@ def add(
] = model_information

@staticmethod
def get_supported_ml_tasks():
def get_supported_ml_tasks() -> List[str]:
return AlgorithmsRegistry.registry.keys()

@staticmethod
def get_algorithm_class(ml_task, algorithm_name):
def get_algorithm_class(ml_task: str, algorithm_name: str) -> Type[BaseAlgorithm]:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["class"]

@staticmethod
def get_long_name(ml_task, algorithm_name):
def get_long_name(ml_task: str, algorithm_name: str) -> str:
return AlgorithmsRegistry.registry[ml_task][algorithm_name][
"class"
].algorithm_name

@staticmethod
def get_max_rows_limit(ml_task, algorithm_name):
def get_max_rows_limit(ml_task: str, algorithm_name: str) -> int:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["additional"][
"max_rows_limit"
]

@staticmethod
def get_max_cols_limit(ml_task, algorithm_name):
def get_max_cols_limit(ml_task: str, algorithm_name: str) -> int:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["additional"][
"max_cols_limit"
]

@staticmethod
def get_eval_metric(algorithm_name, ml_task, automl_eval_metric):
def get_eval_metric(ml_task: str, algorithm_name: str, automl_eval_metric: str):
if algorithm_name == "Xgboost":
return xgboost_eval_metric(ml_task, automl_eval_metric)

Expand Down
Loading
Loading