-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update retokenization tool to support RoBERTa; fix WSC #903
Changes from 18 commits
0cf7b23
38c5581
0c4546b
4e2734b
df3a271
9c1ba46
57076f3
0665933
c6c30fa
e41718d
174e564
9db0ddf
a5638c2
c4e1fc3
e2121fa
48defcb
9ab7b6d
dde2b9a
afbed57
5b13e5d
507f9d4
cb11cb7
e863f65
1894e10
cf422b3
2e94cce
c28ce99
d8ae32b
7402be1
59f9e50
4cdb3da
08b256a
a8abb35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,8 +19,9 @@ | |
# install with: pip install python-Levenshtein | ||
from Levenshtein.StringMatcher import StringMatcher | ||
|
||
from .tokenizers import get_tokenizer | ||
from .utils import unescape_moses | ||
from jiant.utils.tokenizers import get_tokenizer | ||
from jiant.utils.utils import unescape_moses | ||
|
||
|
||
# Tokenizer instance for internal use. | ||
_SIMPLE_TOKENIZER = SpaceTokenizer() | ||
|
@@ -98,7 +99,8 @@ def realign_spans(record, tokenizer_name): | |
""" | ||
Builds the indices alignment while also tokenizing the input | ||
piece by piece. | ||
Only BERT/XLNet and Moses tokenization is supported currently. | ||
Currently, SentencePiece(for XLNet), WPM(for BERT), BPE(for GPT/XLM, ByteBPE(for RoBERTa/GPT-2) | ||
HaokunLiu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
and Moses(for Transformer-XL and default) tokenization are supported. | ||
|
||
Parameters | ||
----------------------- | ||
|
@@ -294,6 +296,22 @@ def process_wordpiece_for_alignment(t): | |
return "<w>" + t | ||
|
||
|
||
def process_sentencepiece_for_alignment(t): | ||
"""Add <w> markers to ensure word-boundary alignment.""" | ||
if t.startswith("▁"): | ||
return "<w>" + re.sub(r"^▁", "", t) | ||
else: | ||
return t | ||
|
||
|
||
def process_bytebpe_for_alignment(t): | ||
"""Add <w> markers to ensure word-boundary alignment.""" | ||
if t.startswith("▁"): | ||
return "<w>" + re.sub(r"^Ġ", "", t) | ||
else: | ||
return t | ||
|
||
|
||
def space_tokenize_with_bow(sentence): | ||
"""Add <w> markers to ensure word-boundary alignment.""" | ||
return ["<w>" + t for t in sentence.split()] | ||
|
@@ -307,14 +325,6 @@ def align_moses(text: Text) -> Tuple[TokenAligner, List[Text]]: | |
return ta, moses_tokens | ||
|
||
|
||
def align_openai(text: Text) -> Tuple[TokenAligner, List[Text]]: | ||
eow_tokens = space_tokenize_with_eow(text) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is slightly different from the new version, because it added end-of-word markers (as used in the original GPT) instead of beginning-of-word markers to ensure correct character overlap. We should probably add something to Otherwise, it might be safer to just replace |
||
openai_utils = get_tokenizer("OpenAI.BPE") | ||
bpe_tokens = openai_utils.tokenize(text) | ||
ta = TokenAligner(eow_tokens, bpe_tokens) | ||
return ta, bpe_tokens | ||
|
||
|
||
def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]: | ||
# If using lowercase, do this for the source tokens for better matching. | ||
do_lower_case = tokenizer_name.endswith("uncased") | ||
|
@@ -328,12 +338,48 @@ def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text] | |
return ta, wpm_tokens | ||
|
||
|
||
def align_sentencepiece(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]: | ||
bow_tokens = space_tokenize_with_bow(text) | ||
sentencepiece_tokenizer = get_tokenizer(tokenizer_name) | ||
sentencepiece_tokens = sentencepiece_tokenizer.tokenize(text) | ||
|
||
modified_sentencepiece_tokens = list( | ||
map(process_sentencepiece_for_alignment, sentencepiece_tokens) | ||
) | ||
ta = TokenAligner(bow_tokens, modified_sentencepiece_tokens) | ||
return ta, sentencepiece_tokens | ||
|
||
|
||
def align_bpe(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add function docstrings giving examples and stating which models these apply to? Since these are fairly dense and involve a number of edge cases, should we write some unit tests for this module? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am afraid it will not be very economical to do this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it'll be hard - this module has a very minimal API so it should be easy to write tests. We are refactoring and adding new tokenizers here, and we'd want tests both to ensure there aren't regressions and that the new functionality is doing what we expect. (This is partly my fault for not having any on the original version, but since it's grown quite a bit in scope and tokenization bugs can be very hard to detect otherwise, I think it's warranted now.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. I'll do it later today. |
||
eow_tokens = space_tokenize_with_eow(text.lower()) | ||
bpe_tokenizer = get_tokenizer(tokenizer_name) | ||
bpe_tokens = bpe_tokenizer.tokenize(text) | ||
ta = TokenAligner(eow_tokens, bpe_tokens) | ||
return ta, bpe_tokens | ||
|
||
|
||
def align_bytebpe(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]: | ||
bow_tokens = space_tokenize_with_bow(text) | ||
bytebpe_tokenizer = get_tokenizer(tokenizer_name) | ||
bytebpe_tokens = bytebpe_tokenizer.tokenize(text) | ||
|
||
modified_bytebpe_tokens = list(map(process_bytebpe_for_alignment, bytebpe_tokens)) | ||
if len(modified_bytebpe_tokens) > 0: | ||
modified_bytebpe_tokens[0] = "<w>" + modified_bytebpe_tokens[0] | ||
ta = TokenAligner(bow_tokens, modified_bytebpe_tokens) | ||
return ta, bytebpe_tokens | ||
|
||
|
||
def get_aligner_fn(tokenizer_name: Text): | ||
if tokenizer_name == "MosesTokenizer": | ||
if tokenizer_name == "MosesTokenizer" or tokenizer_name.startswith("transfo-xl-"): | ||
return align_moses | ||
elif tokenizer_name == "OpenAI.BPE": | ||
return align_openai | ||
elif tokenizer_name.startswith("bert-") or tokenizer_name.startswith("xlnet-"): | ||
elif tokenizer_name.startswith("bert-"): | ||
return functools.partial(align_wpm, tokenizer_name=tokenizer_name) | ||
elif tokenizer_name.startswith("openai-gpt") or tokenizer_name.startswith("xlm-mlm-en-"): | ||
return functools.partial(align_bpe, tokenizer_name=tokenizer_name) | ||
elif tokenizer_name.startswith("xlnet-"): | ||
return functools.partial(align_sentencepiece, tokenizer_name=tokenizer_name) | ||
elif tokenizer_name.startswith("roberta-") or tokenizer_name.startswith("gpt2"): | ||
return functools.partial(align_bytebpe, tokenizer_name=tokenizer_name) | ||
else: | ||
raise ValueError(f"Unsupported tokenizer '{tokenizer_name}'") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks reasonable. Have you confirmed that no other tasks use the old logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no more tasks using old logic now.