-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RoFormer] Fix some issues #12397
[RoFormer] Fix some issues #12397
Changes from 13 commits
4b6088a
d2f2bea
d96a2a0
bae3675
8175b1a
c0daacd
72c3095
3c02106
aa14de9
23c723a
9490c0a
bbf0e0e
f986f17
ef2d46e
90102cb
1c5dd6a
1f08622
52eb0f5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,7 +38,7 @@ def postprocess_qa_predictions( | |
null_score_diff_threshold: float = 0.0, | ||
output_dir: Optional[str] = None, | ||
prefix: Optional[str] = None, | ||
is_world_process_zero: bool = True, | ||
log_level: Optional[int] = logging.WARNING, | ||
): | ||
""" | ||
Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the | ||
|
@@ -70,8 +70,8 @@ def postprocess_qa_predictions( | |
answers, are saved in `output_dir`. | ||
prefix (:obj:`str`, `optional`): | ||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names. | ||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
Whether this process is the main process or not (used to determine if logging/saves should be done). | ||
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): | ||
``logging`` log level (e.g., ``logging.WARNING``) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
""" | ||
assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." | ||
all_start_logits, all_end_logits = predictions | ||
|
@@ -91,7 +91,7 @@ def postprocess_qa_predictions( | |
scores_diff_json = collections.OrderedDict() | ||
|
||
# Logging. | ||
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) | ||
logger.setLevel(log_level) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") | ||
|
||
# Let's loop over all the examples! | ||
|
@@ -250,7 +250,7 @@ def postprocess_qa_predictions_with_beam_search( | |
end_n_top: int = 5, | ||
output_dir: Optional[str] = None, | ||
prefix: Optional[str] = None, | ||
is_world_process_zero: bool = True, | ||
log_level: Optional[int] = logging.WARNING, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
): | ||
""" | ||
Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the | ||
|
@@ -280,8 +280,8 @@ def postprocess_qa_predictions_with_beam_search( | |
answers, are saved in `output_dir`. | ||
prefix (:obj:`str`, `optional`): | ||
If provided, the dictionaries mentioned above are saved with `prefix` added to their names. | ||
is_world_process_zero (:obj:`bool`, `optional`, defaults to :obj:`True`): | ||
Whether this process is the main process or not (used to determine if logging/saves should be done). | ||
log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): | ||
``logging`` log level (e.g., ``logging.WARNING``) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
""" | ||
assert len(predictions) == 5, "`predictions` should be a tuple with five elements." | ||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions | ||
|
@@ -302,7 +302,7 @@ def postprocess_qa_predictions_with_beam_search( | |
scores_diff_json = collections.OrderedDict() if version_2_with_negative else None | ||
|
||
# Logging. | ||
logger.setLevel(logging.INFO if is_world_process_zero else logging.WARN) | ||
logger.setLevel(log_level) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") | ||
|
||
# Let's loop over all the examples! | ||
|
@@ -413,14 +413,14 @@ def postprocess_qa_predictions_with_beam_search( | |
output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" | ||
) | ||
|
||
print(f"Saving predictions to {prediction_file}.") | ||
logger.info(f"Saving predictions to {prediction_file}.") | ||
with open(prediction_file, "w") as writer: | ||
writer.write(json.dumps(all_predictions, indent=4) + "\n") | ||
print(f"Saving nbest_preds to {nbest_file}.") | ||
logger.info(f"Saving nbest_preds to {nbest_file}.") | ||
with open(nbest_file, "w") as writer: | ||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n") | ||
if version_2_with_negative: | ||
print(f"Saving null_odds to {null_odds_file}.") | ||
logger.info(f"Saving null_odds to {null_odds_file}.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. |
||
with open(null_odds_file, "w") as writer: | ||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n") | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,21 +31,36 @@ | |
"vocab_file": { | ||
"junnyu/roformer_chinese_small": "https://huggingface.co/junnyu/roformer_chinese_small/resolve/main/vocab.txt", | ||
"junnyu/roformer_chinese_base": "https://huggingface.co/junnyu/roformer_chinese_base/resolve/main/vocab.txt", | ||
"junnyu/roformer_chinese_char_small": "https://huggingface.co/junnyu/roformer_chinese_char_small/resolve/main/vocab.txt", | ||
"junnyu/roformer_chinese_char_base": "https://huggingface.co/junnyu/roformer_chinese_char_base/resolve/main/vocab.txt", | ||
"junnyu/roformer_small_discriminator": "https://huggingface.co/junnyu/roformer_small_discriminator/resolve/main/vocab.txt", | ||
"junnyu/roformer_small_generator": "https://huggingface.co/junnyu/roformer_small_generator/resolve/main/vocab.txt", | ||
} | ||
} | ||
|
||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {"junnyu/roformer_chinese_small": 1536, "junnyu/roformer_chinese_base": 1536} | ||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { | ||
"junnyu/roformer_chinese_small": 1536, | ||
"junnyu/roformer_chinese_base": 1536, | ||
"junnyu/roformer_chinese_char_small": 512, | ||
"junnyu/roformer_chinese_char_base": 512, | ||
"junnyu/roformer_small_discriminator": 128, | ||
"junnyu/roformer_small_generator": 128, | ||
} | ||
|
||
|
||
PRETRAINED_INIT_CONFIGURATION = { | ||
"junnyu/roformer_chinese_small": {"do_lower_case": True}, | ||
"junnyu/roformer_chinese_base": {"do_lower_case": True}, | ||
"junnyu/roformer_chinese_char_small": {"do_lower_case": True}, | ||
"junnyu/roformer_chinese_char_base": {"do_lower_case": True}, | ||
"junnyu/roformer_small_discriminator": {"do_lower_case": True}, | ||
"junnyu/roformer_small_generator": {"do_lower_case": True}, | ||
} | ||
|
||
|
||
class RoFormerTokenizer(PreTrainedTokenizer): | ||
r""" | ||
Construct a RoFormer tokenizer. Based on `Rust Jieba <https://pypi.org/project/rjieba/>`. | ||
Construct a RoFormer tokenizer. Based on `Jieba <https://pypi.org/project/jieba/>`. | ||
|
||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. | ||
Users should refer to this superclass for more information regarding those methods. | ||
|
@@ -143,13 +158,13 @@ def __init__( | |
) | ||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) | ||
try: | ||
import rjieba | ||
import jieba | ||
except ImportError: | ||
raise ImportError( | ||
"You need to install rjieba to use RoFormerTokenizer." | ||
"See https://pypi.org/project/rjieba/ for installation." | ||
"You need to install jieba to use RoFormerTokenizer." | ||
"See https://pypi.org/project/jieba/ for installation." | ||
) | ||
self.jieba = rjieba | ||
self.jieba = jieba | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the correct way of handling the jieba dependency @LysandreJik? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We decided to handle it this way as it's the only model that requires it - and if other models arrive, then to upstream it like it is done for the other models. |
||
|
||
@property | ||
def do_lower_case(self): | ||
|
@@ -167,21 +182,21 @@ def __getstate__(self): | |
def __setstate__(self, d): | ||
self.__dict__ = d | ||
try: | ||
import rjieba | ||
import jieba | ||
except ImportError: | ||
raise ImportError( | ||
"You need to install rjieba to use RoFormerTokenizer." | ||
"See https://pypi.org/project/rjieba/ for installation." | ||
"You need to install jieba to use RoFormerTokenizer." | ||
"See https://pypi.org/project/jieba/ for installation." | ||
) | ||
self.jieba = rjieba | ||
self.jieba = jieba | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you already have the |
||
|
||
def get_vocab(self): | ||
return dict(self.vocab, **self.added_tokens_encoder) | ||
|
||
def _tokenize(self, text, use_jieba=True): | ||
split_tokens = [] | ||
if use_jieba: | ||
for wholword in self.jieba.cut(text, False): | ||
for wholword in self.jieba.cut(text, HMM=False): | ||
if wholword in self.vocab: | ||
split_tokens.append(wholword) | ||
else: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,21 +22,21 @@ | |
from .test_tokenization_common import TokenizerTesterMixin | ||
|
||
|
||
def is_rjieba_available(): | ||
return importlib.util.find_spec("rjieba") is not None | ||
def is_jieba_available(): | ||
return importlib.util.find_spec("jieba") is not None | ||
|
||
|
||
def require_rjieba(test_case): | ||
def require_jieba(test_case): | ||
""" | ||
Decorator marking a test that requires Jieba. These tests are skipped when Jieba isn't installed. | ||
""" | ||
if not is_rjieba_available(): | ||
return unittest.skip("test requires rjieba")(test_case) | ||
if not is_jieba_available(): | ||
return unittest.skip("test requires jieba")(test_case) | ||
else: | ||
return test_case | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, I have changed this :) . |
||
|
||
|
||
@require_rjieba | ||
@require_jieba | ||
@require_tokenizers | ||
class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): | ||
|
||
|
@@ -78,7 +78,3 @@ def test_rust_tokenizer(self): | |
input_tokens = tokens + [tokenizer.unk_token] | ||
exp_tokens = [22943, 21332, 34431, 45904, 117, 306, 1231, 1231, 2653, 33994, 1266, 100] | ||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), exp_tokens) | ||
|
||
# due to custom pre_tokenize , char_to_token may be error | ||
def test_alignement_methods(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason you updated this? Same question for the lines below
Don't think this RoFormer PR needs to update the
utils_qa.py
file of the Tensorflow examples?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forked transformers before this pr 276bc14.
And this pr do not run command
fix-copies
.I run command
fix-copies
and then this fileexamples/tensorflow/question-answering/utils_qa.py
change.