Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

find-threshold: CLI command for multi-label classifier threshold tuning #11280

Merged
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
0e5cd6b
Add foundation for find-threshold CLI functionality.
rmitsch Aug 5, 2022
4981700
Finish first draft for find-threshold.
rmitsch Aug 8, 2022
1d0f5d3
Add tests.
rmitsch Aug 8, 2022
a7b56e8
Revert adjusted import statements.
rmitsch Aug 8, 2022
d689d97
Fix mypy errors.
rmitsch Aug 9, 2022
6c3ae8d
Fix imports.
rmitsch Aug 9, 2022
63c8028
Harmonize arguments with spacy evaluate command.
rmitsch Aug 30, 2022
3a0a385
Generalize component and threshold handling. Harmonize arguments with…
rmitsch Sep 1, 2022
51863cd
Fix Spancat test.
rmitsch Sep 1, 2022
ea9737a
Add beta parameter to Scorer and PRFScore.
rmitsch Sep 1, 2022
110850f
Make beta a component scorer setting.
rmitsch Sep 2, 2022
24b69a1
Remove beta.
rmitsch Sep 2, 2022
73432c6
Update nlp.config (workaround).
rmitsch Sep 2, 2022
20c4a0d
Reload pipeline on threshold change. Adjust tests. Remove confection …
rmitsch Sep 5, 2022
03666f6
Remove assumption of component being a Pipe object or having a .cfg a…
rmitsch Sep 5, 2022
b61cf87
Adjust test output and reference values.
rmitsch Sep 5, 2022
9c00b28
Remove beta references. Delete universe.json.
rmitsch Sep 12, 2022
65e41a5
Reverting unnecessary changes. Removing unused default values. Renami…
rmitsch Sep 29, 2022
604c5ea
Update spacy/cli/find_threshold.py
rmitsch Sep 29, 2022
58d5c99
Remove adding labels in tests.
rmitsch Sep 29, 2022
9e2eea1
Merge remote-tracking branch 'upstream/master' into feature/classifie…
adrianeboyd Oct 21, 2022
08c0c41
Remove unused error
adrianeboyd Oct 21, 2022
67596fc
Undo changes to PRFScorer
adrianeboyd Oct 21, 2022
9d947a4
Change default value for n_trials. Log table iteratively.
rmitsch Oct 24, 2022
19dd45f
Add warnings for pointless applications of find_threshold().
rmitsch Oct 28, 2022
5bacad8
Fix imports.
rmitsch Oct 28, 2022
5de02dc
Adjust type check of TextCategorizer to exclude subclasses.
rmitsch Oct 28, 2022
34c6c3b
Change check of if there's only one unique value in scores.
rmitsch Nov 11, 2022
ba857c6
Attempt merging after reconciling diverging master branches.
rmitsch Nov 11, 2022
5500a58
Update spacy/cli/find_threshold.py
rmitsch Nov 17, 2022
d080808
Incorporate feedback.
rmitsch Nov 17, 2022
7b4da3f
Fix test issue. Update docstring.
rmitsch Nov 17, 2022
809588d
Update docs & docstring.
rmitsch Nov 17, 2022
42a8208
Merge branch 'master' into feature/classifier-threshold-tuning
rmitsch Nov 17, 2022
3f9d879
Update spacy/tests/test_cli.py
rmitsch Nov 17, 2022
dd84d65
Add examples to docs. Rename _nlp to nlp in tests.
rmitsch Nov 17, 2022
0ee2257
Update spacy/cli/find_threshold.py
rmitsch Nov 17, 2022
bbfef28
Update spacy/cli/find_threshold.py
rmitsch Nov 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spacy/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .project.push import project_push # noqa: F401
from .project.pull import project_pull # noqa: F401
from .project.document import project_document # noqa: F401
from .find_threshold import find_threshold # noqa: F401


