Skip to content

Commit

Permalink
Fix up
Browse files Browse the repository at this point in the history
  • Loading branch information
qqaatw committed Sep 1, 2021
1 parent 28b8dac commit bd6d2eb
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source/model_doc/realm.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ RealmTokenizer

.. autoclass:: transformers.RealmTokenizer
:members: build_inputs_with_special_tokens, get_special_tokens_mask,
create_token_type_ids_from_sequences, save_vocabulary
create_token_type_ids_from_sequences, save_vocabulary, batch_encode_candidates


RealmEmbedder
Expand Down
26 changes: 13 additions & 13 deletions src/transformers/models/realm/modeling_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,23 +593,23 @@ def forward(
Returns:
Example:
Example::
>>> import torch
>>> from transformers import RealmTokenizer, RealmEncoder
>>> import torch
>>> from transformers import RealmTokenizer, RealmEncoder
>>> tokenizer = RealmTokenizer.from_pretrained('qqaatw/realm-cc-news-pretrained-bert')
>>> model = RealmEncoder.from_pretrained('qqaatw/realm-cc-news-pretrained-bert', num_candidates=2)
>>> tokenizer = RealmTokenizer.from_pretrained('qqaatw/realm-cc-news-pretrained-bert')
>>> model = RealmEncoder.from_pretrained('qqaatw/realm-cc-news-pretrained-bert', num_candidates=2)
>>> # batch_size = 2, num_candidates = 2
>>> text = [
>>> ["Hello world!", "Nice to meet you!"],
>>> ["The cute cat.", "The adorable dog."]
>>> ]
>>> # batch_size = 2, num_candidates = 2
>>> text = [
>>> ["Hello world!", "Nice to meet you!"],
>>> ["The cute cat.", "The adorable dog."]
>>> ]
>>> inputs = tokenizer.batch_encode_candidates(text, max_length=10)
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> inputs = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
"""

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down
27 changes: 15 additions & 12 deletions src/transformers/models/realm/tokenization_realm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for REALM."""
import torch

from ...file_utils import PaddingStrategy
from ...tokenization_utils_base import BatchEncoding
Expand Down Expand Up @@ -69,7 +68,7 @@ def batch_encode_candidates(self, text, **kwargs):
differences:
1. Handle additional num_candidate axis. (batch_size, num_candidates, text)
2. Always pad the sequences to `max_length` and always return PyTorch tensors..
2. Always pad the sequences to `max_length`.
3. Must specify `max_length` in order to stack packs of candidates into a batch.
- single sequence: ``[CLS] X [SEP]``
Expand All @@ -88,23 +87,27 @@ def batch_encode_candidates(self, text, **kwargs):
Returns:
:class:`~transformers.BatchEncoding`: Encoded text or text pair.
Example: >>> from transformers import RealmTokenizer
Example::
>>> # batch_size = 2, num_candidates = 2 >>> text = [ >>> ["Hello world!", "Nice to meet you!"], >>> ["The cute
cat.", "The adorable dog."] >>> ]
>>> from transformers import RealmTokenizer
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-bert")
>>> # batch_size = 2, num_candidates = 2
>>> text = [
>>> ["Hello world!", "Nice to meet you!"],
>>> ["The cute cat.", "The adorable dog."]
>>> ]
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10)
>>> tokenizer = RealmTokenizer.from_pretrained("qqaatw/realm-cc-news-pretrained-bert")
>>> tokenized_text = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
"""

# Always return PyTorch tensor.
kwargs["return_tensors"] = "pt"
# Always using a fixed sequence length to encode in order to stack candidates into a batch.
kwargs["padding"] = PaddingStrategy.MAX_LENGTH

batch_text = text
batch_text_pair = kwargs.pop("text_pair", None)
return_tensors = kwargs.pop("return_tensors", None)

output_data = {
"input_ids": [],
Expand All @@ -118,7 +121,7 @@ def batch_encode_candidates(self, text, **kwargs):
else:
candidate_text_pair = None

encoded_candidates = super().__call__(candidate_text, candidate_text_pair, **kwargs)
encoded_candidates = super().__call__(candidate_text, candidate_text_pair, return_tensors=None, **kwargs)

encoded_input_ids = encoded_candidates.get("input_ids")
encoded_attention_mask = encoded_candidates.get("attention_mask")
Expand All @@ -131,6 +134,6 @@ def batch_encode_candidates(self, text, **kwargs):
if encoded_token_type_ids is not None:
output_data["token_type_ids"].append(encoded_token_type_ids)

output_data = dict((key, torch.stack(item)) for key, item in output_data.items() if len(item) != 0)
output_data = dict((key, item) for key, item in output_data.items() if len(item) != 0)

return BatchEncoding(output_data, tensor_type=kwargs["return_tensors"])
return BatchEncoding(output_data, tensor_type=return_tensors)

0 comments on commit bd6d2eb

Please sign in to comment.