Skip to content

Commit

Permalink
Update retokenization tool to support RoBERTa; fix WSC (#903)
Browse files Browse the repository at this point in the history
* Rename namespaces to suppress warnings.

* Revert "Rename namespaces to suppress warnings."

This reverts commit 0cf7b23.

* Initial attempt.

* Fix WSC retokenization.

* Remove obnoxious newline.

* fix retokenize

* debug

* WiC fix

* add spaces in docstring

* update record task

* clean up

* "@Placeholder" fix

* max_seq_len fix

* black

* add docstring

* update docstring

* add test script for retokenize

* Revert "add test script for retokenize"

* Create test_retokenize.py

* update to pytorch_transformer 1.2.0

* package, download updates
  • Loading branch information
sleepinyourhat authored Sep 13, 2019
1 parent a59184d commit 8e51a95
Show file tree
Hide file tree
Showing 7 changed files with 442 additions and 23 deletions.
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
command: |
source venv/bin/activate
python -m nltk.downloader perluniprops nonbreaking_prefixes punkt
python -m spacy download en
python main.py --config_file jiant/config/demo.conf
python main.py --config_file jiant/config/demo.conf --overrides "do_pretrain = 0, do_target_task_training = 0, load_model = 1"
# Step 8: run tests
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies:

# for some tokenizers in pytorch-transformers
- spacy==2.1
- ftfy

# for tokenization
- nltk==3.4.5
Expand Down
18 changes: 12 additions & 6 deletions jiant/tasks/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,19 @@ def get_split_text(self, split: str):
def load_data_for_path(self, path, split):
""" Load data """

def tokenize_preserve_placeholder(sent):
def tokenize_preserve_placeholder(sent, max_ent_length):
""" Tokenize questions while preserving @placeholder token """
sent_parts = sent.split("@placeholder")
assert len(sent_parts) == 2
sent_parts = [
tokenize_and_truncate(self.tokenizer_name, s, self.max_seq_len) for s in sent_parts
]
return sent_parts[0] + ["@placeholder"] + sent_parts[1]
placeholder_loc = len(
tokenize_and_truncate(
self.tokenizer_name, sent_parts[0], self.max_seq_len - max_ent_length
)
)
sent_tok = tokenize_and_truncate(
self.tokenizer_name, sent, self.max_seq_len - max_ent_length
)
return sent_tok[:placeholder_loc] + ["@placeholder"] + sent_tok[placeholder_loc:]

examples = []
data = [json.loads(d) for d in open(path, encoding="utf-8")]
Expand All @@ -287,9 +292,10 @@ def tokenize_preserve_placeholder(sent):
)
ent_idxs = item["passage"]["entities"]
ents = [item["passage"]["text"][idx["start"] : idx["end"] + 1] for idx in ent_idxs]
max_ent_length = max([idx["end"] - idx["start"] + 1 for idx in ent_idxs])
qas = item["qas"]
for qa in qas:
qst = tokenize_preserve_placeholder(qa["query"])
qst = tokenize_preserve_placeholder(qa["query"], max_ent_length)
qst_id = qa["idx"]
if "answers" in qa:
anss = [a["text"] for a in qa["answers"]]
Expand Down
3 changes: 1 addition & 2 deletions jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2601,9 +2601,8 @@ def _process_preserving_word(sent, word):
sequence the marked word is located. """
sent_parts = sent.split(word)
sent_tok1 = tokenize_and_truncate(self._tokenizer_name, sent_parts[0], self.max_seq_len)
sent_tok2 = tokenize_and_truncate(self._tokenizer_name, sent_parts[1], self.max_seq_len)
sent_mid = tokenize_and_truncate(self._tokenizer_name, word, self.max_seq_len)
sent_tok = sent_tok1 + sent_mid + sent_tok2
sent_tok = tokenize_and_truncate(self._tokenizer_name, sent, self.max_seq_len)
start_idx = len(sent_tok1)
end_idx = start_idx + len(sent_mid)
assert end_idx > start_idx, "Invalid marked word indices. Something is wrong."
Expand Down
91 changes: 76 additions & 15 deletions jiant/utils/retokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
# install with: pip install python-Levenshtein
from Levenshtein.StringMatcher import StringMatcher

from .tokenizers import get_tokenizer
from .utils import unescape_moses
from jiant.utils.tokenizers import get_tokenizer
from jiant.utils.utils import unescape_moses


# Tokenizer instance for internal use.
_SIMPLE_TOKENIZER = SpaceTokenizer()
Expand Down Expand Up @@ -98,7 +99,9 @@ def realign_spans(record, tokenizer_name):
"""
Builds the indices alignment while also tokenizing the input
piece by piece.
Only BERT/XLNet and Moses tokenization is supported currently.
Currently, SentencePiece (for XLNet), WPM (for BERT), BPE (for GPT/XLM),
ByteBPE (for RoBERTa/GPT-2) and Moses (for Transformer-XL and default) tokenization are
supported.
Parameters
-----------------------
Expand Down Expand Up @@ -294,28 +297,40 @@ def process_wordpiece_for_alignment(t):
return "<w>" + t


