Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update retokenization tool to support RoBERTa; fix WSC #903

Merged
merged 33 commits into from
Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
0cf7b23
Rename namespaces to suppress warnings.
sleepinyourhat Jul 12, 2019
38c5581
Revert "Rename namespaces to suppress warnings."
sleepinyourhat Jul 12, 2019
0c4546b
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 15, 2019
4e2734b
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 21, 2019
df3a271
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 22, 2019
9c1ba46
Initial attempt.
sleepinyourhat Jul 24, 2019
57076f3
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Jul 24, 2019
0665933
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Aug 6, 2019
c6c30fa
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Aug 8, 2019
e41718d
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Aug 21, 2019
174e564
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Aug 25, 2019
9db0ddf
Merge branch 'master' of https://github.com/nyu-mll/jiant into nyu-ma…
sleepinyourhat Aug 26, 2019
a5638c2
Fix WSC retokenization.
sleepinyourhat Aug 27, 2019
c4e1fc3
Remove obnoxious newline.
sleepinyourhat Aug 27, 2019
e2121fa
Merge branch 'master' of https://github.com/nyu-mll/jiant into fix-re…
sleepinyourhat Aug 27, 2019
48defcb
fix retokenize
HaokunLiu Aug 27, 2019
9ab7b6d
debug
HaokunLiu Aug 27, 2019
dde2b9a
WiC fix
HaokunLiu Aug 28, 2019
afbed57
add spaces in docstring
HaokunLiu Aug 28, 2019
5b13e5d
update record task
HaokunLiu Aug 28, 2019
507f9d4
clean up
HaokunLiu Aug 28, 2019
cb11cb7
Merge branch 'master' into fix-retokenization
sleepinyourhat Aug 29, 2019
e863f65
"@placeholder" fix
HaokunLiu Sep 4, 2019
1894e10
max_seq_len fix
HaokunLiu Sep 4, 2019
cf422b3
black
HaokunLiu Sep 4, 2019
2e94cce
add docstring
HaokunLiu Sep 7, 2019
c28ce99
update docstring
HaokunLiu Sep 7, 2019
d8ae32b
add test script for retokenize
HaokunLiu Sep 7, 2019
7402be1
Revert "add test script for retokenize"
HaokunLiu Sep 7, 2019
59f9e50
Create test_retokenize.py
HaokunLiu Sep 7, 2019
4cdb3da
update to pytorch_transformer 1.2.0
HaokunLiu Sep 10, 2019
08b256a
package, download updates
HaokunLiu Sep 11, 2019
a8abb35
Merge branch 'master' into fix-retokenization
HaokunLiu Sep 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions jiant/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,19 @@ def get_split_text(self, split: str):
def load_data_for_path(self, path, split):
""" Load data """

def tokenize_preserve_placeholder(sent):
def tokenize_preserve_placeholder(sent, max_ent_length):
""" Tokenize questions while preserving @placeholder token """
sent_parts = sent.split("@placeholder")
assert len(sent_parts) == 2
sent_parts = [
tokenize_and_truncate(self.tokenizer_name, s, self.max_seq_len) for s in sent_parts
]
return sent_parts[0] + ["@placeholder"] + sent_parts[1]
placeholder_loc = len(
tokenize_and_truncate(
self.tokenizer_name, sent_parts[0], self.max_seq_len - max_ent_length
)
)
sent_tok = tokenize_and_truncate(
self.tokenizer_name, sent, self.max_seq_len - max_ent_length
)
return sent_tok[:placeholder_loc] + ["@placeholder"] + sent_tok[placeholder_loc:]

examples = []
data = [json.loads(d) for d in open(path, encoding="utf-8")]
Expand All @@ -287,9 +292,10 @@ def tokenize_preserve_placeholder(sent):
)
ent_idxs = item["passage"]["entities"]
ents = [item["passage"]["text"][idx["start"] : idx["end"] + 1] for idx in ent_idxs]
max_ent_length = max([idx["end"] - idx["start"] + 1 for idx in ent_idxs])
qas = item["qas"]
for qa in qas:
qst = tokenize_preserve_placeholder(qa["query"])
qst = tokenize_preserve_placeholder(qa["query"], max_ent_length)
qst_id = qa["idx"]
if "answers" in qa:
anss = [a["text"] for a in qa["answers"]]
Expand Down
3 changes: 1 addition & 2 deletions jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2601,9 +2601,8 @@ def _process_preserving_word(sent, word):
sequence the marked word is located. """
sent_parts = sent.split(word)
sent_tok1 = tokenize_and_truncate(self._tokenizer_name, sent_parts[0], self.max_seq_len)
sent_tok2 = tokenize_and_truncate(self._tokenizer_name, sent_parts[1], self.max_seq_len)
sent_mid = tokenize_and_truncate(self._tokenizer_name, word, self.max_seq_len)
sent_tok = sent_tok1 + sent_mid + sent_tok2
sent_tok = tokenize_and_truncate(self._tokenizer_name, sent, self.max_seq_len)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks reasonable. Have you confirmed that no other tasks use the old logic?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no more tasks using old logic now.

start_idx = len(sent_tok1)
end_idx = start_idx + len(sent_mid)
assert end_idx > start_idx, "Invalid marked word indices. Something is wrong."
Expand Down
77 changes: 62 additions & 15 deletions jiant/utils/retokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
# install with: pip install python-Levenshtein
from Levenshtein.StringMatcher import StringMatcher

from .tokenizers import get_tokenizer
from .utils import unescape_moses
from jiant.utils.tokenizers import get_tokenizer
from jiant.utils.utils import unescape_moses


# Tokenizer instance for internal use.
_SIMPLE_TOKENIZER = SpaceTokenizer()
Expand Down Expand Up @@ -98,7 +99,9 @@ def realign_spans(record, tokenizer_name):
"""
Builds the indices alignment while also tokenizing the input
piece by piece.
Only BERT/XLNet and Moses tokenization is supported currently.
Currently, SentencePiece (for XLNet), WPM (for BERT), BPE (for GPT/XLM),
ByteBPE (for RoBERTa/GPT-2) and Moses (for Transformer-XL and default) tokenization are
supported.

