Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training.before_update callback #11739

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions spacy/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,6 @@ def getopt(opt):
pytest.skip("not referencing any issues")


@pytest.fixture
def test_dir(request):
print(request.fspath)
return Path(request.fspath).parent


# Fixtures for language tokenizers (languages sorted alphabetically)


Expand Down
82 changes: 31 additions & 51 deletions spacy/tests/training/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
from spacy.training.alignment_array import AlignmentArray
from spacy.training.align import get_alignments
from spacy.training.converters import json_to_docs
from spacy.training.initialize import init_nlp
from spacy.training.loop import train
from spacy.training.loop import train_while_improving
from spacy.util import get_words_and_spaces, load_model_from_path, minibatch
from spacy.util import load_config_from_str, registry, load_model_from_config
from thinc.api import compounding
from spacy.util import load_config_from_str, load_model_from_config
from thinc.api import compounding, Adam

from ..util import make_tempdir

Expand Down Expand Up @@ -1146,59 +1145,40 @@ def test_retokenized_docs(doc):
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.width}

[corpora]

[corpora.train]
@readers = "spacy.Corpus.v1"
path = null

[corpora.dev]
@readers = "spacy.Corpus.v1"
path = null

[training]
train_corpus = "corpora.train"
dev_corpus = "corpora.dev"
seed = 1
gpu_allocator = "pytorch"
dropout = 0.1
accumulate_gradient = 3
patience = 5000
max_epochs = 1
max_steps = 6
eval_frequency = 10

[training.batcher]
@batchers = "spacy.batch_by_padded.v1"
discard_oversize = False
get_length = null
size = 1
buffer = 256
"""


def test_training_before_update(test_dir):
ran_before_update = False
def test_training_before_update(doc):
def before_update(nlp, args):
assert args["step"] == 0
assert args["epoch"] == 1

@registry.callbacks(f"test_training_before_update_callback")
def make_before_creation():
def before_update(nlp, args):
nonlocal ran_before_update
ran_before_update = True
assert "step" in args
assert "epoch" in args
# Raise an error here as the rest of the loop
# will not run to completion due to uninitialized
# models.
raise ValueError("ran_before_update")

return before_update
def generate_batch():
yield 1, [Example(doc, doc)]

config = Config().from_str(training_config_string, interpolate=False)
shadeMe marked this conversation as resolved.
Show resolved Hide resolved
config["corpora"]["train"]["path"] = str(test_dir / "toy-en-corpus.spacy")
config["corpora"]["dev"]["path"] = str(test_dir / "toy-en-corpus.spacy")
config["training"]["before_update"] = {
"@callbacks": "test_training_before_update_callback"
}
nlp = load_model_from_config(config, auto_fill=True, validate=True)
optimizer = Adam()
generator = train_while_improving(
nlp,
optimizer,
generate_batch(),
lambda: None,
dropout=0.1,
eval_frequency=100,
accumulate_gradient=10,
patience=10,
max_steps=100,
exclude=[],
annotating_components=[],
before_update=before_update,
)

nlp = init_nlp(config)
train(nlp)
assert ran_before_update == True
with pytest.raises(ValueError, match="ran_before_update"):
for _ in generator:
pass
Binary file removed spacy/tests/training/toy-en-corpus.spacy
Binary file not shown.