Skip to content

Commit

Permalink
Add support of tokenized input for coref and srl predictors (allenai#…
Browse files Browse the repository at this point in the history
…2076)

* add support of tokenized input for coref and srl predictors

* change method signature

* fix PR review comments
  • Loading branch information
Alon Eirew authored and matt-gardner committed Nov 18, 2018
1 parent e3e8e1c commit af902a3
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dist/

.envrc
.python-version
.idea


# jupyter notebooks
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/tokenizers/word_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def split_words(self, sentence: str) -> List[Token]:
def _remove_spaces(tokens: List[spacy.tokens.Token]) -> List[spacy.tokens.Token]:
return [token for token in tokens if not token.is_space]


@WordSplitter.register('spacy')
class SpacyWordSplitter(WordSplitter):
"""
Expand All @@ -154,6 +155,7 @@ def split_words(self, sentence: str) -> List[Token]:
# This works because our Token class matches spacy's.
return _remove_spaces(self.spacy(sentence))


@WordSplitter.register('openai')
class OpenAISplitter(WordSplitter):
"""
Expand Down
33 changes: 32 additions & 1 deletion allennlp/predictors/coref.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List

from overrides import overrides

from allennlp.common.util import get_spacy_model
from allennlp.common.util import JsonDict
from allennlp.common.util import get_spacy_model
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
from allennlp.predictors.predictor import Predictor
Expand Down Expand Up @@ -53,6 +55,35 @@ def predict(self, document: str) -> JsonDict:
"""
return self.predict_json({"document" : document})

def predict_tokenized(self, tokenized_document: List[str]) -> JsonDict:
"""
Predict the coreference clusters in the given document.
Parameters
----------
tokenized_document : ``List[str]``
A list of words representation of a tokenized document.
Returns
-------
A dictionary representation of the predicted coreference clusters.
"""
instance = self._words_list_to_instance(tokenized_document)
return self.predict_instance(instance)

def _words_list_to_instance(self, words: List[str]) -> Instance:
"""
Create an instance from words list represent an already tokenized document,
for skipping tokenization when that information already exist for the user
"""
spacy_document = self._spacy.tokenizer.tokens_from_list(words)
for pipe in filter(None, self._spacy.pipeline):
pipe[1](spacy_document)

sentences = [[token.text for token in sentence] for sentence in spacy_document.sents]
instance = self._dataset_reader.text_to_instance(sentences)
return instance

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Expand Down
77 changes: 54 additions & 23 deletions allennlp/predictors/semantic_role_labeler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,33 @@ def predict(self, sentence: str) -> JsonDict:
-------
A dictionary representation of the semantic roles in the sentence.
"""
return self.predict_json({"sentence" : sentence})
return self.predict_json({"sentence": sentence})

def predict_tokenized(self, tokenized_sentence: List[str]) -> JsonDict:
"""
Predicts the semantic roles of the supplied sentence tokens and returns a dictionary
with the results.
Parameters
----------
tokenized_sentence, ``List[str]``
The sentence tokens to parse via semantic role labeling.
Returns
-------
A dictionary representation of the semantic roles in the sentence.
"""
spacy_doc = self._tokenizer.spacy.tokenizer.tokens_from_list(tokenized_sentence)
for pipe in filter(None, self._tokenizer.spacy.pipeline):
pipe[1](spacy_doc)

tokens = [token for token in spacy_doc]
instances = self.tokens_to_instances(tokens)

if not instances:
return sanitize({"verbs": [], "words": tokens})

return self.predict_instances(instances)

@staticmethod
def make_srl_string(words: List[str], tags: List[str]) -> str:
Expand Down Expand Up @@ -71,6 +96,17 @@ def make_srl_string(words: List[str], tags: List[str]) -> str:
def _json_to_instance(self, json_dict: JsonDict):
raise NotImplementedError("The SRL model uses a different API for creating instances.")

def tokens_to_instances(self, tokens):
words = [token.text for token in tokens]
instances: List[Instance] = []
for i, word in enumerate(tokens):
if word.pos_ == "VERB":
verb_labels = [0 for _ in words]
verb_labels[i] = 1
instance = self._dataset_reader.text_to_instance(tokens, verb_labels)
instances.append(instance)
return instances