@app.command("link", no_args_is_help=True, deprecated=True, hidden=True)
Expand Down
173 changes: 173 additions & 0 deletions spacy/cli/find_threshold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import functools
from functools import partial
import operator
from pathlib import Path
import logging
from typing import Optional, Tuple, Any, Dict, List

import numpy
import wasabi.tables

from ..pipeline import TrainablePipe, Pipe
from ..errors import Errors
from ..training import Corpus
from ._util import app, Arg, Opt, import_code, setup_gpu
from .. import util

_DEFAULTS = {
"average": "micro",
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
"n_trials": 10,
"beta": 1,
"use_gpu": -1,
"gold_preproc": False,
}


@app.command(
"find-threshold",
context_settings={"allow_extra_args": False, "ignore_unknown_options": True},
)
def find_threshold_cli(
# fmt: off
model: str = Arg(..., help="Model name or path"),
data_path: Path = Arg(..., help="Location of binary evaluation data in .spacy format", exists=True),
pipe_name: str = Arg(..., help="Name of pipe to examine thresholds for"),
threshold_key: str = Arg(..., help="Key of threshold attribute in component's configuration"),
scores_key: str = Arg(..., help="Name of score to metric to optimize"),
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
n_trials: int = Opt(_DEFAULTS["n_trials"], "--n_trials", "-n", help="Number of trials to determine optimal thresholds"),
beta: float = Opt(_DEFAULTS["beta"], "--beta", help="Beta for F1 calculation. Ignored if different metric is used"),
code_path: Optional[Path] = Opt(None, "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
use_gpu: int = Opt(_DEFAULTS["use_gpu"], "--gpu-id", "-g", help="GPU ID or -1 for CPU"),
gold_preproc: bool = Opt(_DEFAULTS["gold_preproc"], "--gold-preproc", "-G", help="Use gold preprocessing"),
verbose: bool = Opt(False, "--silent", "-V", "-VV", help="Display more information for debugging purposes"),
# fmt: on
):
"""
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric from CLI.
model (Path): Path to file with trained model.
data_path (Path): Path to file with DocBin with docs to use for threshold search.
pipe_name (str): Name of pipe to examine thresholds for.
threshold_key (str): Key of threshold attribute in component's configuration.
scores_key (str): Name of score to metric to optimize.
n_trials (int): Number of trials to determine optimal thresholds
beta (float): Beta for F-score calculation.
code_path (Optional[Path]): Path to Python file with additional code (registered functions) to be imported.
use_gpu (int): GPU ID or -1 for CPU.
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
to train/test skew.
silent (bool): Display more information for debugging purposes
"""

util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
import_code(code_path)
find_threshold(
model=model,
data_path=data_path,
pipe_name=pipe_name,
threshold_key=threshold_key,
scores_key=scores_key,
n_trials=n_trials,
beta=beta,
use_gpu=use_gpu,
gold_preproc=gold_preproc,
silent=False,
)


def find_threshold(
model: str,
data_path: Path,
pipe_name: str,
threshold_key: str,
scores_key: str,
*,
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
beta: float = _DEFAULTS["beta"], # type: ignore
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
silent: bool = True,
) -> Tuple[float, float]:
"""
Runs prediction trials for `textcat` models with varying tresholds to maximize the specified metric.
rmitsch marked this conversation as resolved.
Show resolved Hide resolved
model (Union[str, Path]): Path to file with trained model.
data_path (Union[str, Path]): Path to file with DocBin with docs to use for threshold search.
pipe_name (str): Name of pipe to examine thresholds for.
threshold_key (str): Key of threshold attribute in component's configuration.
scores_key (str): Name of score to metric to optimize.
n_trials (int): Number of trials to determine optimal thresholds.
beta (float): Beta for F-score calculation.
use_gpu (int): GPU ID or -1 for CPU.
gold_preproc (bool): Whether to use gold preprocessing. Gold preprocessing helps the annotations align to the
tokenization, and may result in sequences of more consistent length. However, it may reduce runtime accuracy due
to train/test skew.
silent (bool): Whether to print non-error-related output to stdout.
RETURNS (Tuple[float, float]): Best found threshold with corresponding F-score.
"""

setup_gpu(use_gpu, silent=silent)
data_path = util.ensure_path(data_path)
if not data_path.exists():
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model)