def process_sentencepiece_for_alignment(t):
"""Add <w> markers to ensure word-boundary alignment."""
if t.startswith("▁"):
return "<w>" + re.sub(r"^▁", "", t)
else:
return t


def process_bytebpe_for_alignment(t):
"""Add <w> markers to ensure word-boundary alignment."""
if t.startswith("▁"):
return "<w>" + re.sub(r"^Ġ", "", t)
else:
return t


def space_tokenize_with_bow(sentence):
"""Add <w> markers to ensure word-boundary alignment."""
return ["<w>" + t for t in sentence.split()]


def align_moses(text: Text) -> Tuple[TokenAligner, List[Text]]:
"""Aligner fn for Moses tokenizer, used in Transformer-XL
"""
MosesTokenizer = get_tokenizer("MosesTokenizer")
moses_tokens = MosesTokenizer.tokenize(text)
cleaned_moses_tokens = unescape_moses(moses_tokens)
ta = TokenAligner(text, cleaned_moses_tokens)
return ta, moses_tokens


def align_openai(text: Text) -> Tuple[TokenAligner, List[Text]]:
eow_tokens = space_tokenize_with_eow(text)
openai_utils = get_tokenizer("OpenAI.BPE")
bpe_tokens = openai_utils.tokenize(text)
ta = TokenAligner(eow_tokens, bpe_tokens)
return ta, bpe_tokens


def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
"""Alignment fn for WPM tokenizer, used in BERT
"""
# If using lowercase, do this for the source tokens for better matching.
do_lower_case = tokenizer_name.endswith("uncased")
bow_tokens = space_tokenize_with_bow(text.lower() if do_lower_case else text)
Expand All @@ -328,12 +343,58 @@ def align_wpm(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]
return ta, wpm_tokens


def align_sentencepiece(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
"""Alignment fn for SentencePiece Tokenizer, used in XLNET
"""
bow_tokens = space_tokenize_with_bow(text)
sentencepiece_tokenizer = get_tokenizer(tokenizer_name)
sentencepiece_tokens = sentencepiece_tokenizer.tokenize(text)

modified_sentencepiece_tokens = list(
map(process_sentencepiece_for_alignment, sentencepiece_tokens)
)
ta = TokenAligner(bow_tokens, modified_sentencepiece_tokens)
return ta, sentencepiece_tokens


def align_bpe(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
"""Alignment fn for BPE tokenizer, used in GPT and XLM
"""
eow_tokens = space_tokenize_with_eow(text.lower())
bpe_tokenizer = get_tokenizer(tokenizer_name)
bpe_tokens = bpe_tokenizer.tokenize(text)
ta = TokenAligner(eow_tokens, bpe_tokens)
return ta, bpe_tokens


def align_bytebpe(text: Text, tokenizer_name: str) -> Tuple[TokenAligner, List[Text]]:
"""Alignment fn for Byte-level BPE tokenizer, used in GPT-2 and RoBERTa
"""
bow_tokens = space_tokenize_with_bow(text)
bytebpe_tokenizer = get_tokenizer(tokenizer_name)
bytebpe_tokens = bytebpe_tokenizer.tokenize(text)

modified_bytebpe_tokens = list(map(process_bytebpe_for_alignment, bytebpe_tokens))
ta = TokenAligner(bow_tokens, modified_bytebpe_tokens)
return ta, bytebpe_tokens


def get_aligner_fn(tokenizer_name: Text):
if tokenizer_name == "MosesTokenizer":
"""Given the tokenzier_name, return the corresponding alignment function.
An alignment function modified the tokenized input to make it close to source token,
and choose a space tokenizer with its word-boundary at the same side as tokenizer_name,
hence the source (from space tokenizer) and target (from tokenizer_name) is sufficiently close.
Use TokenAligner to project token index from source to target.
"""
if tokenizer_name == "MosesTokenizer" or tokenizer_name.startswith("transfo-xl-"):
return align_moses
elif tokenizer_name == "OpenAI.BPE":
return align_openai
elif tokenizer_name.startswith("bert-") or tokenizer_name.startswith("xlnet-"):
elif tokenizer_name.startswith("bert-"):
return functools.partial(align_wpm, tokenizer_name=tokenizer_name)
elif tokenizer_name.startswith("openai-gpt") or tokenizer_name.startswith("xlm-mlm-en-"):
return functools.partial(align_bpe, tokenizer_name=tokenizer_name)
elif tokenizer_name.startswith("xlnet-"):
return functools.partial(align_sentencepiece, tokenizer_name=tokenizer_name)
elif tokenizer_name.startswith("roberta-") or tokenizer_name.startswith("gpt2"):
return functools.partial(align_bytebpe, tokenizer_name=tokenizer_name)
else:
raise ValueError(f"Unsupported tokenizer '{tokenizer_name}'")
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"python-Levenshtein==0.12.0",
"sacremoses",
"pytorch-transformers==1.2.0",
"ftfy",
"spacy",
],
use_scm_version=True,
setup_requires=["setuptools_scm"],
Expand Down
Loading

0 comments on commit 8e51a95

Please sign in to comment.