diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py index 285db61fd572..ad411bfd434b 100644 --- a/paddlenlp/transformers/__init__.py +++ b/paddlenlp/transformers/__init__.py @@ -209,6 +209,9 @@ from .xlm.modeling import * from .xlm.tokenizer import * from .xlm.configuration import * +from .xlm_roberta.modeling import * +from .xlm_roberta.tokenizer import * +from .xlm_roberta.configuration import * from .gau_alpha.modeling import * from .gau_alpha.tokenizer import * from .gau_alpha.configuration import * diff --git a/paddlenlp/transformers/auto/configuration.py b/paddlenlp/transformers/auto/configuration.py index f2058a5ec389..b77672d258c0 100644 --- a/paddlenlp/transformers/auto/configuration.py +++ b/paddlenlp/transformers/auto/configuration.py @@ -113,6 +113,7 @@ ("unimo", "UNIMOConfig"), ("visualglm", "VisualGLMConfig"), ("xlm", "XLMConfig"), + ("xlm-roberta", "XLMRobertaConfig"), ("xlnet", "XLNetConfig"), ("yuan", "YuanConfig"), ] @@ -202,6 +203,7 @@ ("unimo", "UNIMO"), ("visualglm", "VisualGLM"), ("xlm", "XLM"), + ("xlm-roberta", "XLMRoberta"), ("xlnet", "XLNet"), ("yuan", "Yuan"), ] diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py index 8b94d9f4b53d..938b06f5a5b9 100644 --- a/paddlenlp/transformers/auto/modeling.py +++ b/paddlenlp/transformers/auto/modeling.py @@ -94,6 +94,7 @@ ("UNIMO", "unimo"), ("XLNet", "xlnet"), ("XLM", "xlm"), + ("XLMRoberta", "xlm_roberta"), ("GPT", "gpt"), ("GLM", "glm"), ("MT5", "mt5"), diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py index a53e36c4935a..7dfc2ede0ada 100644 --- a/paddlenlp/transformers/auto/tokenizer.py +++ b/paddlenlp/transformers/auto/tokenizer.py @@ -99,6 +99,7 @@ ("squeezebert", "SqueezeBertTokenizer"), ("t5", "T5Tokenizer"), ("xlm", "XLMTokenizer"), + ("xlm_roberta", "XLMRobertaTokenizer"), ("xlnet", "XLNetTokenizer"), ("bert_japanese", "BertJapaneseTokenizer"), ("bigbird", "BigBirdTokenizer"), diff --git a/paddlenlp/transformers/xlm_roberta/__init__.py b/paddlenlp/transformers/xlm_roberta/__init__.py new file mode 100644 index 000000000000..4c08fc6b9a63 --- /dev/null +++ b/paddlenlp/transformers/xlm_roberta/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration import * +from .modeling import * +from .tokenizer import * diff --git a/paddlenlp/transformers/xlm_roberta/configuration.py b/paddlenlp/transformers/xlm_roberta/configuration.py new file mode 100644 index 000000000000..dcbf46079bac --- /dev/null +++ b/paddlenlp/transformers/xlm_roberta/configuration.py @@ -0,0 +1,160 @@ +# coding=utf-8 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" XLM-RoBERTa configuration""" + +from ..model_utils import PretrainedConfig + +__all__ = ["PRETRAINED_INIT_CONFIGURATION", "XLMRobertaConfig"] + +PRETRAINED_INIT_CONFIGURATION = { + "hf-internal-testing/tiny-random-onnx-xlm-roberta": { + "attention_probs_dropout_prob": 0.1, + "bos_token_id": 0, + "classifier_dropout": None, + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 4, + "initializer_range": 0.02, + "intermediate_size": 37, + "layer_norm_eps": 1e-05, + "max_position_embeddings": 514, + "model_type": "xlm-roberta", + "num_attention_heads": 4, + "num_hidden_layers": 5, + "output_past": True, + "pad_token_id": 1, + "position_embedding_type": "absolute", + "dtype": "float32", + "type_vocab_size": 1, + "use_cache": True, + "vocab_size": 250002, + }, +} + + +class XLMRobertaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`XLMRobertaModel`] or a [`TFXLMRobertaModel`]. It + is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the XLMRoBERTa + [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 30522): + Vocabulary size of the XLM-RoBERTa model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`XLMRobertaModel`] or [`TFXLMRobertaModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"silu"` and `"gelu_new"` are supported. + hidden_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + max_position_embeddings (`int`, *optional*, defaults to 512): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + type_vocab_size (`int`, *optional*, defaults to 2): + The vocabulary size of the `token_type_ids` passed when calling [`XLMRobertaModel`] or + [`TFXLMRobertaModel`]. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + position_embedding_type (`str`, *optional*, defaults to `"absolute"`): + Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For + positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to + [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). + For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models + with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). + is_decoder (`bool`, *optional*, defaults to `False`): + Whether the model is used as a decoder or not. If `False`, the model is used as an encoder. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + classifier_dropout (`float`, *optional*): + The dropout ratio for the classification head. + + Examples: + + ```python + >>> from paddlenlp.transformers import XLMRobertaConfig, XLMRobertaModel + + >>> # Initializing a XLM-RoBERTa xlm-roberta-base style configuration + >>> configuration = XLMRobertaConfig() + + >>> # Initializing a model (with random weights) from the xlm-roberta-base style configuration + >>> model = XLMRobertaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "xlm-roberta" + + def __init__( + self, + vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + position_embedding_type="absolute", + use_cache=True, + classifier_dropout=None, + **kwargs, + ): + kwargs["return_dict"] = kwargs.pop("return_dict", False) + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + self.position_embedding_type = position_embedding_type + self.use_cache = use_cache + self.classifier_dropout = classifier_dropout diff --git a/paddlenlp/transformers/xlm_roberta/modeling.py b/paddlenlp/transformers/xlm_roberta/modeling.py new file mode 100644 index 000000000000..31feb37785ef --- /dev/null +++ b/paddlenlp/transformers/xlm_roberta/modeling.py @@ -0,0 +1,1618 @@ +# coding=utf-8 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle XLM-RoBERTa model.""" + +import math +from typing import List, Optional, Tuple, Union + +import paddle +from paddle import nn +from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...utils import logger +from ...utils.converter import StateDictNameMapping +from ..activations import ACT2FN +from ..model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ..model_utils import ( + PretrainedModel, + apply_chunking_to_forward, + register_base_model, +) +from .configuration import PRETRAINED_INIT_CONFIGURATION, XLMRobertaConfig + +__all__ = [ + "XLMRobertaModel", + "XLMRobertaPretrainedModel", + "XLMRobertaForSequenceClassification", + "XLMRobertaForTokenClassification", + "XLMRobertaForQuestionAnswering", + "XLMRobertaForMaskedLM", + "XLMRobertaForMultipleChoice", + "XLMRobertaForCausalLM", +] + + +class XLMRobertaEmbeddings(nn.Layer): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size + ) # padding_idx=config.pad_token_id NOTE, donot set padding_idx + self.word_embeddings.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer( + "position_ids", + paddle.arange(config.max_position_embeddings, dtype=paddle.int64).expand((1, -1)), + persistable=False, + ) + self.register_buffer( + "token_type_ids", paddle.zeros(self.position_ids.shape, dtype=paddle.int64), persistable=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, # padding_idx=self.padding_idx + ) + self.position_embeddings.padding_idx = config.pad_token_id + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.shape + else: + input_shape = inputs_embeds.shape[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand([input_shape[0], seq_length]) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: paddle.Tensor + + Returns: paddle.Tensor + """ + input_shape = inputs_embeds.shape[:-1] + sequence_length = input_shape[1] + + position_ids = paddle.arange( + self.padding_idx + 1, + sequence_length + self.padding_idx + 1, + dtype=paddle.int64, + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +class XLMRobertaSelfAttention(nn.Layer): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + self.scale = math.sqrt(self.attention_head_size) + + def transpose_for_scores(self, x: paddle.Tensor) -> paddle.Tensor: + new_x_shape = x.shape[:-1] + [self.num_attention_heads, self.attention_head_size] + x = x.reshape(new_x_shape) + return x.transpose([0, 2, 1, 3]) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[paddle.Tensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = paddle.concat([past_key_value[0], key_layer], axis=2) + value_layer = paddle.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(paddle.Tensor, paddle.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(paddle.Tensor, paddle.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = paddle.matmul(query_layer, key_layer, transpose_y=True) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = paddle.to_tensor( + key_length - 1, + dtype=paddle.int64, + ).reshape([-1, 1]) + else: + position_ids_l = paddle.arange( + query_length, + dtype=paddle.int64, + ).reshape([-1, 1]) + position_ids_r = paddle.arange( + key_length, + dtype=paddle.int64, + ).reshape([1, -1]) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.cast(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = paddle.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = paddle.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = paddle.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / self.scale + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = paddle.matmul(attention_probs, value_layer) + + context_layer = context_layer.transpose([0, 2, 1, 3]) + new_context_layer_shape = context_layer.shape[:-2] + [ + self.all_head_size, + ] + context_layer = context_layer.reshape(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +class XLMRobertaSelfOutput(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class XLMRobertaAttention(nn.Layer): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = XLMRobertaSelfOutput(config) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[paddle.Tensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class XLMRobertaIntermediate(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class XLMRobertaOutput(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class XLMRobertaLayer(nn.Layer): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = XLMRobertaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute") + self.intermediate = XLMRobertaIntermediate(config) + self.output = XLMRobertaOutput(config) + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + past_key_value: Optional[Tuple[Tuple[paddle.Tensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[paddle.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class XLMRobertaEncoder(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.LayerList([XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.enable_recompute = False + + def forward( + self, + hidden_states: paddle.Tensor, + attention_mask: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[paddle.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.enable_recompute and not hidden_states.stop_gradient: + layer_outputs = self._gradient_checkpointing_func( + layer_module.__call__, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class XLMRobertaPooler(nn.Layer): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + pooler_act = getattr(config, "pooler_act", "tanh") + self.activation = ACT2FN[pooler_act] + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class XLMRobertaPretrainedModel(PretrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = XLMRobertaConfig + pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = { + "model_state": { + "hf-internal-testing/tiny-random-onnx-xlm-roberta": "https://bj.bcebos.com/paddlenlp/models/community/hf-internal-testing/tiny-random-onnx-xlm-roberta/model.safetensors", + } + } + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"] + + def can_generate(self) -> bool: + return False + + @classmethod + def _get_name_mappings(cls, config): + architectures = config.architectures + [cls.__name__] + mappings = [] + model_mappings = [ + ["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"], + ["embeddings.position_ids", "embeddings.position_ids"], + ["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"], + ["embeddings.token_type_embeddings.weight", "embeddings.token_type_embeddings.weight"], + ["embeddings.LayerNorm.weight", "embeddings.LayerNorm.weight"], + ["embeddings.LayerNorm.bias", "embeddings.LayerNorm.bias"], + ["pooler.dense.weight", "pooler.dense.weight", "transpose"], + ["pooler.dense.bias", "pooler.dense.bias"], + # for TokenClassification + ] + for layer_index in range(config.num_hidden_layers): + for name in [ + "attention.self.query", + "attention.self.key", + "attention.self.value", + "attention.output.dense", + "attention.output.LayerNorm", + "intermediate.dense", + "output.dense", + "output.LayerNorm", + ]: + action = None if "LayerNorm" in name else "transpose" + model_mappings.extend( + [ + [ + f"encoder.layer.{layer_index}.{name}.weight", + f"encoder.layer.{layer_index}.{name}.weight", + action, + ], + [ + f"encoder.layer.{layer_index}.{name}.bias", + f"encoder.layer.{layer_index}.{name}.bias", + ], + ] + ) + + # base-model prefix "XLMRobertaModel" + torch_prefix = "" + paddle_prefix = "" + if "XLMRobertaModel" not in config.architectures: + torch_prefix = "roberta." + if "XLMRobertaModel" not in [cls.__name__]: + paddle_prefix = "roberta." + + # add prefix + for mapping in model_mappings: + mapping[0] = torch_prefix + mapping[0] + mapping[1] = paddle_prefix + mapping[1] + + if "XLMRobertaForCausalLM" in architectures: + model_mappings.extend( + [ + ["lm_head.dense.weight", "lm_head.dense.weight", "transpose"], + ["lm_head.dense.bias", "lm_head.dense.bias"], + ["lm_head.layer_norm.weight", "lm_head.layer_norm.weight"], + ["lm_head.layer_norm.bias", "lm_head.layer_norm.bias"], + ["lm_head.bias", "lm_head.bias"], + ] + ) + + # downstream mappings + if "XLMRobertaForQuestionAnswering" in architectures: + model_mappings.extend( + [ + ["qa_outputs.weight", "qa_outputs.weight", "transpose"], + ["qa_outputs.bias", "qa_outputs.bias"], + ] + ) + if "XLMRobertaForSequenceClassification" in architectures: + model_mappings.extend( + [ + ["classifier.dense.weight", "classifier.dense.weight", "transpose"], + ["classifier.dense.bias", "classifier.dense.bias"], + ["classifier.out_proj.weight", "classifier.out_proj.weight", "transpose"], + ["classifier.out_proj.bias", "classifier.out_proj.bias"], + ] + ) + + if "XLMRobertaForMultipleChoice" in architectures or "XLMRobertaForTokenClassification" in architectures: + model_mappings.extend( + [ + ["classifier.weight", "classifier.weight", "transpose"], + ["classifier.bias", "classifier.bias"], + ] + ) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @paddle.no_grad() + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if hasattr(module, "padding_idx") and module.padding_idx is not None: + module.weight[module.padding_idx] = 0 + elif isinstance(module, nn.LayerNorm): + module.bias.zero_() + module.weight.fill_(1.0) + + +@register_base_model +class XLMRobertaModel(XLMRobertaPretrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = XLMRobertaEmbeddings(config) + self.encoder = XLMRobertaEncoder(config) + + self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @property + def dtype(self) -> paddle.dtype: + """ + `paddle.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.named_parameters())[1].dtype + except StopIteration: + try: + return next(self.named_buffers())[1].dtype + except StopIteration: + return self._dtype + + def invert_attention_mask(self, encoder_attention_mask: paddle.Tensor) -> paddle.Tensor: + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`paddle.Tensor`): An attention mask. + + Returns: + `paddle.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.ndim == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.ndim == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask.cast(dtype=self.dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * paddle.finfo(self.dtype).min + + return encoder_extended_attention_mask + + @staticmethod + def create_extended_attention_mask_for_decoder(input_shape, attention_mask): + batch_size, seq_length = input_shape + seq_ids = paddle.arange(seq_length) + causal_mask = seq_ids[None, None, :].tile([batch_size, seq_length, 1]) <= seq_ids[None, :, None] + # in case past_key_values are used we need to add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.cast(dtype=attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + causal_mask = paddle.concat( + [ + paddle.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype), + causal_mask, + ], + axis=-1, + ) + + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + return extended_attention_mask + + def get_extended_attention_mask( + self, attention_mask: paddle.Tensor, input_shape: Tuple[int], dtype: paddle.dtype = None + ) -> paddle.Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (`paddle.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (`Tuple[int]`): + The shape of the input to the model. + + Returns: + `paddle.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`. + """ + if dtype is None: + dtype = self.dtype + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.ndim == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.ndim == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder: + extended_attention_mask = XLMRobertaModel.create_extended_attention_mask_for_decoder( + input_shape, attention_mask + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})" + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.cast(dtype=dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * paddle.finfo(dtype).min + return extended_attention_mask + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(paddle.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = paddle.ones( + ((batch_size, seq_length + past_key_values_length)), + ) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand([batch_size, seq_length]) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = paddle.zeros( + input_shape, + dtype=paddle.int64, + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: paddle.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.shape + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = paddle.ones(encoder_hidden_shape) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class XLMRobertaForCausalLM(XLMRobertaPretrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + + if config.tie_word_embeddings: + input_embeddings = self.roberta.embeddings.word_embeddings.weight + else: + input_embeddings = None + self.lm_head = XLMRobertaLMHead(config, input_embeddings=input_embeddings) + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def get_output_embeddings(self): + if self.config.tie_word_embeddings: + return None + else: + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + if self.config.tie_word_embeddings: + logger.warning( + "`set_output_embeddings` method is called when `config.tie_word_embeddings=True`. This is not expected. We will do nothing!" + ) + else: + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + past_key_values: Tuple[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(paddle.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from paddlenlp.transformers import AutoTokenizer, XLMRobertaForCausalLM, AutoConfig + >>> import paddle + + >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base") + >>> config = AutoConfig.from_pretrained("roberta-base") + >>> config.is_decoder = True + >>> model = XLMRobertaForCausalLM.from_pretrained("roberta-base", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pd") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :] + labels = labels[:, 1:] + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct( + shifted_prediction_scores.reshape([-1, self.config.vocab_size]), + labels.reshape( + [ + -1, + ] + ), + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = paddle.ones(input_shape, dtype=input_ids.dtype) + + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(axis=0, index=beam_idx) for past_state in layer_past),) + return reordered_past + + +class XLMRobertaForMaskedLM(XLMRobertaPretrainedModel): + _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + + if config.tie_word_embeddings: + input_embeddings = self.roberta.embeddings.word_embeddings.weight + else: + input_embeddings = None + self.lm_head = XLMRobertaLMHead(config, input_embeddings=input_embeddings) + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def get_output_embeddings(self): + if self.config.tie_word_embeddings: + return None + else: + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + if self.config.tie_word_embeddings: + logger.warning( + "`set_output_embeddings` method is called when `config.tie_word_embeddings=True`. This is not expected. We will do nothing!" + ) + else: + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + encoder_hidden_states: Optional[paddle.Tensor] = None, + encoder_attention_mask: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], MaskedLMOutput]: + r""" + labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.reshape([-1, self.config.vocab_size]), + labels.reshape( + [ + -1, + ] + ), + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class XLMRobertaLMHead(nn.Layer): + """Roberta Head for masked language modeling.""" + + def __init__(self, config, input_embeddings=None): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps) + + if input_embeddings is None: + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False) + # self.bias = nn.Parameter(paddle.zeros((config.vocab_size,))) + data = paddle.zeros((config.vocab_size,)) + self.bias = paddle.create_parameter( + data.shape, dtype=data.dtype, default_initializer=nn.initializer.Assign(data) + ) + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + else: + # self.bias = nn.Parameter(paddle.zeros((config.vocab_size,))) + data = paddle.zeros((config.vocab_size,)) + self.bias = paddle.create_parameter( + data.shape, dtype=data.dtype, default_initializer=nn.initializer.Assign(data) + ) + decoder_weight = input_embeddings.weight if hasattr(input_embeddings, "weight") else input_embeddings + self.decoder = lambda x: paddle.matmul(x, decoder_weight, transpose_y=True) + self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = nn.functional.gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + +class XLMRobertaForSequenceClassification(XLMRobertaPretrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + self.classifier = XLMRobertaClassificationHead(config) + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], SequenceClassifierOutput]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.reshape([-1, self.num_labels]), labels.reshape((-1,))) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class XLMRobertaForMultipleChoice(XLMRobertaPretrainedModel): + def __init__(self, config): + super().__init__(config) + + self.roberta = XLMRobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.reshape([-1, input_ids.shape[-1]]) if input_ids is not None else None + flat_position_ids = position_ids.reshape([-1, position_ids.shape[-1]]) if position_ids is not None else None + flat_token_type_ids = ( + token_type_ids.reshape([-1, token_type_ids.shape[-1]]) if token_type_ids is not None else None + ) + flat_attention_mask = ( + attention_mask.reshape([-1, attention_mask.shape[-1]]) if attention_mask is not None else None + ) + flat_inputs_embeds = ( + inputs_embeds.reshape([-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]]) + if inputs_embeds is not None + else None + ) + + outputs = self.roberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.reshape([-1, num_choices]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class XLMRobertaForTokenClassification(XLMRobertaPretrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], TokenClassifierOutput]: + r""" + labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + loss_fct = CrossEntropyLoss() + loss = loss_fct( + logits.reshape([-1, self.num_labels]), + labels.reshape( + [ + -1, + ] + ), + ) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class XLMRobertaClassificationHead(nn.Layer): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + pooler_act = getattr(config, "pooler_act", "tanh") + self.activation = ACT2FN[pooler_act] + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = self.activation(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class XLMRobertaForQuestionAnswering(XLMRobertaPretrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = XLMRobertaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self._post_init(self.__init__) + + def forward( + self, + input_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + token_type_ids: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + start_positions: Optional[paddle.Tensor] = None, + end_positions: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[paddle.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.chunk(2, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.shape) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.shape) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.shape[1] + start_positions = start_positions.clip(0, ignored_index) + end_positions = end_positions.clip(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: paddle.Tensor x: + + Returns: paddle.Tensor + """ + mask = (input_ids != padding_idx).cast("int64") + incremental_indices = (paddle.cumsum(mask, axis=1) + past_key_values_length) * mask + return incremental_indices.cast("int64") + padding_idx diff --git a/paddlenlp/transformers/xlm_roberta/tokenizer.py b/paddlenlp/transformers/xlm_roberta/tokenizer.py new file mode 100644 index 000000000000..fcd13236878d --- /dev/null +++ b/paddlenlp/transformers/xlm_roberta/tokenizer.py @@ -0,0 +1,305 @@ +# coding=utf-8 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +""" Tokenization classes for XLM-RoBERTa model.""" + + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from ...utils import logger +from ..tokenizer_utils import AddedToken, PretrainedTokenizer + +SPIECE_UNDERLINE = "▁" + +__all__ = ["XLMRobertaTokenizer"] + + +class XLMRobertaTokenizer(PretrainedTokenizer): + """ + Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on + [SentencePiece](https://github.com/google/sentencepiece). + + This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to + this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + + + When building a sequence using special tokens, this is not the token that is used for the beginning of + sequence. The token used is the `cls_token`. + + + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + + + When building a sequence using special tokens, this is not the token that is used for the end of sequence. + The token used is the `sep_token`. + + + + sep_token (`str`, *optional*, defaults to `""`): + The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for + sequence classification or for a text and a question for question answering. It is also used as the last + token of a sequence built with special tokens. + cls_token (`str`, *optional*, defaults to `""`): + The classifier token which is used when doing sequence classification (classification of the whole sequence + instead of per-token classification). It is the first token of the sequence when built with special tokens. + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + pad_token (`str`, *optional*, defaults to `""`): + The token used for padding, for example when batching sequences of different lengths. + mask_token (`str`, *optional*, defaults to `""`): + The token used for masking values. This is the token used when training this model with masked language + modeling. This is the token which the model will try to predict. + sp_model_kwargs (`dict`, *optional*): + Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for + SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things, + to set: + + - `enable_sampling`: Enable subword regularization. + - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout. + + - `nbest_size = {0,1}`: No sampling is performed. + - `nbest_size > 1`: samples from the nbest_size results. + - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice) + using forward-filtering-and-backward-sampling algorithm. + + - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for + BPE-dropout. + + Attributes: + sp_model (`SentencePieceProcessor`): + The *SentencePiece* processor that is used for every conversion (string, tokens and IDs). + """ + + resource_files_names = {"vocab_file": "sentencepiece.bpe.model"} + pretrained_resource_files_map = { + "vocab_file": { + "BAAI/bge-m3": "https://bj.bcebos.com/paddlenlp/models/community/BAAI/bge-m3/sentencepiece.bpe.model", + } + } + pretrained_init_configuration = {} + max_model_input_sizes = { + "BAAI/bge-m3": 8192, + } + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + bos_token="", + eos_token="", + sep_token="", + cls_token="", + unk_token="", + pad_token="", + mask_token="", + sp_model_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + # Mask token behave like a normal word, i.e. include the space before it + mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token + + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(str(vocab_file)) + self.vocab_file = vocab_file + + # Original fairseq vocab and spm vocab must be "aligned": + # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 + # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ---- + # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' + # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a' + + # Mimic fairseq token-to-id alignment for the first 4 token + self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab + self.fairseq_offset = 1 + + self.fairseq_tokens_to_ids[""] = len(self.sp_model) + self.fairseq_offset + self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + sep_token=sep_token, + cls_token=cls_token, + pad_token=pad_token, + mask_token=mask_token, + sp_model_kwargs=self.sp_model_kwargs, + **kwargs, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + + # for backward compatibility + if not hasattr(self, "sp_model_kwargs"): + self.sp_model_kwargs = {} + + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. An XLM-RoBERTa sequence has the following format: + + - single sequence: ` X ` + - pair of sequences: ` A B ` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + + if token_ids_1 is None: + return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] + cls = [self.cls_token_id] + sep = [self.sep_token_id] + return cls + token_ids_0 + sep + sep + token_ids_1 + sep + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + if token_ids_1 is None: + return [1] + ([0] * len(token_ids_0)) + [1] + return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does + not make use of token type ids, therefore a list of zeros is returned. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of zeros. + + """ + + sep = [self.sep_token_id] + cls = [self.cls_token_id] + + if token_ids_1 is None: + return len(cls + token_ids_0 + sep) * [0] + return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] + + @property + def vocab_size(self): + return len(self.sp_model) + self.fairseq_offset + 1 # Add the token + + def get_vocab(self): + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text: str) -> List[str]: + # TODO check if the t5/llama PR also applies here + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + if token in self.fairseq_tokens_to_ids: + return self.fairseq_tokens_to_ids[token] + spm_id = self.sp_model.PieceToId(token) + + # Need to return unknown token if the SP model returned 0 + return spm_id + self.fairseq_offset if spm_id else self.unk_token_id + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.fairseq_ids_to_tokens: + return self.fairseq_ids_to_tokens[index] + return self.sp_model.IdToPiece(index - self.fairseq_offset) + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (strings for sub-words) in a single string.""" + out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip() + return out_string + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) diff --git a/tests/transformers/xlm_roberta/__init__.py b/tests/transformers/xlm_roberta/__init__.py new file mode 100644 index 000000000000..a9cc79cc9d7f --- /dev/null +++ b/tests/transformers/xlm_roberta/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/transformers/xlm_roberta/test_modeling.py b/tests/transformers/xlm_roberta/test_modeling.py new file mode 100644 index 000000000000..e5bded622082 --- /dev/null +++ b/tests/transformers/xlm_roberta/test_modeling.py @@ -0,0 +1,453 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import tempfile +import unittest + +import numpy as np +import paddle +from parameterized import parameterized_class + +from paddlenlp.transformers import ( + XLMRobertaConfig, + XLMRobertaForCausalLM, + XLMRobertaForMaskedLM, + XLMRobertaForMultipleChoice, + XLMRobertaForQuestionAnswering, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaModel, + XLMRobertaPretrainedModel, +) + +from ...testing_utils import require_package, slow +from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask + + +class XLMRobertaModelTester: + def __init__(self, parent: XLMRobertaModelTest): + self.parent: XLMRobertaModelTest = parent + self.batch_size = 13 + self.seq_length = 7 + self.is_training = True + self.use_input_mask = True + self.use_token_type_ids = True + self.use_labels = True + self.vocab_size = 99 + self.hidden_size = 32 + self.num_hidden_layers = 5 + self.num_attention_heads = 4 + self.intermediate_size = 37 + self.hidden_act = "gelu" + self.hidden_dropout_prob = 0.1 + self.attention_probs_dropout_prob = 0.1 + self.max_position_embeddings = 512 + self.type_vocab_size = 16 + self.type_sequence_label_size = 2 + self.initializer_range = 0.02 + self.layer_norm_eps = 1e-12 + self.pad_token_id = 1 + self.bos_token_id = 0 + self.eos_token_id = 2 + self.position_embedding_type = "absolute" + self.use_cache = True + self.classifier_dropout = None + self.num_labels = 2 + self.num_choices = 4 + self.dropout = 0.56 + self.scope = None + + def prepare_config_and_inputs(self): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.parent.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config() + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + + def get_config(self): + return XLMRobertaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + intermediate_size=self.intermediate_size, + hidden_act=self.hidden_act, + hidden_dropout_prob=self.hidden_dropout_prob, + attention_probs_dropout_prob=self.attention_probs_dropout_prob, + max_position_embeddings=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + initializer_range=self.initializer_range, + layer_norm_eps=self.layer_norm_eps, + pad_token_id=self.pad_token_id, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + position_embedding_type=self.position_embedding_type, + use_cache=self.use_cache, + classifier_dropout=self.classifier_dropout, + num_labels=self.num_labels, + ) + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + return ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) + + def create_and_check_model( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + + model = XLMRobertaModel(config) + model.eval() + + result = model( + input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, return_dict=self.parent.return_dict + ) + result = model(input_ids, token_type_ids=token_type_ids, return_dict=self.parent.return_dict) + result = model(input_ids, return_dict=self.parent.return_dict) + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size]) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.hidden_size]) + + def create_and_check_for_causal_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = XLMRobertaForCausalLM(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + return_dict=self.parent.return_dict, + ) + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def create_and_check_for_masked_lm( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = XLMRobertaForMaskedLM(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=token_labels, + return_dict=self.parent.return_dict, + ) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + + def create_and_check_for_token_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + + model = XLMRobertaForTokenClassification(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + labels=token_labels, + ) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.num_labels]) + + def create_and_check_for_sequence_classification( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + model = XLMRobertaForSequenceClassification(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + labels=sequence_labels, + return_dict=self.parent.return_dict, + ) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_labels]) + + def create_and_check_for_multiple_choice( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + + model = XLMRobertaForMultipleChoice(config) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand([-1, self.num_choices, -1]) + multiple_choice_input_mask = input_mask.unsqueeze(1).expand([-1, self.num_choices, -1]) + result = model( + multiple_choice_inputs_ids, + attention_mask=multiple_choice_input_mask, + token_type_ids=multiple_choice_token_type_ids, + return_dict=self.parent.return_dict, + labels=choice_labels, + ) + + if token_labels is not None: + result = result[1:] + elif paddle.is_tensor(result): + result = [result] + + self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_choices]) + + def create_and_check_for_question_answering( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ): + + model = XLMRobertaForQuestionAnswering(config) + model.eval() + result = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + return_dict=self.parent.return_dict, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + + if sequence_labels is not None: + start_logits, end_logits = result[1], result[2] + else: + start_logits, end_logits = result[0], result[1] + + self.parent.assertEqual(start_logits.shape, [self.batch_size, self.seq_length]) + self.parent.assertEqual(end_logits.shape, [self.batch_size, self.seq_length]) + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} + return config, inputs_dict + + +@parameterized_class( + ("return_dict", "use_labels"), + [ + [False, False], + [False, True], + [True, False], + [True, True], + ], +) +class XLMRobertaModelTest(ModelTesterMixin, unittest.TestCase): + base_model_class = XLMRobertaModel + use_test_inputs_embeds: bool = False + return_dict: bool = False + use_labels: bool = False + test_tie_weights = True + + all_model_classes = ( + XLMRobertaForQuestionAnswering, + XLMRobertaForTokenClassification, + XLMRobertaForMultipleChoice, + XLMRobertaForSequenceClassification, + XLMRobertaForMaskedLM, + XLMRobertaForCausalLM, + ) + all_generative_model_classes = (XLMRobertaForCausalLM,) + + def setUp(self): + self.model_tester = XLMRobertaModelTester(self) + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_for_masked_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_masked_lm(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs) + + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_token_classification(*config_and_inputs) + + def test_for_multiple_choice(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs) + + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_question_answering(*config_and_inputs) + + @slow + def test_model_from_pretrained(self): + for model_name in list(XLMRobertaPretrainedModel.pretrained_init_configuration.keys())[:1]: + model = XLMRobertaModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +class XLMRobertaCompatibilityTest(unittest.TestCase): + test_model_id = "hf-internal-testing/tiny-random-onnx-xlm-roberta" + + @classmethod + @require_package("transformers", "torch") + def setUpClass(cls) -> None: + from transformers import XLMRobertaModel + + cls.torch_model_path = tempfile.TemporaryDirectory().name + model = XLMRobertaModel.from_pretrained(cls.test_model_id) + model.save_pretrained(cls.torch_model_path) + + @require_package("transformers", "torch") + def test_xlmroberta_model_converter(self): + with tempfile.TemporaryDirectory() as tempdir: + + # 1. create commmon input + input_ids = np.random.randint(100, 200, [1, 20]) + + # 2. forward the paddle model + from paddlenlp.transformers import XLMRobertaModel + + paddle_model = XLMRobertaModel.from_pretrained(self.test_model_id, from_hf_hub=False, cache_dir=tempdir) + paddle_model.eval() + paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0] + + # 3. forward the torch model + import torch + from transformers import XLMRobertaModel + + torch_model = XLMRobertaModel.from_pretrained(self.torch_model_path) + torch_model.eval() + torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0] + + self.assertTrue( + np.allclose( + paddle_logit.detach().cpu().reshape([-1])[:9].numpy(), + torch_logit.detach().cpu().reshape([-1])[:9].numpy(), + rtol=1e-4, + ) + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transformers/xlm_roberta/test_tokenizer.py b/tests/transformers/xlm_roberta/test_tokenizer.py new file mode 100644 index 000000000000..a5dad1977828 --- /dev/null +++ b/tests/transformers/xlm_roberta/test_tokenizer.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from paddlenlp.transformers import XLMRobertaTokenizer + +from ..test_tokenizer_common import TokenizerTesterMixin + +# VOCAB_FILES_NAMES = XLMRobertaTokenizer.resource_files_names + + +class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + test_offsets = False + tokenizer_class = XLMRobertaTokenizer + + # Set up method called before each test + def setUp(self): + super().setUp() + self.vocab_file = "BAAI/bge-m3" + self.special_tokens_map = {"unk_token": ""} + + # Method to get a tokenizer instance with specified keyword arguments + def get_tokenizer(self, **kwargs): + kwargs.update(self.special_tokens_map) + return XLMRobertaTokenizer.from_pretrained(self.vocab_file, **kwargs) + + # Test method to check tokenization + def test_tokenization(self): + tokenizer = self.get_tokenizer() + text = "Hello, how are you?" + tokens = tokenizer.tokenize(text) + self.assertIsInstance(tokens, list) + self.assertGreater(len(tokens), 0) + + # Test method to check conversion of token to ID + def test_token_to_id(self): + tokenizer = self.get_tokenizer() + token = "Hello" + token_id = tokenizer.convert_tokens_to_ids(token) + self.assertIsInstance(token_id, int) + + # Test method to check conversion of ID to token + def test_id_to_token(self): + tokenizer = self.get_tokenizer() + token_id = tokenizer.convert_tokens_to_ids("How") + token = tokenizer.convert_ids_to_tokens(token_id) + self.assertEqual(token, "How") + + # Test method to check special tokens + def test_special_tokens(self): + tokenizer = self.get_tokenizer( + vocab_file=self.vocab_file, cls_token="", sep_token="", pad_token="" + ) + self.assertEqual(tokenizer.cls_token, "") + self.assertEqual(tokenizer.sep_token, "") + self.assertEqual(tokenizer.pad_token, "") + + # Test method to check building inputs with special tokens + def test_build_inputs_with_special_tokens(self): + tokenizer = self.get_tokenizer() + token_ids_0 = tokenizer.convert_tokens_to_ids(["Hello", "world"]) + token_ids_1 = tokenizer.convert_tokens_to_ids(["How", "are", "you"]) + + input_ids = tokenizer.build_inputs_with_special_tokens(token_ids_0, token_ids_1) + self.assertEqual(input_ids[0], tokenizer.cls_token_id) + self.assertEqual(input_ids[-1], tokenizer.sep_token_id) + + +if __name__ == "__main__": + unittest.main()