Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
specify 'truncation' to avoid transformers warning (#5120)
Browse files Browse the repository at this point in the history
* specify 'truncation' to avoid transformers warning

* Update docs

* Remove `stride` param

* Update CHANGELOG.md

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
epwalsh and dirkgr authored Apr 13, 2021
1 parent 0ddd3d3 commit b34df73
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed a stall when using distributed training and gradient accumulation at the same time
- Fixed an issue where using the `from_pretrained_transformer` `Vocabulary` constructor in distributed training via the `allennlp train` command
would result in the data being iterated through unnecessarily.
- Fixed a warning from `transformers` when using `max_length` in the `PretrainedTransformerTokenizer`.

### Removed

- Removed the `stride` parameter to `PretrainedTransformerTokenizer`. This parameter had no effect.


## [v2.2.0](https://github.com/allenai/allennlp/releases/tag/v2.2.0) - 2021-03-26
Expand Down
8 changes: 1 addition & 7 deletions allennlp/data/tokenizers/pretrained_transformer_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class PretrainedTransformerTokenizer(Tokenizer):
to their model.
max_length : `int`, optional (default=`None`)
If set to a number, will limit the total sequence returned so that it has a maximum length.
If there are overflowing tokens, those will be added to the returned dictionary
stride : `int`, optional (default=`0`)
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
from the main sequence returned. The value of this argument defines the number of additional tokens.
tokenizer_kwargs: `Dict[str, Any]`, optional (default = `None`)
Dictionary with
[additional arguments](https://github.com/huggingface/transformers/blob/155c782a2ccd103cf63ad48a2becd7c76a7d2115/transformers/tokenization_utils.py#L691)
Expand All @@ -55,7 +51,6 @@ def __init__(
model_name: str,
add_special_tokens: bool = True,
max_length: Optional[int] = None,
stride: int = 0,
tokenizer_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
if tokenizer_kwargs is None:
Expand All @@ -73,7 +68,6 @@ def __init__(

self._add_special_tokens = add_special_tokens
self._max_length = max_length
self._stride = stride

self._tokenizer_lowercases = self.tokenizer_lowercases(self.tokenizer)

Expand Down Expand Up @@ -240,7 +234,7 @@ def tokenize(self, text: str) -> List[Token]:
text=text,
add_special_tokens=True,
max_length=max_length,
stride=self._stride,
truncation=True if max_length is not None else False,
return_tensors=None,
return_offsets_mapping=self.tokenizer.is_fast,
return_attention_mask=False,
Expand Down

0 comments on commit b34df73

Please sign in to comment.