From 42d85298467b3f6f0e4b7b70f5fd481ab33df7ce Mon Sep 17 00:00:00 2001 From: Pete Date: Thu, 22 Jul 2021 17:07:48 -0700 Subject: [PATCH] allow TransformerTextField to take input directly from HF tokenizer (#5329) * allow TransformerTextField to take input directly from HF tokenizer * accept list as well --- CHANGELOG.md | 1 + .../data/fields/transformer_text_field.py | 42 ++++++++++++------- .../fields/transformer_text_field_test.py | 22 ++++++++++ 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b463be77c0b..8bd6afeda39 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`. - Fixed broken links in `allennlp.nn.initializers` docs. +- `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer. ### Changed diff --git a/allennlp/data/fields/transformer_text_field.py b/allennlp/data/fields/transformer_text_field.py index 6d27a03ade9..0aa9157365f 100644 --- a/allennlp/data/fields/transformer_text_field.py +++ b/allennlp/data/fields/transformer_text_field.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, List, Any +from typing import Dict, Optional, List, Any, Union from overrides import overrides import torch @@ -8,6 +8,12 @@ from allennlp.nn import util +def _tensorize(x: Union[torch.Tensor, List[int]]) -> torch.Tensor: + if not isinstance(x, torch.Tensor): + return torch.tensor(x) + return x + + class TransformerTextField(Field[torch.Tensor]): """ A `TransformerTextField` is a collection of several tensors that are are a representation of text, @@ -28,25 +34,27 @@ class TransformerTextField(Field[torch.Tensor]): def __init__( self, - input_ids: torch.Tensor, + input_ids: Union[torch.Tensor, List[int]], # I wish input_ids were called `token_ids` for clarity, but we want to be compatible with huggingface. - token_type_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - special_tokens_mask: Optional[torch.Tensor] = None, - offsets_mapping: Optional[torch.Tensor] = None, + token_type_ids: Optional[Union[torch.Tensor, List[int]]] = None, + attention_mask: Optional[Union[torch.Tensor, List[int]]] = None, + special_tokens_mask: Optional[Union[torch.Tensor, List[int]]] = None, + offsets_mapping: Optional[Union[torch.Tensor, List[int]]] = None, padding_token_id: int = 0, ) -> None: - self.input_ids = input_ids - self.token_type_ids = token_type_ids - self.attention_mask = attention_mask - self.special_tokens_mask = special_tokens_mask - self.offsets_mapping = offsets_mapping + self.input_ids = _tensorize(input_ids) + self.token_type_ids = None if token_type_ids is None else _tensorize(token_type_ids) + self.attention_mask = None if attention_mask is None else _tensorize(attention_mask) + self.special_tokens_mask = ( + None if special_tokens_mask is None else _tensorize(special_tokens_mask) + ) + self.offsets_mapping = None if offsets_mapping is None else _tensorize(offsets_mapping) self.padding_token_id = padding_token_id @overrides def get_padding_lengths(self) -> Dict[str, int]: return { - name: len(getattr(self, name)) + name: getattr(self, name).shape[-1] for name in self.__slots__ if isinstance(getattr(self, name), torch.Tensor) } @@ -56,15 +64,17 @@ def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]: result = {} for name, padding_length in padding_lengths.items(): tensor = getattr(self, name) + if len(tensor.shape) > 1: + tensor = tensor.squeeze(0) result[name] = torch.nn.functional.pad( tensor, - (0, padding_length - len(tensor)), + (0, padding_length - tensor.shape[-1]), value=self.padding_token_id if name == "input_ids" else 0, ) if "attention_mask" not in result: result["attention_mask"] = torch.tensor( - [True] * len(self.input_ids) - + [False] * (padding_lengths["input_ids"] - len(self.input_ids)), + [True] * self.input_ids.shape[-1] + + [False] * (padding_lengths["input_ids"] - self.input_ids.shape[-1]), dtype=torch.bool, ) return result @@ -88,7 +98,7 @@ def format_item(x) -> str: return str(x.item()) def readable_tensor(t: torch.Tensor) -> str: - if len(t) <= 16: + if t.shape[-1] <= 16: return "[" + ", ".join(map(format_item, t)) + "]" else: return ( diff --git a/tests/data/fields/transformer_text_field_test.py b/tests/data/fields/transformer_text_field_test.py index cecb60e454a..079c288e88d 100644 --- a/tests/data/fields/transformer_text_field_test.py +++ b/tests/data/fields/transformer_text_field_test.py @@ -1,5 +1,7 @@ import torch +import pytest +from allennlp.common.cached_transformers import get_tokenizer from allennlp.data import Batch, Instance from allennlp.data.fields import TransformerTextField @@ -38,3 +40,23 @@ def test_transformer_text_field_batching(): assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False]) assert torch.all(tensors["text"]["input_ids"][-1] == 0) assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False])) + + +@pytest.mark.parametrize("return_tensors", ["pt", None]) +def test_transformer_text_field_from_huggingface(return_tensors): + tokenizer = get_tokenizer("bert-base-cased") + + batch = Batch( + [ + Instance( + {"text": TransformerTextField(**tokenizer(text, return_tensors=return_tensors))} + ) + for text in [ + "Hello, World!", + "The fox jumped over the fence", + "Humpty dumpty sat on a wall", + ] + ] + ) + tensors = batch.as_tensor_dict(batch.get_padding_lengths()) + assert tensors["text"]["input_ids"].shape == (3, 11)