def _sentence_to_srl_instances(self, json_dict: JsonDict) -> List[Instance]:
"""
The SRL model has a slightly different API from other models, as the model is run
Expand All @@ -92,15 +128,7 @@ def _sentence_to_srl_instances(self, json_dict: JsonDict) -> List[Instance]:
"""
sentence = json_dict["sentence"]
tokens = self._tokenizer.split_words(sentence)
words = [token.text for token in tokens]
instances: List[Instance] = []
for i, word in enumerate(tokens):
if word.pos_ == "VERB":
verb_labels = [0 for _ in words]
verb_labels[i] = 1
instance = self._dataset_reader.text_to_instance(tokens, verb_labels)
instances.append(instance)
return instances
return self.tokens_to_instances(tokens)

@overrides
def predict_batch_json(self, inputs: List[JsonDict]) -> List[JsonDict]:
Expand Down Expand Up @@ -178,6 +206,21 @@ def predict_batch_json(self, inputs: List[JsonDict]) -> List[JsonDict]:

return sanitize(return_dicts)

def predict_instances(self, instances: List[Instance]) -> JsonDict:
outputs = self._model.forward_on_instances(instances)

results = {"verbs": [], "words": outputs[0]["words"]}
for output in outputs:
tags = output['tags']
description = self.make_srl_string(output["words"], tags)
results["verbs"].append({
"verb": output["verb"],
"description": description,
"tags": tags,
})

return sanitize(results)

@overrides
def predict_json(self, inputs: JsonDict) -> JsonDict:
"""
Expand All @@ -198,16 +241,4 @@ def predict_json(self, inputs: JsonDict) -> JsonDict:
if not instances:
return sanitize({"verbs": [], "words": self._tokenizer.split_words(inputs["sentence"])})

outputs = self._model.forward_on_instances(instances)

results = {"verbs": [], "words": outputs[0]["words"]}
for output in outputs:
tags = output['tags']
description = self.make_srl_string(output["words"], tags)
results["verbs"].append({
"verb": output["verb"],
"description": description,
"tags": tags,
})

return sanitize(results)
return self.predict_instances(instances)
11 changes: 10 additions & 1 deletion allennlp/tests/predictors/coref_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@ def test_uses_named_inputs(self):
predictor = Predictor.from_archive(archive, 'coreference-resolution')

result = predictor.predict_json(inputs)
self.assert_predict_result(result)

document = ['This', 'is', 'a', 'single', 'string',
'document', 'about', 'a', 'test', '.', 'Sometimes',
'it', 'contains', 'coreferent', 'parts', '.']

result_doc_words = predictor.predict_tokenized(document)
self.assert_predict_result(result_doc_words)

@staticmethod
def assert_predict_result(result):
document = result["document"]
assert document == ['This', 'is', 'a', 'single', 'string',
'document', 'about', 'a', 'test', '.', 'Sometimes',
'it', 'contains', 'coreferent', 'parts', '.']

clusters = result["clusters"]
assert isinstance(clusters, list)
for cluster in clusters:
Expand Down
15 changes: 11 additions & 4 deletions allennlp/tests/predictors/srl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,28 @@ def test_uses_named_inputs(self):
archive = load_archive(self.FIXTURES_ROOT / 'srl' / 'serialization' / 'model.tar.gz')
predictor = Predictor.from_archive(archive, 'semantic-role-labeling')

result = predictor.predict_json(inputs)
result_json = predictor.predict_json(inputs)
self.assert_predict_result(result_json)

words = ["The", "squirrel", "wrote", "a", "unit", "test",
"to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."]

result_words = predictor.predict_tokenized(words)
self.assert_predict_result(result_words)

@staticmethod
def assert_predict_result(result):
words = result.get("words")
assert words == ["The", "squirrel", "wrote", "a", "unit", "test",
"to", "make", "sure", "its", "nuts", "worked", "as", "designed", "."]
num_words = len(words)

num_words = len(words)
verbs = result.get("verbs")
assert verbs is not None
assert isinstance(verbs, list)

assert any(v["verb"] == "wrote" for v in verbs)
assert any(v["verb"] == "make" for v in verbs)
assert any(v["verb"] == "worked" for v in verbs)

for verb in verbs:
tags = verb.get("tags")
assert tags is not None
Expand Down

0 comments on commit af902a3

Please sign in to comment.