Skip to content

Commit

Permalink
feature(Metric): MetricRegistry to avoid duplicate metric
Browse files Browse the repository at this point in the history
  • Loading branch information
szemyd committed Nov 20, 2023
1 parent 315154a commit a196f9d
Show file tree
Hide file tree
Showing 11 changed files with 96 additions and 58 deletions.
2 changes: 1 addition & 1 deletion docs/examples/evaluate_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
probabilities=probs,
# dataset_type="classification_multilabel", # if automatic inference of dataset type fails
calculation="both",
default_metrics=library.default_metrics_classification.binary_classification_balanced_metrics,
default_metrics=library.MetricRegistryClassification().binary_classification_balanced_metrics,
)
sc.print()
2 changes: 1 addition & 1 deletion docs/examples/evaluate_classification_multilabel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
probabilities=probs,
# dataset_type="classification_multilabel", # if automatic inference of dataset type fails
calculation="single",
default_metrics=library.multiclass_classification_metrics,
default_metrics=library.MetricRegistryClassification().multiclass_classification_metrics,
).print()
13 changes: 4 additions & 9 deletions docs/examples/evaluate_default_metric_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,22 @@

import numpy as np

from krisi import score
from krisi.evaluate.library.default_metrics_regression import (
all_regression_metrics,
low_computation_regression_metrics,
minimal_regression_metrics,
)
from krisi import library, score

score(
y=np.random.random(1000),
predictions=np.random.random(1000),
default_metrics=all_regression_metrics, # This is the default
default_metrics=library.MetricRegistryRegression().all_regression_metrics, # This is the default
).print()

score(
y=np.random.random(1000),
predictions=np.random.random(1000),
default_metrics=minimal_regression_metrics,
default_metrics=library.MetricRegistryRegression().minimal_regression_metrics,
).print(input_analysis=False)

score(
y=np.random.random(1000),
predictions=np.random.random(1000),
default_metrics=low_computation_regression_metrics,
default_metrics=library.MetricRegistryRegression().low_computation_regression_metrics,
).print(input_analysis=False)
24 changes: 6 additions & 18 deletions src/krisi/evaluate/library/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,18 @@
from krisi.evaluate.type import DatasetType

from .benchmarking_models import *
from .default_metrics_classification import *
from .default_metrics_classification import (
binary_classification_balanced_metrics,
binary_classification_imbalanced_metrics,
minimal_binary_classification_metrics,
minimal_multiclass_classification_metrics,
multiclass_classification_metrics,
)
from .default_metrics_regression import *
from .default_metrics_regression import (
all_regression_metrics,
low_computation_regression_metrics,
minimal_regression_metrics,
)
from .default_metrics_classification import MetricRegistryClassification
from .default_metrics_regression import MetricRegistryRegression


def get_default_metrics_for_dataset_type(type: DatasetType) -> List[Metric]:
if type == DatasetType.classification_binary_balanced:
return binary_classification_balanced_metrics
return MetricRegistryClassification().binary_classification_balanced_metrics
elif type == DatasetType.classification_binary_imbalanced:
return binary_classification_imbalanced_metrics
return MetricRegistryClassification().binary_classification_imbalanced_metrics
elif type == DatasetType.classification_multiclass:
return multiclass_classification_metrics
return MetricRegistryClassification().multiclass_classification_metrics
elif type == DatasetType.regression:
return all_regression_metrics
return MetricRegistryRegression().all_regression_metrics
else:
raise ValueError(f"Unknown dataset type {type}")
46 changes: 43 additions & 3 deletions src/krisi/evaluate/library/default_metrics_classification.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

import pandas as pd
from sklearn.metrics import (
accuracy_score,
Expand Down Expand Up @@ -311,9 +313,6 @@
imbalance_ratio_pred_y,
]


"""~"""
minimal_binary_classification_metrics = [accuracy_binary, f_one_score_binary]
"""~"""
multiclass_classification_metrics = [
cross_entropy,
Expand All @@ -332,3 +331,44 @@
f_one_score_weighted,
# roc_auc_multi_weighted,
]


