Skip to content

Commit

Permalink
feature: No need to define calculation type, if benchmark model is de…
Browse files Browse the repository at this point in the history
…fined than benchmarking is done.
  • Loading branch information
szemyd committed Nov 20, 2023
1 parent 61bf9f7 commit 625f6f4
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 41 deletions.
3 changes: 1 addition & 2 deletions docs/examples/evaluate_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


from krisi import library, score
from krisi.evaluate.type import Calculation
from krisi.utils.data import generate_random_classification

y, preds, probs, sample_weight = generate_random_classification(
Expand All @@ -16,7 +15,7 @@
y=y,
predictions=preds,
# dataset_type="classification_multilabel", # if automatic inference of dataset type fails
calculation=[Calculation.single, Calculation.rolling],
calculation="both",
default_metrics=library.default_metrics_classification.binary_classification_balanced_metrics,
)
sc.print()
3 changes: 1 addition & 2 deletions docs/examples/evaluate_classification_with_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


from krisi import score
from krisi.evaluate.type import Calculation
from krisi.utils.data import generate_random_classification

y, preds, probs, sample_weight = generate_random_classification(
Expand All @@ -26,7 +25,7 @@
predictions=preds,
probabilities=probs,
# dataset_type="classification_binary_balanced", # if automatic inference of dataset type fails
calculation=[Calculation.single, Calculation.rolling],
calculation="both",
)
sc.print()
sc.generate_report()
3 changes: 1 addition & 2 deletions docs/examples/evaluate_probabilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np

from krisi import score
from krisi.evaluate.type import Calculation
from krisi.utils.data import create_probabilities

num_labels = 3
Expand All @@ -19,7 +18,7 @@
y=np.random.randint(0, num_labels, num_samples),
predictions=np.random.randint(0, num_labels, num_samples),
probabilities=probabilities,
calculation=[Calculation.single, Calculation.rolling]
calculation="both"
# dataset_type="classification_multilabel", # if automatic inference of dataset type fails
)
sc.print()
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/evaluate_probabilities_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np

from krisi import score
from krisi.evaluate.type import Calculation
from krisi.utils.data import create_probabilities