Parameters
-----------------------
Expand Down Expand Up @@ -294,6 +297,22 @@ def process_wordpiece_for_alignment(t):
return "<w>" + t


def process_sentencepiece_for_alignment(t):
"""Add <w> markers to ensure word-boundary alignment."""
if t.startswith("▁"):
return "<w>" + re.sub(r"^▁", "", t)
else:
return t


def process_bytebpe_for_alignment(t):
"""Add <w> markers to ensure word-boundary alignment."""
if t.startswith("▁"):
return "<w>" + re.sub(r"^Ġ", "", t)
else:
return t


def space_tokenize_with_bow(sentence):
"""Add <w> markers to ensure word-boundary alignment."""
return ["<w>" + t for t in sentence.split()]
Expand All @@ -307,14 +326,6 @@ def align_moses(text: Text) -> Tuple[TokenAligner, List[Text]]:
return ta, moses_tokens


def align_openai(text: Text) -> Tuple[TokenAligner, List[Text]]:
eow_tokens = space_tokenize_with_eow(text)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is slightly different from the new version, because it added end-of-word markers (as used in the original GPT) instead of beginning-of-word markers to ensure correct character overlap. We should probably add something to align_wpm to detect this and do the appropriate padding?

Otherwise, it might be safer to just replace align_wpm with a much simpler implementation than the black-box character-based alignment we have now - just split into the original tokens, then split each one into wordpieces while tracking offsets. This is recommended for BERT (https://github.com/google-research/bert#tokenization), but not sure if it's compatible with SentencePiece or other subword models.

openai_utils = get_tokenizer("OpenAI.BPE")
bpe_tokens = openai_utils.tokenize(text)
ta = TokenAligner(eow_tokens, bpe_tokens)
return ta, bpe_tokens


def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
# If using lowercase, do this for the source tokens for better matching.
do_lower_case = tokenizer_name.endswith("uncased")
Expand All @@ -328,12 +339,48 @@ def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]
return ta, wpm_tokens


def align_sentencepiece(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
bow_tokens = space_tokenize_with_bow(text)
sentencepiece_tokenizer = get_tokenizer(tokenizer_name)
sentencepiece_tokens = sentencepiece_tokenizer.tokenize(text)

modified_sentencepiece_tokens = list(
map(process_sentencepiece_for_alignment, sentencepiece_tokens)
)
ta = TokenAligner(bow_tokens, modified_sentencepiece_tokens)
return ta, sentencepiece_tokens


def align_bpe(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add function docstrings giving examples and stating which models these apply to?

Since these are fairly dense and involve a number of edge cases, should we write some unit tests for this module?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am afraid it will not be very economical to do this.
The only occasions we need to modify this, is either adding new tokenizer or refactoring retokenize. When the first happens, the old test cases can always pass, and we always need to add new test cases (which has little reuse value). When the second happens, I don't expect we will keep the same interface.
I wonder what do you think about it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it'll be hard - this module has a very minimal API so it should be easy to write tests. We are refactoring and adding new tokenizers here, and we'd want tests both to ensure there aren't regressions and that the new functionality is doing what we expect.

(This is partly my fault for not having any on the original version, but since it's grown quite a bit in scope and tokenization bugs can be very hard to detect otherwise, I think it's warranted now.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. I'll do it later today.

eow_tokens = space_tokenize_with_eow(text.lower())
bpe_tokenizer = get_tokenizer(tokenizer_name)
bpe_tokens = bpe_tokenizer.tokenize(text)
ta = TokenAligner(eow_tokens, bpe_tokens)
return ta, bpe_tokens


def align_bytebpe(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
bow_tokens = space_tokenize_with_bow(text)
bytebpe_tokenizer = get_tokenizer(tokenizer_name)
bytebpe_tokens = bytebpe_tokenizer.tokenize(text)

modified_bytebpe_tokens = list(map(process_bytebpe_for_alignment, bytebpe_tokens))
if len(modified_bytebpe_tokens) > 0:
modified_bytebpe_tokens[0] = "<w>" + modified_bytebpe_tokens[0]
ta = TokenAligner(bow_tokens, modified_bytebpe_tokens)
return ta, bytebpe_tokens


def get_aligner_fn(tokenizer_name: Text):
if tokenizer_name == "MosesTokenizer":
if tokenizer_name == "MosesTokenizer" or tokenizer_name.startswith("transfo-xl-"):
return align_moses
elif tokenizer_name == "OpenAI.BPE":
return align_openai
elif tokenizer_name.startswith("bert-") or tokenizer_name.startswith("xlnet-"):
elif tokenizer_name.startswith("bert-"):
return functools.partial(align_wpm, tokenizer_name=tokenizer_name)
elif tokenizer_name.startswith("openai-gpt") or tokenizer_name.startswith("xlm-mlm-en-"):
return functools.partial(align_bpe, tokenizer_name=tokenizer_name)
elif tokenizer_name.startswith("xlnet-"):
return functools.partial(align_sentencepiece, tokenizer_name=tokenizer_name)
elif tokenizer_name.startswith("roberta-") or tokenizer_name.startswith("gpt2"):
return functools.partial(align_bytebpe, tokenizer_name=tokenizer_name)
else:
raise ValueError(f"Unsupported tokenizer '{tokenizer_name}'")