class MetricRegistryClassification:
def __init__(self):
self.accuracy_binary = deepcopy(accuracy_binary)
self.accuracy_binary_balanced = deepcopy(accuracy_binary_balanced)
self.avg_precision_micro = deepcopy(avg_precision_micro)
self.avg_precision_macro = deepcopy(avg_precision_macro)
self.avg_precision_weighted = deepcopy(avg_precision_weighted)
self.accuracy_binary = deepcopy(accuracy_binary)
self.recall_binary = deepcopy(recall_binary)
self.recall_macro = deepcopy(recall_macro)
self.precision_binary = deepcopy(precision_binary)
self.precision_macro = deepcopy(precision_macro)
self.matthew_corr = deepcopy(matthew_corr)
self.s_score = deepcopy(s_score)
self.f_one_score_binary = deepcopy(f_one_score_binary)
self.f_one_score_macro = deepcopy(f_one_score_macro)
self.f_one_score_micro = deepcopy(f_one_score_micro)
self.f_one_score_weighted = deepcopy(f_one_score_weighted)
self.kappa = deepcopy(kappa)
self.brier_score = deepcopy(brier_score)
self.calibration = deepcopy(calibration)
self.roc_auc_binary_micro = deepcopy(roc_auc_binary_micro)
self.roc_auc_binary_macro = deepcopy(roc_auc_binary_macro)
self.roc_auc_binary_weighted = deepcopy(roc_auc_binary_weighted)
self.roc_auc_multi_micro = deepcopy(roc_auc_multi_micro)
self.roc_auc_multi_macro = deepcopy(roc_auc_multi_macro)
self.imbalance_ratio_y = deepcopy(imbalance_ratio_y)
self.imbalance_ratio_pred_y = deepcopy(imbalance_ratio_pred_y)
self.cross_entropy = deepcopy(cross_entropy)

self.binary_classification_balanced_metrics = deepcopy(
binary_classification_balanced_metrics
)
self.binary_classification_imbalanced_metrics = deepcopy(
binary_classification_imbalanced_metrics
)
self.multiclass_classification_metrics = deepcopy(
multiclass_classification_metrics
)
7 changes: 7 additions & 0 deletions src/krisi/evaluate/library/default_metrics_regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from typing import List

import numpy as np
Expand Down Expand Up @@ -195,3 +196,9 @@
if metric.comp_complexity is not ComputationalComplexity.high
]
""" ~ """


class MetricRegistryRegression:
all_regression_metrics = deepcopy(all_regression_metrics)
minimal_regression_metrics = deepcopy(minimal_regression_metrics)
low_computation_regression_metrics = deepcopy(low_computation_regression_metrics)
13 changes: 5 additions & 8 deletions tests/test_benchmarking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd

from krisi import library
from krisi.evaluate import score
from krisi.evaluate.benchmark import zscore
from krisi.evaluate.library.benchmarking_models import (
Expand All @@ -9,10 +10,6 @@
RandomClassifierChunked,
WorstModel,
)
from krisi.evaluate.library.default_metrics_classification import (
binary_classification_balanced_metrics,
f_one_score_macro,
)
from krisi.sharedtypes import Task
from krisi.utils.data import (
generate_synthetic_data,
Expand All @@ -32,7 +29,7 @@ def test_benchmarking_random():
predictions,
probabilities,
sample_weight=sample_weight,
default_metrics=[f_one_score_macro],
default_metrics=[library.MetricRegistryClassification().f_one_score_macro],
benchmark_models=RandomClassifier(),
)
sc.print()
Expand Down Expand Up @@ -83,7 +80,7 @@ def test_benchmarking_random_all_metrics():
predictions,
probabilities,
sample_weight=sample_weight,
default_metrics=binary_classification_balanced_metrics,
default_metrics=library.MetricRegistryClassification().binary_classification_balanced_metrics,
benchmark_models=RandomClassifierChunked(2),
)
sc.print()
Expand All @@ -101,7 +98,7 @@ def test_perfect_to_best():
predictions,
probabilities,
sample_weight=sample_weight,
default_metrics=binary_classification_balanced_metrics,
default_metrics=library.MetricRegistryClassification().binary_classification_balanced_metrics,
benchmark_models=[PerfectModel(), WorstModel()],
)
sc.print()
Expand Down Expand Up @@ -129,7 +126,7 @@ def test_benchmark_zscore():
predictions,
probabilities,
sample_weight=sample_weight,
default_metrics=binary_classification_balanced_metrics,
default_metrics=library.MetricRegistryClassification().binary_classification_balanced_metrics,
benchmark_models=[PerfectModel(), WorstModel()],
)
sc.print()
Expand Down
4 changes: 3 additions & 1 deletion tests/test_imbalanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@


