Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
allow TransformerTextField to take input directly from HF tokenizer (#…
Browse files Browse the repository at this point in the history
…5329)

* allow TransformerTextField to take input directly from HF tokenizer

* accept list as well
  • Loading branch information
epwalsh authored Jul 23, 2021
1 parent 64043ac commit 42d8529
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Fixed a mispelling: the parameter `contructor_extras` in `Lazy()` is now correctly called `constructor_extras`.
- Fixed broken links in `allennlp.nn.initializers` docs.
- `TransformerTextField` can now take tensors of shape `(1, n)` like the tensors produced from a HuggingFace tokenizer.

### Changed

Expand Down
42 changes: 26 additions & 16 deletions allennlp/data/fields/transformer_text_field.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, List, Any
from typing import Dict, Optional, List, Any, Union

from overrides import overrides
import torch
Expand All @@ -8,6 +8,12 @@
from allennlp.nn import util


def _tensorize(x: Union[torch.Tensor, List[int]]) -> torch.Tensor:
if not isinstance(x, torch.Tensor):
return torch.tensor(x)
return x


class TransformerTextField(Field[torch.Tensor]):
"""
A `TransformerTextField` is a collection of several tensors that are are a representation of text,
Expand All @@ -28,25 +34,27 @@ class TransformerTextField(Field[torch.Tensor]):

def __init__(
self,
input_ids: torch.Tensor,
input_ids: Union[torch.Tensor, List[int]],
# I wish input_ids were called `token_ids` for clarity, but we want to be compatible with huggingface.
token_type_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
special_tokens_mask: Optional[torch.Tensor] = None,
offsets_mapping: Optional[torch.Tensor] = None,
token_type_ids: Optional[Union[torch.Tensor, List[int]]] = None,
attention_mask: Optional[Union[torch.Tensor, List[int]]] = None,
special_tokens_mask: Optional[Union[torch.Tensor, List[int]]] = None,
offsets_mapping: Optional[Union[torch.Tensor, List[int]]] = None,
padding_token_id: int = 0,
) -> None:
self.input_ids = input_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.special_tokens_mask = special_tokens_mask
self.offsets_mapping = offsets_mapping
self.input_ids = _tensorize(input_ids)
self.token_type_ids = None if token_type_ids is None else _tensorize(token_type_ids)
self.attention_mask = None if attention_mask is None else _tensorize(attention_mask)
self.special_tokens_mask = (
None if special_tokens_mask is None else _tensorize(special_tokens_mask)
)
self.offsets_mapping = None if offsets_mapping is None else _tensorize(offsets_mapping)
self.padding_token_id = padding_token_id

@overrides
def get_padding_lengths(self) -> Dict[str, int]:
return {
name: len(getattr(self, name))
name: getattr(self, name).shape[-1]
for name in self.__slots__
if isinstance(getattr(self, name), torch.Tensor)
}
Expand All @@ -56,15 +64,17 @@ def as_tensor(self, padding_lengths: Dict[str, int]) -> Dict[str, torch.Tensor]:
result = {}
for name, padding_length in padding_lengths.items():
tensor = getattr(self, name)
if len(tensor.shape) > 1:
tensor = tensor.squeeze(0)
result[name] = torch.nn.functional.pad(
tensor,
(0, padding_length - len(tensor)),
(0, padding_length - tensor.shape[-1]),
value=self.padding_token_id if name == "input_ids" else 0,
)
if "attention_mask" not in result:
result["attention_mask"] = torch.tensor(
[True] * len(self.input_ids)
+ [False] * (padding_lengths["input_ids"] - len(self.input_ids)),
[True] * self.input_ids.shape[-1]
+ [False] * (padding_lengths["input_ids"] - self.input_ids.shape[-1]),
dtype=torch.bool,
)
return result
Expand All @@ -88,7 +98,7 @@ def format_item(x) -> str:
return str(x.item())

def readable_tensor(t: torch.Tensor) -> str:
if len(t) <= 16:
if t.shape[-1] <= 16:
return "[" + ", ".join(map(format_item, t)) + "]"
else:
return (
Expand Down
22 changes: 22 additions & 0 deletions tests/data/fields/transformer_text_field_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torch
import pytest

from allennlp.common.cached_transformers import get_tokenizer
from allennlp.data import Batch, Instance
from allennlp.data.fields import TransformerTextField

Expand Down Expand Up @@ -38,3 +40,23 @@ def test_transformer_text_field_batching():
assert tensors["text"]["attention_mask"][0, -1] == torch.Tensor([False])
assert torch.all(tensors["text"]["input_ids"][-1] == 0)
assert torch.all(tensors["text"]["attention_mask"][-1] == torch.tensor([False]))


@pytest.mark.parametrize("return_tensors", ["pt", None])
def test_transformer_text_field_from_huggingface(return_tensors):
tokenizer = get_tokenizer("bert-base-cased")

batch = Batch(
[
Instance(
{"text": TransformerTextField(**tokenizer(text, return_tensors=return_tensors))}
)
for text in [
"Hello, World!",
"The fox jumped over the fence",
"Humpty dumpty sat on a wall",
]
]
)
tensors = batch.as_tensor_dict(batch.get_padding_lengths())
assert tensors["text"]["input_ids"].shape == (3, 11)

0 comments on commit 42d8529

Please sign in to comment.