num_labels = 2
Expand All @@ -19,7 +18,7 @@
y=np.random.randint(0, num_labels, num_samples),
predictions=np.random.randint(0, num_labels, num_samples),
probabilities=probabilities,
calculation=[Calculation.single, Calculation.rolling]
calculation="both"
# dataset_type="classification_multilabel", # if automatic inference of dataset type fails
)
sc.print()
Expand Down
2 changes: 1 addition & 1 deletion docs/walkthroughs/a_full_rundown_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@
}
],
"source": [
"scorecard_rolling = score(y, predictions, calculation='rolling') # calculation can be 'rolling', 'single' or 'benchmark'"
"scorecard_rolling = score(y, predictions, calculation='rolling') # calculation can be 'rolling', 'single' or 'both'"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/walkthroughs/a_full_rundown_notebook.md
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ The most important feature of `krisi` is that you can evaluate metrics over time


```python
scorecard_rolling = score(y, predictions, calculation='rolling') # calculation can be 'rolling', 'single' or 'benchmark'
scorecard_rolling = score(y, predictions, calculation='rolling') # calculation can be 'rolling', 'single' or 'both'
```

/Users/daniel/mambaforge/envs/krisitest/lib/python3.10/site-packages/sklearn/metrics/_regression.py:918: UndefinedMetricWarning:
Expand Down
39 changes: 16 additions & 23 deletions src/krisi/evaluate/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def score(
custom_metrics: Optional[Union[List[Metric], Metric]] = None,
dataset_type: Optional[Union[DatasetType, str]] = None,
sample_type: Union[str, SampleTypes] = SampleTypes.outofsample,
calculation: Union[
List[Union[Calculation, str]], Union[Calculation, str]
] = Calculation.single,
calculation: Union[Calculation, str] = Calculation.single,
rolling_args: Optional[Dict[str, Any]] = None,
raise_exceptions: bool = False,
benchmark_models: Optional[Union[Model, List[Model]]] = None,
Expand Down Expand Up @@ -63,11 +61,11 @@ def score(
- `SampleTypes.outofsample`
- `SampleTypes.insample`
calculation: Union[ List[Union[Calculation, str]], Union[Calculation, str] ], optional
Whether it should evaluate `Metrics` on a rolling basis or on the whole prediction or benchmark, by default Calculation.single
Whether it should evaluate `Metrics` on a rolling basis or on the whole prediction or both, by default Calculation.single
- `Calculation.single`
- `Calculation.rolling`
- `Calculation.benchmark`
- `Calculation.both`
rolling_args : Dict[str, Any], optional
Arguments to be passed onto `pd.DataFrame.rolling`.
Default:
Expand All @@ -86,9 +84,8 @@ def score(
If Calculation type is incorrectly specified.
"""

calculations = [
Calculation.from_str(calculation) for calculation in wrap_in_list(calculation)
]
calculation = Calculation.from_str(calculation)

benchmark_models = (
wrap_in_list(benchmark_models) if benchmark_models is not None else None
)
Expand All @@ -110,24 +107,20 @@ def score(
**kwargs,
)

assert any(
[
calc
in [
Calculation.single,
Calculation.benchmark,
Calculation.rolling,
]
for calc in calculations
]
), f"Calculation type {calculation} not recognized."
assert calculation in [
Calculation.single,
Calculation.rolling,
Calculation.both,
], f"Calculation type {calculation} not recognized."

if Calculation.single in calculations:
if calculation == Calculation.single:
sc.evaluate()
elif calculation == Calculation.rolling:
sc.evaluate_over_time()
elif calculation == Calculation.both:
sc.evaluate()
if Calculation.rolling in calculations:
sc.evaluate_over_time()
if Calculation.benchmark in calculations:
assert benchmark_models is not None, "You need to define a benchmark model!"
if benchmark_models is not None:
sc.evaluate_benchmark(benchmark_models)

sc.cleanup_group()
Expand Down
2 changes: 1 addition & 1 deletion src/krisi/evaluate/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class SampleTypes(ParsableEnum):
class Calculation(ParsableEnum):
single = "single"
rolling = "rolling"
benchmark = "benchmark"
both = "both"


class SaveModes(ParsableEnum):
Expand Down
5 changes: 0 additions & 5 deletions tests/test_benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
binary_classification_balanced_metrics,
f_one_score_macro,
)
from krisi.evaluate.type import Calculation
from krisi.sharedtypes import Task
from krisi.utils.data import (
generate_synthetic_data,
Expand All @@ -34,7 +33,6 @@ def test_benchmarking_random():
probabilities,
sample_weight=sample_weight,
default_metrics=[f_one_score_macro],
calculation=[Calculation.single, Calculation.benchmark],
benchmark_models=RandomClassifier(),
)
sc.print()
Expand Down Expand Up @@ -86,7 +84,6 @@ def test_benchmarking_random_all_metrics():
probabilities,
sample_weight=sample_weight,
default_metrics=binary_classification_balanced_metrics,
calculation=[Calculation.benchmark],
benchmark_models=RandomClassifierChunked(2),
)
sc.print()
Expand All @@ -105,7 +102,6 @@ def test_perfect_to_best():
probabilities,
sample_weight=sample_weight,
default_metrics=binary_classification_balanced_metrics,
calculation=Calculation.benchmark,
benchmark_models=[PerfectModel(), WorstModel()],
)
sc.print()
Expand Down Expand Up @@ -134,7 +130,6 @@ def test_benchmark_zscore():
probabilities,
sample_weight=sample_weight,
default_metrics=binary_classification_balanced_metrics,
calculation=Calculation.benchmark,
benchmark_models=[PerfectModel(), WorstModel()],
)
sc.print()
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/test_scorecard_getters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_spreading_comparions_results():
pd.Series(np.random.randint(2, size=100)),
pd.Series(np.random.randint(2, size=100)),
default_metrics=[f_one_score_macro],
calculation=["benchmark"],
benchmark_models=RandomClassifierChunked(0.05),
)

Expand All @@ -26,7 +25,6 @@ def test_getting_no_skill_metric():
pd.Series(np.random.randint(2, size=100)),
pd.Series(np.random.randint(2, size=100)),
default_metrics=[f_one_score_macro],
calculation=["benchmark"],
benchmark_models=RandomClassifierChunked(0.05),
)

Expand Down

0 comments on commit 625f6f4

Please sign in to comment.