def test_benchmarking_random_all_metrics():
groupped_metric = library.binary_classification_imbalanced_metrics
groupped_metric = (
library.MetricRegistryClassification().binary_classification_imbalanced_metrics
)
X, y = generate_synthetic_data(task=Task.classification, num_obs=1000)
sample_weight = pd.Series([1.0] * len(y))
preds_probs = generate_synthetic_predictions_binary(y, sample_weight)
Expand Down
16 changes: 12 additions & 4 deletions tests/unit/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@ def dummy_func(

return dummy_func

for metric in library.binary_classification_imbalanced_metrics:
for (
metric
) in (
library.MetricRegistryClassification().binary_classification_imbalanced_metrics
):
if isinstance(metric, Group):
for m in metric.metrics:
m.func = dummy_func_inject_kwags(m.parameters)
Expand All @@ -52,7 +56,7 @@ def dummy_func(
pd.Series(np.random.randint(100, size=data_len)),
pd.concat([probs_0, probs_0 - 1], axis="columns", copy=False),
sample_weight=pd.Series(np.random.rand(data_len)),
default_metrics=library.binary_classification_imbalanced_metrics,
default_metrics=library.MetricRegistryClassification().binary_classification_imbalanced_metrics,
calculation=Calculation.single,
)

Expand Down Expand Up @@ -84,7 +88,11 @@ def dummy_func(

return dummy_func

for metric in library.binary_classification_imbalanced_metrics:
for (
metric
) in (
library.MetricRegistryClassification().binary_classification_imbalanced_metrics
):
if isinstance(metric, Group):
for m in metric.metrics:
m.func = dummy_func_inject_kwags(m.parameters)
Expand All @@ -97,7 +105,7 @@ def dummy_func(
pd.Series(np.random.randint(100, size=data_len)),
pd.concat([probs_0, probs_0 - 1], axis="columns", copy=False),
sample_weight=pd.Series(np.random.rand(data_len)),
default_metrics=library.binary_classification_imbalanced_metrics,
default_metrics=library.MetricRegistryClassification().binary_classification_imbalanced_metrics,
calculation=Calculation.rolling,
rolling_args={"window": window_size, "min_periods": min_periods},
)
6 changes: 3 additions & 3 deletions tests/unit/test_scorecard_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import pandas as pd
import pytest

from krisi import library
from krisi.evaluate import score
from krisi.evaluate.library.benchmarking_models import RandomClassifierChunked
from krisi.evaluate.library.default_metrics_classification import f_one_score_macro


def test_spreading_comparions_results():
sc = score(
pd.Series(np.random.randint(2, size=100)),
pd.Series(np.random.randint(2, size=100)),
default_metrics=[f_one_score_macro],
default_metrics=[library.MetricRegistryClassification().f_one_score_macro],
benchmark_models=RandomClassifierChunked(0.05),
)

Expand All @@ -24,7 +24,7 @@ def test_getting_no_skill_metric():
sc = score(
pd.Series(np.random.randint(2, size=100)),
pd.Series(np.random.randint(2, size=100)),
default_metrics=[f_one_score_macro],
default_metrics=[library.MetricRegistryClassification().f_one_score_macro],
benchmark_models=RandomClassifierChunked(0.05),
)

Expand Down
21 changes: 11 additions & 10 deletions tests/unit/test_scorecard_union.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from krisi import library, score
from krisi.evaluate.library.benchmarking_models import RandomClassifier
from krisi.utils.data import generate_random_classification


Expand All @@ -7,20 +8,20 @@ def test_scorecard_union():
num_labels=2, num_samples=1000
)
sc_1 = score(
y,
predictions,
probabilities,
y=y,
predictions=predictions,
probabilities=probabilities,
sample_weight=sample_weight,
default_metrics=library.binary_classification_balanced_metrics,
benchmark_models=library.RandomClassifier(),
default_metrics=library.MetricRegistryClassification().binary_classification_balanced_metrics,
benchmark_models=RandomClassifier(),
)
sc_2 = score(
y,
predictions,
probabilities,
y=y,
predictions=predictions,
probabilities=probabilities,
sample_weight=sample_weight,
default_metrics=library.binary_classification_balanced_metrics,
benchmark_models=library.RandomClassifier(),
default_metrics=library.MetricRegistryClassification().binary_classification_balanced_metrics,
benchmark_models=RandomClassifier(),
)
metric_key = "precision_binary"
sc_substracted = sc_1.subtract(sc_2)
Expand Down

0 comments on commit a196f9d

Please sign in to comment.