-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add support of tokenized input for coref and srl predictors #2076
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ dist/ | |
|
||
.envrc | ||
.python-version | ||
.idea | ||
|
||
|
||
# jupyter notebooks | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from typing import List | ||
|
||
from overrides import overrides | ||
|
||
from allennlp.common.util import get_spacy_model | ||
from allennlp.common.util import JsonDict | ||
from allennlp.common.util import get_spacy_model | ||
from allennlp.data import DatasetReader, Instance | ||
from allennlp.models import Model | ||
from allennlp.predictors.predictor import Predictor | ||
|
@@ -53,6 +55,38 @@ def predict(self, document: str) -> JsonDict: | |
""" | ||
return self.predict_json({"document" : document}) | ||
|
||
def predict_from_list(self, tokenized_document: List[str]) -> JsonDict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer this to be called |
||
""" | ||
Predict the coreference clusters in the given document. | ||
|
||
Parameters | ||
---------- | ||
tokenized_document : ``List[str`` | ||
A list of words representation of a tokenized document. | ||
|
||
Returns | ||
------- | ||
A dictionary representation of the predicted coreference clusters. | ||
""" | ||
return self.predict_words_list(tokenized_document) | ||
|
||
def predict_words_list(self, words_list: List[str]) -> JsonDict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know why you have this extra method - it looks like it's doing exactly the same thing as |
||
instance = self._words_list_to_instance(words_list) | ||
return self.predict_instance(instance) | ||
|
||
def _words_list_to_instance(self, document_list: List[str]) -> Instance: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
Create an instance from words list represent an already tokenized document, | ||
for skipping tokenization when that information already exist for the user | ||
""" | ||
spacy_document = self._spacy.tokenizer.tokens_from_list(document_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like you're calling the same spacy pipeline on the document twice here; once inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here im using the actual Spacy tokenizer and not the WordSplitter one (which I use in srl), it will return a spacy.Doc object, I then run the pipeline on the doc object only once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I see - so yeah, you really don't need that extra method at all. |
||
for pipe in filter(None, self._spacy.pipeline): | ||
pipe[1](spacy_document) | ||
|
||
sentences = [[token.text for token in sentence] for sentence in spacy_document.sents] | ||
instance = self._dataset_reader.text_to_instance(sentences) | ||
return instance | ||
|
||
@overrides | ||
def _json_to_instance(self, json_dict: JsonDict) -> Instance: | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,8 +41,23 @@ def predict(self, sentence: str) -> JsonDict: | |
------- | ||
A dictionary representation of the semantic roles in the sentence. | ||
""" | ||
return self.predict_json({"sentence" : sentence}) | ||
return self.predict_json({"sentence": sentence}) | ||
|
||
def predict_from_list(self, tokenized_sentence: List[str]) -> JsonDict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
Predicts the semantic roles of the supplied sentence tokens and returns a dictionary | ||
with the results. | ||
|
||
Parameters | ||
---------- | ||
tokenized_sentence, ``List[str]`` | ||
The sentence tokens to parse via semantic role labeling. | ||
|
||
Returns | ||
------- | ||
A dictionary representation of the semantic roles in the sentence. | ||
""" | ||
return self.predict_words_list(tokenized_sentence) | ||
|
||
@staticmethod | ||
def make_srl_string(words: List[str], tags: List[str]) -> str: | ||
|
@@ -71,6 +86,17 @@ def make_srl_string(words: List[str], tags: List[str]) -> str: | |
def _json_to_instance(self, json_dict: JsonDict): | ||
raise NotImplementedError("The SRL model uses a different API for creating instances.") | ||
|
||
def tokens_to_instances(self, tokens): | ||
words = [token.text for token in tokens] | ||
instances: List[Instance] = [] | ||
for i, word in enumerate(tokens): | ||
if word.pos_ == "VERB": | ||
verb_labels = [0 for _ in words] | ||
verb_labels[i] = 1 | ||
instance = self._dataset_reader.text_to_instance(tokens, verb_labels) | ||
instances.append(instance) | ||
return instances | ||
|
||
def _sentence_to_srl_instances(self, json_dict: JsonDict) -> List[Instance]: | ||
""" | ||
The SRL model has a slightly different API from other models, as the model is run | ||
|
@@ -92,15 +118,7 @@ def _sentence_to_srl_instances(self, json_dict: JsonDict) -> List[Instance]: | |
""" | ||
sentence = json_dict["sentence"] | ||
tokens = self._tokenizer.split_words(sentence) | ||
words = [token.text for token in tokens] | ||
instances: List[Instance] = [] | ||
for i, word in enumerate(tokens): | ||
if word.pos_ == "VERB": | ||
verb_labels = [0 for _ in words] | ||
verb_labels[i] = 1 | ||
instance = self._dataset_reader.text_to_instance(tokens, verb_labels) | ||
instances.append(instance) | ||
return instances | ||
return self.tokens_to_instances(tokens) | ||
|
||
@overrides | ||
def predict_batch_json(self, inputs: List[JsonDict]) -> List[JsonDict]: | ||
|
@@ -178,6 +196,21 @@ def predict_batch_json(self, inputs: List[JsonDict]) -> List[JsonDict]: | |
|
||
return sanitize(return_dicts) | ||
|
||
def predict_instances(self, instances: List[Instance]) -> JsonDict: | ||
outputs = self._model.forward_on_instances(instances) | ||
|
||
results = {"verbs": [], "words": outputs[0]["words"]} | ||
for output in outputs: | ||
tags = output['tags'] | ||
description = self.make_srl_string(output["words"], tags) | ||
results["verbs"].append({ | ||
"verb": output["verb"], | ||
"description": description, | ||
"tags": tags, | ||
}) | ||
|
||
return sanitize(results) | ||
|
||
@overrides | ||
def predict_json(self, inputs: JsonDict) -> JsonDict: | ||
""" | ||
|
@@ -198,16 +231,17 @@ def predict_json(self, inputs: JsonDict) -> JsonDict: | |
if not instances: | ||
return sanitize({"verbs": [], "words": self._tokenizer.split_words(inputs["sentence"])}) | ||
|
||
outputs = self._model.forward_on_instances(instances) | ||
return self.predict_instances(instances) | ||
|
||
results = {"verbs": [], "words": outputs[0]["words"]} | ||
for output in outputs: | ||
tags = output['tags'] | ||
description = self.make_srl_string(output["words"], tags) | ||
results["verbs"].append({ | ||
"verb": output["verb"], | ||
"description": description, | ||
"tags": tags, | ||
}) | ||
def predict_words_list(self, words_list: List[str]) -> JsonDict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again the logic in this method should just be moved into |
||
""" | ||
Create an instance list of works document, for skipping tokenization when that | ||
information already exist for the user | ||
""" | ||
tokens = self._tokenizer.tokens_from_list(words_list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can't rely on this method actually existing, because you didn't create the method on the base class (I'm a little surprised that mypy didn't catch this; maybe because we're playing a little loose with the tokenizers inside of a predictor already...). But it's a lot easier than this: just do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean here I should just add There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of the line that you have, you should have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spacy POS tagging is required, I've tried to follow the class logic were all spacy tokenization and pipeline is under the hood via class There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's especially non-obvious from the diff I was looking at. And now that I look at the things that were hidden, I understand why mypy didn't catch this - I didn't realize that we specifically instantiated a |
||
instances = self.tokens_to_instances(tokens) | ||
|
||
return sanitize(results) | ||
if not instances: | ||
return sanitize({"verbs": [], "words": tokens}) | ||
|
||
return self.predict_instances(instances) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,21 +13,27 @@ def test_uses_named_inputs(self): | |
archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz') | ||
predictor = Predictor.from_archive(archive, 'semantic-role-labeling') | ||
|
||
result = predictor.predict_json(inputs) | ||
result_json = predictor.predict_json(inputs) | ||
self.assert_predict_result(result_json) | ||
|
||
words = ["The", "squirrel", "wrote", "a", "unit", "test", | ||
"to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."] | ||
|
||
result_words = predictor.predict_words_list(words) | ||
self.assert_predict_result(result_words) | ||
|
||
@staticmethod | ||
def assert_predict_result(result): | ||
words = result.get("words") | ||
assert words == ["The", "squirrel", "wrote", "a", "unit", "test", | ||
"to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."] | ||
num_words = len(words) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd not remove these blank lines. They separate coherent segments. |
||
verbs = result.get("verbs") | ||
assert verbs is not None | ||
assert isinstance(verbs, list) | ||
|
||
assert any(v["verb"] == "wrote" for v in verbs) | ||
assert any(v["verb"] == "make" for v in verbs) | ||
assert any(v["verb"] == "worked" for v in verbs) | ||
|
||
for verb in verbs: | ||
tags = verb.get("tags") | ||
assert tags is not None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't just add methods like this to a subclass without adding them to the base class. This breaks the API. You call
self._tokenizer.tokens_from_list()
in thePredictor
below, but that will crash with any tokenizer except the spacy tokenizer.