diff --git a/flair/embeddings/base.py b/flair/embeddings/base.py index 8b436109ec..810deb8c10 100644 --- a/flair/embeddings/base.py +++ b/flair/embeddings/base.py @@ -21,6 +21,7 @@ TransfoXLModel, XLNetModel, ) +from transformers.tokenization_utils_base import TruncationStrategy, LARGE_INTEGER import flair from flair.data import DT, Sentence, Token @@ -168,7 +169,6 @@ def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor: class TransformerEmbedding(Embeddings[Sentence]): - NO_MAX_SEQ_LENGTH_MODELS = (XLNetModel, TransfoXLModel) def __init__( self, @@ -217,7 +217,7 @@ def __init__( self.truncate = True - if isinstance(self.model, self.NO_MAX_SEQ_LENGTH_MODELS): + if self.tokenizer.model_max_length > LARGE_INTEGER: allow_long_sentences = False self.truncate = False