diff --git a/paddlenlp/transformers/__init__.py b/paddlenlp/transformers/__init__.py
index 285db61fd572..ad411bfd434b 100644
--- a/paddlenlp/transformers/__init__.py
+++ b/paddlenlp/transformers/__init__.py
@@ -209,6 +209,9 @@
from .xlm.modeling import *
from .xlm.tokenizer import *
from .xlm.configuration import *
+from .xlm_roberta.modeling import *
+from .xlm_roberta.tokenizer import *
+from .xlm_roberta.configuration import *
from .gau_alpha.modeling import *
from .gau_alpha.tokenizer import *
from .gau_alpha.configuration import *
diff --git a/paddlenlp/transformers/auto/configuration.py b/paddlenlp/transformers/auto/configuration.py
index f2058a5ec389..b77672d258c0 100644
--- a/paddlenlp/transformers/auto/configuration.py
+++ b/paddlenlp/transformers/auto/configuration.py
@@ -113,6 +113,7 @@
("unimo", "UNIMOConfig"),
("visualglm", "VisualGLMConfig"),
("xlm", "XLMConfig"),
+ ("xlm-roberta", "XLMRobertaConfig"),
("xlnet", "XLNetConfig"),
("yuan", "YuanConfig"),
]
@@ -202,6 +203,7 @@
("unimo", "UNIMO"),
("visualglm", "VisualGLM"),
("xlm", "XLM"),
+ ("xlm-roberta", "XLMRoberta"),
("xlnet", "XLNet"),
("yuan", "Yuan"),
]
diff --git a/paddlenlp/transformers/auto/modeling.py b/paddlenlp/transformers/auto/modeling.py
index 8b94d9f4b53d..938b06f5a5b9 100644
--- a/paddlenlp/transformers/auto/modeling.py
+++ b/paddlenlp/transformers/auto/modeling.py
@@ -94,6 +94,7 @@
("UNIMO", "unimo"),
("XLNet", "xlnet"),
("XLM", "xlm"),
+ ("XLMRoberta", "xlm_roberta"),
("GPT", "gpt"),
("GLM", "glm"),
("MT5", "mt5"),
diff --git a/paddlenlp/transformers/auto/tokenizer.py b/paddlenlp/transformers/auto/tokenizer.py
index a53e36c4935a..7dfc2ede0ada 100644
--- a/paddlenlp/transformers/auto/tokenizer.py
+++ b/paddlenlp/transformers/auto/tokenizer.py
@@ -99,6 +99,7 @@
("squeezebert", "SqueezeBertTokenizer"),
("t5", "T5Tokenizer"),
("xlm", "XLMTokenizer"),
+ ("xlm_roberta", "XLMRobertaTokenizer"),
("xlnet", "XLNetTokenizer"),
("bert_japanese", "BertJapaneseTokenizer"),
("bigbird", "BigBirdTokenizer"),
diff --git a/paddlenlp/transformers/xlm_roberta/__init__.py b/paddlenlp/transformers/xlm_roberta/__init__.py
new file mode 100644
index 000000000000..4c08fc6b9a63
--- /dev/null
+++ b/paddlenlp/transformers/xlm_roberta/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .configuration import *
+from .modeling import *
+from .tokenizer import *
diff --git a/paddlenlp/transformers/xlm_roberta/configuration.py b/paddlenlp/transformers/xlm_roberta/configuration.py
new file mode 100644
index 000000000000..dcbf46079bac
--- /dev/null
+++ b/paddlenlp/transformers/xlm_roberta/configuration.py
@@ -0,0 +1,160 @@
+# coding=utf-8
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" XLM-RoBERTa configuration"""
+
+from ..model_utils import PretrainedConfig
+
+__all__ = ["PRETRAINED_INIT_CONFIGURATION", "XLMRobertaConfig"]
+
+PRETRAINED_INIT_CONFIGURATION = {
+ "hf-internal-testing/tiny-random-onnx-xlm-roberta": {
+ "attention_probs_dropout_prob": 0.1,
+ "bos_token_id": 0,
+ "classifier_dropout": None,
+ "eos_token_id": 2,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 4,
+ "initializer_range": 0.02,
+ "intermediate_size": 37,
+ "layer_norm_eps": 1e-05,
+ "max_position_embeddings": 514,
+ "model_type": "xlm-roberta",
+ "num_attention_heads": 4,
+ "num_hidden_layers": 5,
+ "output_past": True,
+ "pad_token_id": 1,
+ "position_embedding_type": "absolute",
+ "dtype": "float32",
+ "type_vocab_size": 1,
+ "use_cache": True,
+ "vocab_size": 250002,
+ },
+}
+
+
+class XLMRobertaConfig(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`XLMRobertaModel`] or a [`TFXLMRobertaModel`]. It
+ is used to instantiate a XLM-RoBERTa model according to the specified arguments, defining the model architecture.
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the XLMRoBERTa
+ [xlm-roberta-base](https://huggingface.co/xlm-roberta-base) architecture.
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 30522):
+ Vocabulary size of the XLM-RoBERTa model. Defines the number of different tokens that can be represented by
+ the `inputs_ids` passed when calling [`XLMRobertaModel`] or [`TFXLMRobertaModel`].
+ hidden_size (`int`, *optional*, defaults to 768):
+ Dimensionality of the encoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 12):
+ Number of hidden layers in the Transformer encoder.
+ num_attention_heads (`int`, *optional*, defaults to 12):
+ Number of attention heads for each attention layer in the Transformer encoder.
+ intermediate_size (`int`, *optional*, defaults to 3072):
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
+ The dropout ratio for the attention probabilities.
+ max_position_embeddings (`int`, *optional*, defaults to 512):
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
+ just in case (e.g., 512 or 1024 or 2048).
+ type_vocab_size (`int`, *optional*, defaults to 2):
+ The vocabulary size of the `token_type_ids` passed when calling [`XLMRobertaModel`] or
+ [`TFXLMRobertaModel`].
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
+ is_decoder (`bool`, *optional*, defaults to `False`):
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ classifier_dropout (`float`, *optional*):
+ The dropout ratio for the classification head.
+
+ Examples:
+
+ ```python
+ >>> from paddlenlp.transformers import XLMRobertaConfig, XLMRobertaModel
+
+ >>> # Initializing a XLM-RoBERTa xlm-roberta-base style configuration
+ >>> configuration = XLMRobertaConfig()
+
+ >>> # Initializing a model (with random weights) from the xlm-roberta-base style configuration
+ >>> model = XLMRobertaModel(configuration)
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ ```"""
+
+ model_type = "xlm-roberta"
+
+ def __init__(
+ self,
+ vocab_size=30522,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ position_embedding_type="absolute",
+ use_cache=True,
+ classifier_dropout=None,
+ **kwargs,
+ ):
+ kwargs["return_dict"] = kwargs.pop("return_dict", False)
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.position_embedding_type = position_embedding_type
+ self.use_cache = use_cache
+ self.classifier_dropout = classifier_dropout
diff --git a/paddlenlp/transformers/xlm_roberta/modeling.py b/paddlenlp/transformers/xlm_roberta/modeling.py
new file mode 100644
index 000000000000..31feb37785ef
--- /dev/null
+++ b/paddlenlp/transformers/xlm_roberta/modeling.py
@@ -0,0 +1,1618 @@
+# coding=utf-8
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Paddle XLM-RoBERTa model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import paddle
+from paddle import nn
+from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...utils import logger
+from ...utils.converter import StateDictNameMapping
+from ..activations import ACT2FN
+from ..model_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ..model_utils import (
+ PretrainedModel,
+ apply_chunking_to_forward,
+ register_base_model,
+)
+from .configuration import PRETRAINED_INIT_CONFIGURATION, XLMRobertaConfig
+
+__all__ = [
+ "XLMRobertaModel",
+ "XLMRobertaPretrainedModel",
+ "XLMRobertaForSequenceClassification",
+ "XLMRobertaForTokenClassification",
+ "XLMRobertaForQuestionAnswering",
+ "XLMRobertaForMaskedLM",
+ "XLMRobertaForMultipleChoice",
+ "XLMRobertaForCausalLM",
+]
+
+
+class XLMRobertaEmbeddings(nn.Layer):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(
+ config.vocab_size, config.hidden_size
+ ) # padding_idx=config.pad_token_id NOTE, donot set padding_idx
+ self.word_embeddings.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer(
+ "position_ids",
+ paddle.arange(config.max_position_embeddings, dtype=paddle.int64).expand((1, -1)),
+ persistable=False,
+ )
+ self.register_buffer(
+ "token_type_ids", paddle.zeros(self.position_ids.shape, dtype=paddle.int64), persistable=False
+ )
+
+ # End copy
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings,
+ config.hidden_size, # padding_idx=self.padding_idx
+ )
+ self.position_embeddings.padding_idx = config.pad_token_id
+
+ def forward(
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.shape
+ else:
+ input_shape = inputs_embeds.shape[:-1]
+
+ seq_length = input_shape[1]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand([input_shape[0], seq_length])
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = paddle.zeros(input_shape, dtype=paddle.int64)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: paddle.Tensor
+
+ Returns: paddle.Tensor
+ """
+ input_shape = inputs_embeds.shape[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = paddle.arange(
+ self.padding_idx + 1,
+ sequence_length + self.padding_idx + 1,
+ dtype=paddle.int64,
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+class XLMRobertaSelfAttention(nn.Layer):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+ self.scale = math.sqrt(self.attention_head_size)
+
+ def transpose_for_scores(self, x: paddle.Tensor) -> paddle.Tensor:
+ new_x_shape = x.shape[:-1] + [self.num_attention_heads, self.attention_head_size]
+ x = x.reshape(new_x_shape)
+ return x.transpose([0, 2, 1, 3])
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[paddle.Tensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[paddle.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = paddle.concat([past_key_value[0], key_layer], axis=2)
+ value_layer = paddle.concat([past_key_value[1], value_layer], axis=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(paddle.Tensor, paddle.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(paddle.Tensor, paddle.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = paddle.matmul(query_layer, key_layer, transpose_y=True)
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = paddle.to_tensor(
+ key_length - 1,
+ dtype=paddle.int64,
+ ).reshape([-1, 1])
+ else:
+ position_ids_l = paddle.arange(
+ query_length,
+ dtype=paddle.int64,
+ ).reshape([-1, 1])
+ position_ids_r = paddle.arange(
+ key_length,
+ dtype=paddle.int64,
+ ).reshape([1, -1])
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.cast(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = paddle.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = paddle.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = paddle.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / self.scale
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in XLMRobertaModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, axis=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = paddle.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.transpose([0, 2, 1, 3])
+ new_context_layer_shape = context_layer.shape[:-2] + [
+ self.all_head_size,
+ ]
+ context_layer = context_layer.reshape(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+class XLMRobertaSelfOutput(nn.Layer):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class XLMRobertaAttention(nn.Layer):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = XLMRobertaSelfAttention(config, position_embedding_type=position_embedding_type)
+ self.output = XLMRobertaSelfOutput(config)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[paddle.Tensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[paddle.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+class XLMRobertaIntermediate(nn.Layer):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class XLMRobertaOutput(nn.Layer):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: paddle.Tensor, input_tensor: paddle.Tensor) -> paddle.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class XLMRobertaLayer(nn.Layer):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = XLMRobertaAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = XLMRobertaAttention(config, position_embedding_type="absolute")
+ self.intermediate = XLMRobertaIntermediate(config)
+ self.output = XLMRobertaOutput(config)
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ past_key_value: Optional[Tuple[Tuple[paddle.Tensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[paddle.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class XLMRobertaEncoder(nn.Layer):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.LayerList([XLMRobertaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.enable_recompute = False
+
+ def forward(
+ self,
+ hidden_states: paddle.Tensor,
+ attention_mask: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ past_key_values: Optional[Tuple[Tuple[paddle.Tensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[paddle.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.enable_recompute and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.enable_recompute and not hidden_states.stop_gradient:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer_module.__call__,
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class XLMRobertaPooler(nn.Layer):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ pooler_act = getattr(config, "pooler_act", "tanh")
+ self.activation = ACT2FN[pooler_act]
+
+ def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class XLMRobertaPretrainedModel(PretrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = XLMRobertaConfig
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
+ pretrained_resource_files_map = {
+ "model_state": {
+ "hf-internal-testing/tiny-random-onnx-xlm-roberta": "https://bj.bcebos.com/paddlenlp/models/community/hf-internal-testing/tiny-random-onnx-xlm-roberta/model.safetensors",
+ }
+ }
+ base_model_prefix = "roberta"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["XLMRobertaEmbeddings", "XLMRobertaSelfAttention"]
+
+ def can_generate(self) -> bool:
+ return False
+
+ @classmethod
+ def _get_name_mappings(cls, config):
+ architectures = config.architectures + [cls.__name__]
+ mappings = []
+ model_mappings = [
+ ["embeddings.word_embeddings.weight", "embeddings.word_embeddings.weight"],
+ ["embeddings.position_ids", "embeddings.position_ids"],
+ ["embeddings.position_embeddings.weight", "embeddings.position_embeddings.weight"],
+ ["embeddings.token_type_embeddings.weight", "embeddings.token_type_embeddings.weight"],
+ ["embeddings.LayerNorm.weight", "embeddings.LayerNorm.weight"],
+ ["embeddings.LayerNorm.bias", "embeddings.LayerNorm.bias"],
+ ["pooler.dense.weight", "pooler.dense.weight", "transpose"],
+ ["pooler.dense.bias", "pooler.dense.bias"],
+ # for TokenClassification
+ ]
+ for layer_index in range(config.num_hidden_layers):
+ for name in [
+ "attention.self.query",
+ "attention.self.key",
+ "attention.self.value",
+ "attention.output.dense",
+ "attention.output.LayerNorm",
+ "intermediate.dense",
+ "output.dense",
+ "output.LayerNorm",
+ ]:
+ action = None if "LayerNorm" in name else "transpose"
+ model_mappings.extend(
+ [
+ [
+ f"encoder.layer.{layer_index}.{name}.weight",
+ f"encoder.layer.{layer_index}.{name}.weight",
+ action,
+ ],
+ [
+ f"encoder.layer.{layer_index}.{name}.bias",
+ f"encoder.layer.{layer_index}.{name}.bias",
+ ],
+ ]
+ )
+
+ # base-model prefix "XLMRobertaModel"
+ torch_prefix = ""
+ paddle_prefix = ""
+ if "XLMRobertaModel" not in config.architectures:
+ torch_prefix = "roberta."
+ if "XLMRobertaModel" not in [cls.__name__]:
+ paddle_prefix = "roberta."
+
+ # add prefix
+ for mapping in model_mappings:
+ mapping[0] = torch_prefix + mapping[0]
+ mapping[1] = paddle_prefix + mapping[1]
+
+ if "XLMRobertaForCausalLM" in architectures:
+ model_mappings.extend(
+ [
+ ["lm_head.dense.weight", "lm_head.dense.weight", "transpose"],
+ ["lm_head.dense.bias", "lm_head.dense.bias"],
+ ["lm_head.layer_norm.weight", "lm_head.layer_norm.weight"],
+ ["lm_head.layer_norm.bias", "lm_head.layer_norm.bias"],
+ ["lm_head.bias", "lm_head.bias"],
+ ]
+ )
+
+ # downstream mappings
+ if "XLMRobertaForQuestionAnswering" in architectures:
+ model_mappings.extend(
+ [
+ ["qa_outputs.weight", "qa_outputs.weight", "transpose"],
+ ["qa_outputs.bias", "qa_outputs.bias"],
+ ]
+ )
+ if "XLMRobertaForSequenceClassification" in architectures:
+ model_mappings.extend(
+ [
+ ["classifier.dense.weight", "classifier.dense.weight", "transpose"],
+ ["classifier.dense.bias", "classifier.dense.bias"],
+ ["classifier.out_proj.weight", "classifier.out_proj.weight", "transpose"],
+ ["classifier.out_proj.bias", "classifier.out_proj.bias"],
+ ]
+ )
+
+ if "XLMRobertaForMultipleChoice" in architectures or "XLMRobertaForTokenClassification" in architectures:
+ model_mappings.extend(
+ [
+ ["classifier.weight", "classifier.weight", "transpose"],
+ ["classifier.bias", "classifier.bias"],
+ ]
+ )
+
+ mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)]
+ return mappings
+
+ @paddle.no_grad()
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if hasattr(module, "padding_idx") and module.padding_idx is not None:
+ module.weight[module.padding_idx] = 0
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.zero_()
+ module.weight.fill_(1.0)
+
+
+@register_base_model
+class XLMRobertaModel(XLMRobertaPretrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
+ Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
+
+ """
+
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = XLMRobertaEmbeddings(config)
+ self.encoder = XLMRobertaEncoder(config)
+
+ self.pooler = XLMRobertaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ @property
+ def dtype(self) -> paddle.dtype:
+ """
+ `paddle.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ try:
+ return next(self.named_parameters())[1].dtype
+ except StopIteration:
+ try:
+ return next(self.named_buffers())[1].dtype
+ except StopIteration:
+ return self._dtype
+
+ def invert_attention_mask(self, encoder_attention_mask: paddle.Tensor) -> paddle.Tensor:
+ """
+ Invert an attention mask (e.g., switches 0. and 1.).
+
+ Args:
+ encoder_attention_mask (`paddle.Tensor`): An attention mask.
+
+ Returns:
+ `paddle.Tensor`: The inverted attention mask.
+ """
+ if encoder_attention_mask.ndim == 3:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
+ if encoder_attention_mask.ndim == 2:
+ encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
+ # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
+ # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
+ # /transformer/transformer_layers.py#L270
+ # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
+ # encoder_extended_attention_mask.transpose(-1, -2))
+ encoder_extended_attention_mask = encoder_extended_attention_mask.cast(dtype=self.dtype) # fp16 compatibility
+ encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * paddle.finfo(self.dtype).min
+
+ return encoder_extended_attention_mask
+
+ @staticmethod
+ def create_extended_attention_mask_for_decoder(input_shape, attention_mask):
+ batch_size, seq_length = input_shape
+ seq_ids = paddle.arange(seq_length)
+ causal_mask = seq_ids[None, None, :].tile([batch_size, seq_length, 1]) <= seq_ids[None, :, None]
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
+ # causal and attention masks must have same type with pytorch version < 1.3
+ causal_mask = causal_mask.cast(dtype=attention_mask.dtype)
+
+ if causal_mask.shape[1] < attention_mask.shape[1]:
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
+ causal_mask = paddle.concat(
+ [
+ paddle.ones((batch_size, seq_length, prefix_seq_len), dtype=causal_mask.dtype),
+ causal_mask,
+ ],
+ axis=-1,
+ )
+
+ extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
+ return extended_attention_mask
+
+ def get_extended_attention_mask(
+ self, attention_mask: paddle.Tensor, input_shape: Tuple[int], dtype: paddle.dtype = None
+ ) -> paddle.Tensor:
+ """
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
+
+ Arguments:
+ attention_mask (`paddle.Tensor`):
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
+ input_shape (`Tuple[int]`):
+ The shape of the input to the model.
+
+ Returns:
+ `paddle.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
+ """
+ if dtype is None:
+ dtype = self.dtype
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ if attention_mask.ndim == 3:
+ extended_attention_mask = attention_mask[:, None, :, :]
+ elif attention_mask.ndim == 2:
+ # Provided a padding mask of dimensions [batch_size, seq_length]
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder:
+ extended_attention_mask = XLMRobertaModel.create_extended_attention_mask_for_decoder(
+ input_shape, attention_mask
+ )
+ else:
+ extended_attention_mask = attention_mask[:, None, None, :]
+ else:
+ raise ValueError(
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
+ )
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.cast(dtype=dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * paddle.finfo(dtype).min
+ return extended_attention_mask
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ past_key_values: Optional[List[paddle.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(paddle.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.shape
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.shape[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = paddle.ones(
+ ((batch_size, seq_length + past_key_values_length)),
+ )
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand([batch_size, seq_length])
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = paddle.zeros(
+ input_shape,
+ dtype=paddle.int64,
+ )
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: paddle.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.shape
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = paddle.ones(encoder_hidden_shape)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+class XLMRobertaForCausalLM(XLMRobertaPretrainedModel):
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `XLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
+
+ if config.tie_word_embeddings:
+ input_embeddings = self.roberta.embeddings.word_embeddings.weight
+ else:
+ input_embeddings = None
+ self.lm_head = XLMRobertaLMHead(config, input_embeddings=input_embeddings)
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def get_output_embeddings(self):
+ if self.config.tie_word_embeddings:
+ return None
+ else:
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ if self.config.tie_word_embeddings:
+ logger.warning(
+ "`set_output_embeddings` method is called when `config.tie_word_embeddings=True`. This is not expected. We will do nothing!"
+ )
+ else:
+ self.lm_head.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ labels: Optional[paddle.Tensor] = None,
+ past_key_values: Tuple[Tuple[paddle.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`paddle.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(paddle.Tensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from paddlenlp.transformers import AutoTokenizer, XLMRobertaForCausalLM, AutoConfig
+ >>> import paddle
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base")
+ >>> config = AutoConfig.from_pretrained("roberta-base")
+ >>> config.is_decoder = True
+ >>> model = XLMRobertaForCausalLM.from_pretrained("roberta-base", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pd")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :]
+ labels = labels[:, 1:]
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(
+ shifted_prediction_scores.reshape([-1, self.config.vocab_size]),
+ labels.reshape(
+ [
+ -1,
+ ]
+ ),
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = paddle.ones(input_shape, dtype=input_ids.dtype)
+
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ past_length = past_key_values[0][0].shape[2]
+
+ # Some generation methods already pass only the last input ID
+ if input_ids.shape[1] > past_length:
+ remove_prefix_length = past_length
+ else:
+ # Default to old behavior: keep only final ID
+ remove_prefix_length = input_ids.shape[1] - 1
+
+ input_ids = input_ids[:, remove_prefix_length:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(axis=0, index=beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+class XLMRobertaForMaskedLM(XLMRobertaPretrainedModel):
+ _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `XLMRobertaForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
+
+ if config.tie_word_embeddings:
+ input_embeddings = self.roberta.embeddings.word_embeddings.weight
+ else:
+ input_embeddings = None
+ self.lm_head = XLMRobertaLMHead(config, input_embeddings=input_embeddings)
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def get_output_embeddings(self):
+ if self.config.tie_word_embeddings:
+ return None
+ else:
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ if self.config.tie_word_embeddings:
+ logger.warning(
+ "`set_output_embeddings` method is called when `config.tie_word_embeddings=True`. This is not expected. We will do nothing!"
+ )
+ else:
+ self.lm_head.decoder = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ encoder_hidden_states: Optional[paddle.Tensor] = None,
+ encoder_attention_mask: Optional[paddle.Tensor] = None,
+ labels: Optional[paddle.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+ Used to hide legacy arguments that have been deprecated.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(
+ prediction_scores.reshape([-1, self.config.vocab_size]),
+ labels.reshape(
+ [
+ -1,
+ ]
+ ),
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class XLMRobertaLMHead(nn.Layer):
+ """Roberta Head for masked language modeling."""
+
+ def __init__(self, config, input_embeddings=None):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, epsilon=config.layer_norm_eps)
+
+ if input_embeddings is None:
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias_attr=False)
+ # self.bias = nn.Parameter(paddle.zeros((config.vocab_size,)))
+ data = paddle.zeros((config.vocab_size,))
+ self.bias = paddle.create_parameter(
+ data.shape, dtype=data.dtype, default_initializer=nn.initializer.Assign(data)
+ )
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+ else:
+ # self.bias = nn.Parameter(paddle.zeros((config.vocab_size,)))
+ data = paddle.zeros((config.vocab_size,))
+ self.bias = paddle.create_parameter(
+ data.shape, dtype=data.dtype, default_initializer=nn.initializer.Assign(data)
+ )
+ decoder_weight = input_embeddings.weight if hasattr(input_embeddings, "weight") else input_embeddings
+ self.decoder = lambda x: paddle.matmul(x, decoder_weight, transpose_y=True) + self.bias
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = nn.functional.gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x)
+
+ return x
+
+
+class XLMRobertaForSequenceClassification(XLMRobertaPretrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
+ self.classifier = XLMRobertaClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ labels: Optional[paddle.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.reshape([-1, self.num_labels]), labels.reshape((-1,)))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class XLMRobertaForMultipleChoice(XLMRobertaPretrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.roberta = XLMRobertaModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ labels: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.reshape([-1, input_ids.shape[-1]]) if input_ids is not None else None
+ flat_position_ids = position_ids.reshape([-1, position_ids.shape[-1]]) if position_ids is not None else None
+ flat_token_type_ids = (
+ token_type_ids.reshape([-1, token_type_ids.shape[-1]]) if token_type_ids is not None else None
+ )
+ flat_attention_mask = (
+ attention_mask.reshape([-1, attention_mask.shape[-1]]) if attention_mask is not None else None
+ )
+ flat_inputs_embeds = (
+ inputs_embeds.reshape([-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]])
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.roberta(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.reshape([-1, num_choices])
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class XLMRobertaForTokenClassification(XLMRobertaPretrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ labels: Optional[paddle.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(
+ logits.reshape([-1, self.num_labels]),
+ labels.reshape(
+ [
+ -1,
+ ]
+ ),
+ )
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class XLMRobertaClassificationHead(nn.Layer):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ pooler_act = getattr(config, "pooler_act", "tanh")
+ self.activation = ACT2FN[pooler_act]
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = self.activation(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+class XLMRobertaForQuestionAnswering(XLMRobertaPretrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self._post_init(self.__init__)
+
+ def forward(
+ self,
+ input_ids: Optional[paddle.Tensor] = None,
+ attention_mask: Optional[paddle.Tensor] = None,
+ token_type_ids: Optional[paddle.Tensor] = None,
+ position_ids: Optional[paddle.Tensor] = None,
+ inputs_embeds: Optional[paddle.Tensor] = None,
+ start_positions: Optional[paddle.Tensor] = None,
+ end_positions: Optional[paddle.Tensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[paddle.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`paddle.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`paddle.Tensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.chunk(2, axis=-1)
+ start_logits = start_logits.squeeze(-1)
+ end_logits = end_logits.squeeze(-1)
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.shape) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.shape) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.shape[1]
+ start_positions = start_positions.clip(0, ignored_index)
+ end_positions = end_positions.clip(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: paddle.Tensor x:
+
+ Returns: paddle.Tensor
+ """
+ mask = (input_ids != padding_idx).cast("int64")
+ incremental_indices = (paddle.cumsum(mask, axis=1) + past_key_values_length) * mask
+ return incremental_indices.cast("int64") + padding_idx
diff --git a/paddlenlp/transformers/xlm_roberta/tokenizer.py b/paddlenlp/transformers/xlm_roberta/tokenizer.py
new file mode 100644
index 000000000000..fcd13236878d
--- /dev/null
+++ b/paddlenlp/transformers/xlm_roberta/tokenizer.py
@@ -0,0 +1,305 @@
+# coding=utf-8
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+""" Tokenization classes for XLM-RoBERTa model."""
+
+
+import os
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple
+
+import sentencepiece as spm
+
+from ...utils import logger
+from ..tokenizer_utils import AddedToken, PretrainedTokenizer
+
+SPIECE_UNDERLINE = "▁"
+
+__all__ = ["XLMRobertaTokenizer"]
+
+
+class XLMRobertaTokenizer(PretrainedTokenizer):
+ """
+ Adapted from [`RobertaTokenizer`] and [`XLNetTokenizer`]. Based on
+ [SentencePiece](https://github.com/google/sentencepiece).
+
+ This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
+ this superclass for more information regarding those methods.
+
+ Args:
+ vocab_file (`str`):
+ Path to the vocabulary file.
+ bos_token (`str`, *optional*, defaults to `""`):
+ The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the beginning of
+ sequence. The token used is the `cls_token`.
+
+
+
+ eos_token (`str`, *optional*, defaults to `""`):
+ The end of sequence token.
+
+
+
+ When building a sequence using special tokens, this is not the token that is used for the end of sequence.
+ The token used is the `sep_token`.
+
+
+
+ sep_token (`str`, *optional*, defaults to `""`):
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
+ sequence classification or for a text and a question for question answering. It is also used as the last
+ token of a sequence built with special tokens.
+ cls_token (`str`, *optional*, defaults to `""`):
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
+ unk_token (`str`, *optional*, defaults to `""`):
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
+ token instead.
+ pad_token (`str`, *optional*, defaults to `""`):
+ The token used for padding, for example when batching sequences of different lengths.
+ mask_token (`str`, *optional*, defaults to `""`):
+ The token used for masking values. This is the token used when training this model with masked language
+ modeling. This is the token which the model will try to predict.
+ sp_model_kwargs (`dict`, *optional*):
+ Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
+ SentencePiece](https://github.com/google/sentencepiece/tree/master/python) can be used, among other things,
+ to set:
+
+ - `enable_sampling`: Enable subword regularization.
+ - `nbest_size`: Sampling parameters for unigram. Invalid for BPE-Dropout.
+
+ - `nbest_size = {0,1}`: No sampling is performed.
+ - `nbest_size > 1`: samples from the nbest_size results.
+ - `nbest_size < 0`: assuming that nbest_size is infinite and samples from the all hypothesis (lattice)
+ using forward-filtering-and-backward-sampling algorithm.
+
+ - `alpha`: Smoothing parameter for unigram sampling, and dropout probability of merge operations for
+ BPE-dropout.
+
+ Attributes:
+ sp_model (`SentencePieceProcessor`):
+ The *SentencePiece* processor that is used for every conversion (string, tokens and IDs).
+ """
+
+ resource_files_names = {"vocab_file": "sentencepiece.bpe.model"}
+ pretrained_resource_files_map = {
+ "vocab_file": {
+ "BAAI/bge-m3": "https://bj.bcebos.com/paddlenlp/models/community/BAAI/bge-m3/sentencepiece.bpe.model",
+ }
+ }
+ pretrained_init_configuration = {}
+ max_model_input_sizes = {
+ "BAAI/bge-m3": 8192,
+ }
+ model_input_names = ["input_ids", "attention_mask"]
+
+ def __init__(
+ self,
+ vocab_file,
+ bos_token="",
+ eos_token="",
+ sep_token="",
+ cls_token="",
+ unk_token="",
+ pad_token="",
+ mask_token="",
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ # Mask token behave like a normal word, i.e. include the space before it
+ mask_token = AddedToken(mask_token, lstrip=True, special=True) if isinstance(mask_token, str) else mask_token
+
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.Load(str(vocab_file))
+ self.vocab_file = vocab_file
+
+ # Original fairseq vocab and spm vocab must be "aligned":
+ # Vocab | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9
+ # -------- | ------- | ------- | ------ | ------- | --- | --- | --- | ----- | ----- | ----
+ # fairseq | '' | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-'
+ # spm | '' | '' | '' | ',' | '.' | '▁' | 's' | '▁de' | '-' | '▁a'
+
+ # Mimic fairseq token-to-id alignment for the first 4 token
+ self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3}
+
+ # The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
+ self.fairseq_offset = 1
+
+ self.fairseq_tokens_to_ids[""] = len(self.sp_model) + self.fairseq_offset
+ self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
+
+ super().__init__(
+ bos_token=bos_token,
+ eos_token=eos_token,
+ unk_token=unk_token,
+ sep_token=sep_token,
+ cls_token=cls_token,
+ pad_token=pad_token,
+ mask_token=mask_token,
+ sp_model_kwargs=self.sp_model_kwargs,
+ **kwargs,
+ )
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state["sp_model"] = None
+ state["sp_model_proto"] = self.sp_model.serialized_model_proto()
+ return state
+
+ def __setstate__(self, d):
+ self.__dict__ = d
+
+ # for backward compatibility
+ if not hasattr(self, "sp_model_kwargs"):
+ self.sp_model_kwargs = {}
+
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
+ self.sp_model.LoadFromSerializedProto(self.sp_model_proto)
+
+ def build_inputs_with_special_tokens(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
+ adding special tokens. An XLM-RoBERTa sequence has the following format:
+
+ - single sequence: ` X `
+ - pair of sequences: ` A B `
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs to which the special tokens will be added.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
+ """
+
+ if token_ids_1 is None:
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
+ cls = [self.cls_token_id]
+ sep = [self.sep_token_id]
+ return cls + token_ids_0 + sep + sep + token_ids_1 + sep
+
+ def get_special_tokens_mask(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
+ ) -> List[int]:
+ """
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
+ special tokens using the tokenizer `prepare_for_model` method.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
+ Whether or not the token list is already formatted with special tokens for the model.
+
+ Returns:
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
+ """
+
+ if already_has_special_tokens:
+ return super().get_special_tokens_mask(
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
+ )
+
+ if token_ids_1 is None:
+ return [1] + ([0] * len(token_ids_0)) + [1]
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
+
+ def create_token_type_ids_from_sequences(
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
+ ) -> List[int]:
+ """
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. XLM-RoBERTa does
+ not make use of token type ids, therefore a list of zeros is returned.
+
+ Args:
+ token_ids_0 (`List[int]`):
+ List of IDs.
+ token_ids_1 (`List[int]`, *optional*):
+ Optional second list of IDs for sequence pairs.
+
+ Returns:
+ `List[int]`: List of zeros.
+
+ """
+
+ sep = [self.sep_token_id]
+ cls = [self.cls_token_id]
+
+ if token_ids_1 is None:
+ return len(cls + token_ids_0 + sep) * [0]
+ return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
+
+ @property
+ def vocab_size(self):
+ return len(self.sp_model) + self.fairseq_offset + 1 # Add the token
+
+ def get_vocab(self):
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
+ vocab.update(self.added_tokens_encoder)
+ return vocab
+
+ def _tokenize(self, text: str) -> List[str]:
+ # TODO check if the t5/llama PR also applies here
+ return self.sp_model.encode(text, out_type=str)
+
+ def _convert_token_to_id(self, token):
+ """Converts a token (str) in an id using the vocab."""
+ if token in self.fairseq_tokens_to_ids:
+ return self.fairseq_tokens_to_ids[token]
+ spm_id = self.sp_model.PieceToId(token)
+
+ # Need to return unknown token if the SP model returned 0
+ return spm_id + self.fairseq_offset if spm_id else self.unk_token_id
+
+ def _convert_id_to_token(self, index):
+ """Converts an index (integer) in a token (str) using the vocab."""
+ if index in self.fairseq_ids_to_tokens:
+ return self.fairseq_ids_to_tokens[index]
+ return self.sp_model.IdToPiece(index - self.fairseq_offset)
+
+ def convert_tokens_to_string(self, tokens):
+ """Converts a sequence of tokens (strings for sub-words) in a single string."""
+ out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
+ return out_string
+
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
+ if not os.path.isdir(save_directory):
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
+ return
+ out_vocab_file = os.path.join(
+ save_directory,
+ (filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"],
+ )
+
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
+ copyfile(self.vocab_file, out_vocab_file)
+ elif not os.path.isfile(self.vocab_file):
+ with open(out_vocab_file, "wb") as fi:
+ content_spiece_model = self.sp_model.serialized_model_proto()
+ fi.write(content_spiece_model)
+
+ return (out_vocab_file,)
diff --git a/tests/transformers/xlm_roberta/__init__.py b/tests/transformers/xlm_roberta/__init__.py
new file mode 100644
index 000000000000..a9cc79cc9d7f
--- /dev/null
+++ b/tests/transformers/xlm_roberta/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/transformers/xlm_roberta/test_modeling.py b/tests/transformers/xlm_roberta/test_modeling.py
new file mode 100644
index 000000000000..e5bded622082
--- /dev/null
+++ b/tests/transformers/xlm_roberta/test_modeling.py
@@ -0,0 +1,453 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import tempfile
+import unittest
+
+import numpy as np
+import paddle
+from parameterized import parameterized_class
+
+from paddlenlp.transformers import (
+ XLMRobertaConfig,
+ XLMRobertaForCausalLM,
+ XLMRobertaForMaskedLM,
+ XLMRobertaForMultipleChoice,
+ XLMRobertaForQuestionAnswering,
+ XLMRobertaForSequenceClassification,
+ XLMRobertaForTokenClassification,
+ XLMRobertaModel,
+ XLMRobertaPretrainedModel,
+)
+
+from ...testing_utils import require_package, slow
+from ..test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
+
+
+class XLMRobertaModelTester:
+ def __init__(self, parent: XLMRobertaModelTest):
+ self.parent: XLMRobertaModelTest = parent
+ self.batch_size = 13
+ self.seq_length = 7
+ self.is_training = True
+ self.use_input_mask = True
+ self.use_token_type_ids = True
+ self.use_labels = True
+ self.vocab_size = 99
+ self.hidden_size = 32
+ self.num_hidden_layers = 5
+ self.num_attention_heads = 4
+ self.intermediate_size = 37
+ self.hidden_act = "gelu"
+ self.hidden_dropout_prob = 0.1
+ self.attention_probs_dropout_prob = 0.1
+ self.max_position_embeddings = 512
+ self.type_vocab_size = 16
+ self.type_sequence_label_size = 2
+ self.initializer_range = 0.02
+ self.layer_norm_eps = 1e-12
+ self.pad_token_id = 1
+ self.bos_token_id = 0
+ self.eos_token_id = 2
+ self.position_embedding_type = "absolute"
+ self.use_cache = True
+ self.classifier_dropout = None
+ self.num_labels = 2
+ self.num_choices = 4
+ self.dropout = 0.56
+ self.scope = None
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ token_type_ids = None
+ if self.use_token_type_ids:
+ token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.parent.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+ return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
+
+ def get_config(self):
+ return XLMRobertaConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ intermediate_size=self.intermediate_size,
+ hidden_act=self.hidden_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ initializer_range=self.initializer_range,
+ layer_norm_eps=self.layer_norm_eps,
+ pad_token_id=self.pad_token_id,
+ bos_token_id=self.bos_token_id,
+ eos_token_id=self.eos_token_id,
+ position_embedding_type=self.position_embedding_type,
+ use_cache=self.use_cache,
+ classifier_dropout=self.classifier_dropout,
+ num_labels=self.num_labels,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = self.prepare_config_and_inputs()
+
+ return (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ )
+
+ def create_and_check_model(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+
+ model = XLMRobertaModel(config)
+ model.eval()
+
+ result = model(
+ input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, return_dict=self.parent.return_dict
+ )
+ result = model(input_ids, token_type_ids=token_type_ids, return_dict=self.parent.return_dict)
+ result = model(input_ids, return_dict=self.parent.return_dict)
+
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.hidden_size])
+ self.parent.assertEqual(result[1].shape, [self.batch_size, self.hidden_size])
+
+ def create_and_check_for_causal_lm(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ model = XLMRobertaForCausalLM(config)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=token_labels,
+ return_dict=self.parent.return_dict,
+ )
+ if token_labels is not None:
+ result = result[1:]
+ elif paddle.is_tensor(result):
+ result = [result]
+
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size])
+
+ def create_and_check_for_masked_lm(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ model = XLMRobertaForMaskedLM(config)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=token_labels,
+ return_dict=self.parent.return_dict,
+ )
+
+ if token_labels is not None:
+ result = result[1:]
+ elif paddle.is_tensor(result):
+ result = [result]
+
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size])
+
+ def create_and_check_for_token_classification(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+
+ model = XLMRobertaForTokenClassification(config)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ return_dict=self.parent.return_dict,
+ labels=token_labels,
+ )
+
+ if token_labels is not None:
+ result = result[1:]
+ elif paddle.is_tensor(result):
+ result = [result]
+
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.num_labels])
+
+ def create_and_check_for_sequence_classification(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ model = XLMRobertaForSequenceClassification(config)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ labels=sequence_labels,
+ return_dict=self.parent.return_dict,
+ )
+
+ if token_labels is not None:
+ result = result[1:]
+ elif paddle.is_tensor(result):
+ result = [result]
+
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_labels])
+
+ def create_and_check_for_multiple_choice(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+
+ model = XLMRobertaForMultipleChoice(config)
+ model.eval()
+ multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand([-1, self.num_choices, -1])
+ multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand([-1, self.num_choices, -1])
+ multiple_choice_input_mask = input_mask.unsqueeze(1).expand([-1, self.num_choices, -1])
+ result = model(
+ multiple_choice_inputs_ids,
+ attention_mask=multiple_choice_input_mask,
+ token_type_ids=multiple_choice_token_type_ids,
+ return_dict=self.parent.return_dict,
+ labels=choice_labels,
+ )
+
+ if token_labels is not None:
+ result = result[1:]
+ elif paddle.is_tensor(result):
+ result = [result]
+
+ self.parent.assertEqual(result[0].shape, [self.batch_size, self.num_choices])
+
+ def create_and_check_for_question_answering(
+ self,
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+
+ model = XLMRobertaForQuestionAnswering(config)
+ model.eval()
+ result = model(
+ input_ids,
+ attention_mask=input_mask,
+ token_type_ids=token_type_ids,
+ return_dict=self.parent.return_dict,
+ start_positions=sequence_labels,
+ end_positions=sequence_labels,
+ )
+
+ if sequence_labels is not None:
+ start_logits, end_logits = result[1], result[2]
+ else:
+ start_logits, end_logits = result[0], result[1]
+
+ self.parent.assertEqual(start_logits.shape, [self.batch_size, self.seq_length])
+ self.parent.assertEqual(end_logits.shape, [self.batch_size, self.seq_length])
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ token_type_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@parameterized_class(
+ ("return_dict", "use_labels"),
+ [
+ [False, False],
+ [False, True],
+ [True, False],
+ [True, True],
+ ],
+)
+class XLMRobertaModelTest(ModelTesterMixin, unittest.TestCase):
+ base_model_class = XLMRobertaModel
+ use_test_inputs_embeds: bool = False
+ return_dict: bool = False
+ use_labels: bool = False
+ test_tie_weights = True
+
+ all_model_classes = (
+ XLMRobertaForQuestionAnswering,
+ XLMRobertaForTokenClassification,
+ XLMRobertaForMultipleChoice,
+ XLMRobertaForSequenceClassification,
+ XLMRobertaForMaskedLM,
+ XLMRobertaForCausalLM,
+ )
+ all_generative_model_classes = (XLMRobertaForCausalLM,)
+
+ def setUp(self):
+ self.model_tester = XLMRobertaModelTester(self)
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_causal_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+
+ def test_for_masked_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_for_token_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
+
+ def test_for_multiple_choice(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
+
+ def test_for_question_answering(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
+
+ @slow
+ def test_model_from_pretrained(self):
+ for model_name in list(XLMRobertaPretrainedModel.pretrained_init_configuration.keys())[:1]:
+ model = XLMRobertaModel.from_pretrained(model_name)
+ self.assertIsNotNone(model)
+
+
+class XLMRobertaCompatibilityTest(unittest.TestCase):
+ test_model_id = "hf-internal-testing/tiny-random-onnx-xlm-roberta"
+
+ @classmethod
+ @require_package("transformers", "torch")
+ def setUpClass(cls) -> None:
+ from transformers import XLMRobertaModel
+
+ cls.torch_model_path = tempfile.TemporaryDirectory().name
+ model = XLMRobertaModel.from_pretrained(cls.test_model_id)
+ model.save_pretrained(cls.torch_model_path)
+
+ @require_package("transformers", "torch")
+ def test_xlmroberta_model_converter(self):
+ with tempfile.TemporaryDirectory() as tempdir:
+
+ # 1. create commmon input
+ input_ids = np.random.randint(100, 200, [1, 20])
+
+ # 2. forward the paddle model
+ from paddlenlp.transformers import XLMRobertaModel
+
+ paddle_model = XLMRobertaModel.from_pretrained(self.test_model_id, from_hf_hub=False, cache_dir=tempdir)
+ paddle_model.eval()
+ paddle_logit = paddle_model(paddle.to_tensor(input_ids))[0]
+
+ # 3. forward the torch model
+ import torch
+ from transformers import XLMRobertaModel
+
+ torch_model = XLMRobertaModel.from_pretrained(self.torch_model_path)
+ torch_model.eval()
+ torch_logit = torch_model(torch.tensor(input_ids), return_dict=False)[0]
+
+ self.assertTrue(
+ np.allclose(
+ paddle_logit.detach().cpu().reshape([-1])[:9].numpy(),
+ torch_logit.detach().cpu().reshape([-1])[:9].numpy(),
+ rtol=1e-4,
+ )
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/transformers/xlm_roberta/test_tokenizer.py b/tests/transformers/xlm_roberta/test_tokenizer.py
new file mode 100644
index 000000000000..a5dad1977828
--- /dev/null
+++ b/tests/transformers/xlm_roberta/test_tokenizer.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+from paddlenlp.transformers import XLMRobertaTokenizer
+
+from ..test_tokenizer_common import TokenizerTesterMixin
+
+# VOCAB_FILES_NAMES = XLMRobertaTokenizer.resource_files_names
+
+
+class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
+ test_offsets = False
+ tokenizer_class = XLMRobertaTokenizer
+
+ # Set up method called before each test
+ def setUp(self):
+ super().setUp()
+ self.vocab_file = "BAAI/bge-m3"
+ self.special_tokens_map = {"unk_token": ""}
+
+ # Method to get a tokenizer instance with specified keyword arguments
+ def get_tokenizer(self, **kwargs):
+ kwargs.update(self.special_tokens_map)
+ return XLMRobertaTokenizer.from_pretrained(self.vocab_file, **kwargs)
+
+ # Test method to check tokenization
+ def test_tokenization(self):
+ tokenizer = self.get_tokenizer()
+ text = "Hello, how are you?"
+ tokens = tokenizer.tokenize(text)
+ self.assertIsInstance(tokens, list)
+ self.assertGreater(len(tokens), 0)
+
+ # Test method to check conversion of token to ID
+ def test_token_to_id(self):
+ tokenizer = self.get_tokenizer()
+ token = "Hello"
+ token_id = tokenizer.convert_tokens_to_ids(token)
+ self.assertIsInstance(token_id, int)
+
+ # Test method to check conversion of ID to token
+ def test_id_to_token(self):
+ tokenizer = self.get_tokenizer()
+ token_id = tokenizer.convert_tokens_to_ids("How")
+ token = tokenizer.convert_ids_to_tokens(token_id)
+ self.assertEqual(token, "How")
+
+ # Test method to check special tokens
+ def test_special_tokens(self):
+ tokenizer = self.get_tokenizer(
+ vocab_file=self.vocab_file, cls_token="", sep_token="", pad_token=""
+ )
+ self.assertEqual(tokenizer.cls_token, "")
+ self.assertEqual(tokenizer.sep_token, "")
+ self.assertEqual(tokenizer.pad_token, "")
+
+ # Test method to check building inputs with special tokens
+ def test_build_inputs_with_special_tokens(self):
+ tokenizer = self.get_tokenizer()
+ token_ids_0 = tokenizer.convert_tokens_to_ids(["Hello", "world"])
+ token_ids_1 = tokenizer.convert_tokens_to_ids(["How", "are", "you"])
+
+ input_ids = tokenizer.build_inputs_with_special_tokens(token_ids_0, token_ids_1)
+ self.assertEqual(input_ids[0], tokenizer.cls_token_id)
+ self.assertEqual(input_ids[-1], tokenizer.sep_token_id)
+
+
+if __name__ == "__main__":
+ unittest.main()