pipe: Optional[Pipe] = None
try:
pipe = nlp.get_pipe(pipe_name)
except KeyError as err:
wasabi.msg.fail(title=str(err), exits=1)
rmitsch marked this conversation as resolved.
Show resolved Hide resolved

if not isinstance(pipe, TrainablePipe):
raise TypeError(Errors.E1044)
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
setattr(pipe, "scorer", partial(pipe.scorer.func, beta=beta))

if not silent:
wasabi.msg.info(
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
f"trials and beta = {beta}."
)

# Load evaluation corpus.
corpus = Corpus(data_path, gold_preproc=gold_preproc)
dev_dataset = list(corpus(nlp))
config_keys = threshold_key.split(".")

def set_nested_item(
config: Dict[str, Any], keys: List[str], value: float
) -> Dict[str, Any]:
"""Set item in nested dictionary. Adapated from https://stackoverflow.com/a/54138200.
config (Dict[str, Any]): Configuration dictionary.
keys (List[Any]):
value (float): Value to set.
RETURNS (Dict[str, Any]): Updated dictionary.
"""
functools.reduce(operator.getitem, keys[:-1], config)[keys[-1]] = value
return config

# Evaluate with varying threshold values.
scores: Dict[float, float] = {}
for threshold in numpy.linspace(0, 1, n_trials):
pipe.cfg = set_nested_item(pipe.cfg, config_keys, threshold)
scores[threshold] = nlp.evaluate(dev_dataset)[scores_key]
svlandeg marked this conversation as resolved.
Show resolved Hide resolved
if not (
isinstance(scores[threshold], float) or isinstance(scores[threshold], int)
):
wasabi.msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
f"scores.",
exits=1,
)

best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
if not silent:
print(
f"Best threshold: {round(best_threshold, ndigits=4)} with value of {scores[best_threshold]}.",
wasabi.tables.table(
data=[(threshold, score) for threshold, score in scores.items()],
header=["Threshold", f"{scores_key}"],
),
)

return best_threshold, scores[best_threshold]
2 changes: 2 additions & 0 deletions spacy/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,8 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
E1044 = ("`find_threshold()` only supports components of type `TrainablePipe`.")
E1045 = ("`find_threshold()` only supports components with a `scorer` attribute.")


# Deprecated model shortcuts, only used in errors and warnings
Expand Down
5 changes: 3 additions & 2 deletions spacy/pipeline/spancat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
Expand Down Expand Up @@ -165,8 +166,8 @@ def spancat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:


@registry.scorers("spacy.spancat_scorer.v1")
def make_spancat_scorer():
return spancat_score
def make_spancat_scorer(beta: float = 1.0):
return partial(spancat_score, beta=beta)


class SpanCategorizer(TrainablePipe):
Expand Down
5 changes: 3 additions & 2 deletions spacy/pipeline/textcat_multilabel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Iterable, Optional, Dict, List, Callable, Any
from thinc.types import Floats2d
from thinc.api import Model, Config
Expand Down Expand Up @@ -121,8 +122,8 @@ def textcat_multilabel_score(examples: Iterable[Example], **kwargs) -> Dict[str,


@registry.scorers("spacy.textcat_multilabel_scorer.v1")
def make_textcat_multilabel_scorer():
return textcat_multilabel_score
def make_textcat_multilabel_scorer(beta: float = 1.0):
return partial(textcat_multilabel_score, beta=beta)


class MultiLabel_TextCategorizer(TextCategorizer):
Expand Down
Loading