diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 641a21018be2b..e98e5fac1d330 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -32,6 +32,7 @@ PaddleNLP提供了丰富的模型结构,包含经典的RNN类模型结构, | [ERNIESage](../examples/text_graph/erniesage)| ERNIESage(ERNIE SAmple aggreGatE) 通过Graph(图)来构建自身节点和邻居节点的连接关系,将自身节点和邻居节点的关系构建成一个关联样本输入到ERNIE中,ERNIE作为聚合函数 (Aggregators) 来表征自身节点和邻居节点的语义关系,最终强化图中节点的语义表示。| | [GPT-2](../examples/language_model/gpt2) |[Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf) | | [ELECTRA](../examples/language_model/electra/) | [ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators](https://arxiv.org/abs/2003.10555) | +| [XLNet](../examples/language_model/xlnet/) | [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) | | [RoBERTa](../examples/text_classification/pretrained_models) | [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) | | [PLATO-2](../examples/dialogue/plato-2) | 百度自研领先的开放域对话预训练模型 [PLATO-2: Towards Building an Open-Domain Chatbot via Curriculum Learning](https://arxiv.org/abs/2006.16779) | | [SentenceBERT](../examples/text_matching/sentence_transformers)| [Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks](https://arxiv.org/abs/1908.10084) | diff --git a/docs/transformers.md b/docs/transformers.md index 2e40924fb9127..c3e16f64c77d9 100644 --- a/docs/transformers.md +++ b/docs/transformers.md @@ -5,7 +5,7 @@ ## Transformer预训练模型汇总 -下表汇总了目前PaddleNLP支持的各类预训练模型。用户可以使用PaddleNLP提供的模型,完成问答、文本分类、序列标注、文本生成等任务。同时我们提供了29种预训练的参数权重供用户使用,其中包含了12种中文语言模型的预训练权重。 +下表汇总了目前PaddleNLP支持的各类预训练模型。用户可以使用PaddleNLP提供的模型,完成问答、文本分类、序列标注、文本生成等任务。同时我们提供了32种预训练的参数权重供用户使用,其中包含了15种中文语言模型的预训练权重。 | Model | Tokenizer | Supported Task | Pretrained Weight| |---|---|---|---| @@ -15,10 +15,10 @@ |[GPT-2](https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf)| GPT2Tokenizer
GPT2ChineseTokenizer| GPT2ForGreedyGeneration| `gpt2-base-cn`
`gpt2-medium-en`| |[RoBERTa](https://arxiv.org/abs/1907.11692)|RobertaTokenizer| RobertaModel
RobertaForQuestionAnswering
RobertaForSequenceClassification
RobertaForTokenClassification| `roberta-wwm-ext`
`roberta-wwm-ext-large`
`rbt3`
`rbtl3`| |[ELECTRA](https://arxiv.org/abs/2003.10555) | ElectraTokenizer| ElectraModel
ElectraForSequenceClassification
ElectraForTokenClassification
|`electra-small`
`electra-base`
`electra-large`
`chinese-electra-small`
`chinese-electra-base`
| -|[XLNet](https://arxiv.org/abs/1906.08237)| XLNetTokenizer| XLNetModel
XLNetForSequenceClassification
XLNetForTokenClassification |`xlnet-base-cased`
`xlnet-large-cased`| +|[XLNet](https://arxiv.org/abs/1906.08237)| XLNetTokenizer| XLNetModel
XLNetForSequenceClassification
XLNetForTokenClassification |`xlnet-base-cased`
`xlnet-large-cased`
`chinese-xlnet-base`
`chinese-xlnet-mid`
`chinese-xlnet-large`| |[Transformer](https://arxiv.org/abs/1706.03762) |- | TransformerModel | - | -**NOTE**:其中中文的预训练模型有`bert-base-chinese, bert-wwm-chinese, bert-wwm-ext-chinese, ernie-1.0, ernie-tiny, gpt2-base-cn, roberta-wwm-ext, roberta-wwm-ext-large, rbt3, rbtl3, chinese-electra-base, chinese-electra-small`。 +**NOTE**:其中中文的预训练模型有`bert-base-chinese, bert-wwm-chinese, bert-wwm-ext-chinese, ernie-1.0, ernie-tiny, gpt2-base-cn, roberta-wwm-ext, roberta-wwm-ext-large, rbt3, rbtl3, chinese-electra-base, chinese-electra-small, chinese-xlnet-base, chinese-xlnet-mid, chinese-xlnet-large`。 ## 预训练模型使用方法 @@ -73,7 +73,7 @@ for input_ids, token_type_ids, labels in train_dataloader: 用户可以切换表格中的不同模型,来处理相同类型的任务。如对于[预训练模型使用方法](#预训练模型使用方法)中的文本分类任务,用户可以将`BertForSequenceClassification`换成`ErnieForSequenceClassification`, 来寻找更适合的预训练模型。 ## 参考资料: -- 部分中文预训练模型来自:https://github.com/ymcui/Chinese-BERT-wwm +- 部分中文预训练模型来自:https://github.com/ymcui/Chinese-BERT-wwm, https://github.com/ymcui/Chinese-XLNet, https://huggingface.co/clue/xlnet_chinese_large - Sun, Yu, et al. "Ernie: Enhanced representation through knowledge integration." arXiv preprint arXiv:1904.09223 (2019). - Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding." arXiv preprint arXiv:1810.04805 (2018). - Cui, Yiming, et al. "Pre-training with whole word masking for chinese bert." arXiv preprint arXiv:1906.08101 (2019). diff --git a/examples/language_model/xlnet/README.md b/examples/language_model/xlnet/README.md index e40d0d8f0fac1..8a88c5c5a1263 100644 --- a/examples/language_model/xlnet/README.md +++ b/examples/language_model/xlnet/README.md @@ -19,7 +19,7 @@ ```shell pip install paddlenlp\>=2.0.0rc ``` - + * SentencePiece 安装 ```shell pip install sentencepiece @@ -63,13 +63,14 @@ python -m paddle.distributed.launch ./run_glue.py \ 基于`xlnet-base-cased`在GLUE各评测任务上Fine-tuning后,在验证集上有如下结果: -| Task | Metric | Result | +| Task | Metric | Result | |:-----:|:----------------------------:|:------------------:| -| SST-2 | Accuracy | 94.266 | -| QNLI | Accuracy | 91.708 | -| CoLA | Mattehew's corr | 50.264 | -| MRPC | F1/Accuracy | 91.071/87.745 | -| STS-B | Person/Spearman corr | 86.243/85.973 | -| QQP | Accuracy/F1 | 90.838/87.644 | -| MNLI | Matched acc/MisMatched acc | 87.468/86.859 | -| RTE | Accuracy | 70.036 | \ No newline at end of file +| SST-2 | Accuracy | 94.266 | +| QNLI | Accuracy | 91.708 | +| CoLA | Mattehew's corr | 50.264 | +| MRPC | F1/Accuracy | 91.071/87.745 | +| STS-B | Person/Spearman corr | 86.243/85.973 | +| QQP | Accuracy/F1 | 90.838/87.644 | +| MNLI | Matched acc/MisMatched acc | 87.468/86.859 | +| RTE | Accuracy | 70.036 | +| WNLI | Accuracy | 56.338 | diff --git a/examples/language_model/xlnet/run_glue.py b/examples/language_model/xlnet/run_glue.py index 4346789e021bd..93a81fcd17ca2 100644 --- a/examples/language_model/xlnet/run_glue.py +++ b/examples/language_model/xlnet/run_glue.py @@ -16,6 +16,7 @@ import os import random import time +from math import ceil from functools import partial import numpy as np @@ -41,6 +42,7 @@ "mnli": Accuracy, "qnli": Accuracy, "rte": Accuracy, + "wnli": Accuracy, } @@ -151,6 +153,7 @@ def do_train(args): paddle.distributed.init_parallel_env() set_seed(args) + global final_res args.task_name = args.task_name.lower() metric_class = METRIC_CLASSES[args.task_name] @@ -223,8 +226,12 @@ def do_train(args): if paddle.distributed.get_world_size() > 1: model = paddle.DataParallel(model) - num_training_steps = args.max_steps if args.max_steps > 0 else ( - len(train_data_loader) * args.num_train_epochs) + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = ceil(num_training_steps / len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, @@ -255,7 +262,7 @@ def do_train(args): global_step = 0 tic_train = time.time() model.train() - for epoch in range(args.num_train_epochs): + for epoch in range(num_train_epochs): for step, batch in enumerate(train_data_loader): global_step += 1 input_ids, token_type_ids, attention_mask, labels = batch @@ -277,9 +284,14 @@ def do_train(args): if global_step % args.save_steps == 0 or global_step == num_training_steps: tic_eval = time.time() if args.task_name == "mnli": + print("matched ", end="") evaluate(model, loss_fct, metric, dev_data_loader_matched) + final_res1 = "matched " + final_res + print("mismatched ", end="") evaluate(model, loss_fct, metric, dev_data_loader_mismatched) + final_res2 = "mismatched " + final_res + final_res = final_res1 + "\r\n" + final_res2 print("eval done total : %s s" % (time.time() - tic_eval)) else: evaluate(model, loss_fct, metric, dev_data_loader) @@ -297,6 +309,7 @@ def do_train(args): tokenizer.save_pretrained(output_dir) if global_step == num_training_steps: print(final_res) + exit(0) tic_train += time.time() - tic_eval diff --git a/paddlenlp/transformers/xlnet/modeling.py b/paddlenlp/transformers/xlnet/modeling.py index 924b7d76304ee..d0066ad22b690 100644 --- a/paddlenlp/transformers/xlnet/modeling.py +++ b/paddlenlp/transformers/xlnet/modeling.py @@ -21,7 +21,6 @@ from paddle.nn import Layer from .. import PretrainedModel, register_base_model - __all__ = [ "XLNetModel", "XLNetPretrainedModel", @@ -92,31 +91,33 @@ def einsum4x4(equation, x, y): class XLNetRelativeAttention(Layer): - def __init__( - self, - n_head, - d_head, - d_model, - layer_norm_eps, - dropout - ): + def __init__(self, n_head, d_head, d_model, layer_norm_eps, dropout): super(XLNetRelativeAttention, self).__init__() self.n_head = n_head self.d_head = d_head self.d_model = d_model - self.scale = 1 / (d_head ** 0.5) - - self.q = self.create_parameter([self.d_model, self.n_head * self.d_head]) - self.k = self.create_parameter([self.d_model, self.n_head * self.d_head]) - self.v = self.create_parameter([self.d_model, self.n_head * self.d_head]) - self.o = self.create_parameter([self.d_model, self.n_head * self.d_head]) - self.r = self.create_parameter([self.d_model, self.n_head * self.d_head]) - - self.r_r_bias = self.create_parameter([self.n_head, self.d_head], is_bias=True) - self.r_s_bias = self.create_parameter([self.n_head, self.d_head], is_bias=True) - self.r_w_bias = self.create_parameter([self.n_head, self.d_head], is_bias=True) - self.seg_embed = self.create_parameter([2, self.n_head, self.d_head], is_bias=False) + self.scale = 1 / (d_head**0.5) + + self.q = self.create_parameter( + [self.d_model, self.n_head * self.d_head]) + self.k = self.create_parameter( + [self.d_model, self.n_head * self.d_head]) + self.v = self.create_parameter( + [self.d_model, self.n_head * self.d_head]) + self.o = self.create_parameter( + [self.d_model, self.n_head * self.d_head]) + self.r = self.create_parameter( + [self.d_model, self.n_head * self.d_head]) + + self.r_r_bias = self.create_parameter( + [self.n_head, self.d_head], is_bias=True) + self.r_s_bias = self.create_parameter( + [self.n_head, self.d_head], is_bias=True) + self.r_w_bias = self.create_parameter( + [self.n_head, self.d_head], is_bias=True) + self.seg_embed = self.create_parameter( + [2, self.n_head, self.d_head], is_bias=False) self.layer_norm = nn.LayerNorm(d_model, epsilon=layer_norm_eps) self.dropout = nn.Dropout(dropout) @@ -132,7 +133,9 @@ def rel_shift_bnij(x, klen=-1): x = paddle.reshape(x, [x_size[0], x_size[1], x_size[3], x_size[2]]) x = x[:, :, 1:, :] x = paddle.reshape(x, [x_size[0], x_size[1], x_size[2], x_size[3] - 1]) - x = paddle.index_select(x, index=paddle.arange(klen, dtype='int64'), axis=3) + x = paddle.index_select( + x, index=paddle.arange( + klen, dtype='int64'), axis=3) return x def rel_attn_core( @@ -144,8 +147,7 @@ def rel_attn_core( seg_mat=None, attn_mask=None, head_mask=None, - output_attentions=False, - ): + output_attentions=False, ): """Core relative positional attention operations.""" # Content based attention score (refer to the Transformer-XL paper) @@ -195,7 +197,9 @@ def post_attention(self, h, attn_vec, residual=True): # Compute einsum4x4("ibnd,hnd->ibh", attn_vec, self.o) shape = attn_vec.shape attn_vec = attn_vec.reshape([shape[0] * shape[1], -1]) - attn_out = paddle.matmul(attn_vec, self.o, transpose_y=True).reshape([shape[0], shape[1], -1]) + attn_out = paddle.matmul( + attn_vec, self.o, + transpose_y=True).reshape([shape[0], shape[1], -1]) attn_out = self.dropout(attn_out) if residual: @@ -215,8 +219,7 @@ def forward( mems=None, target_mapping=None, head_mask=None, - output_attentions=False, - ): + output_attentions=False, ): if g is not None: # Two-stream attention with relative positional encoding. # Content based attention score @@ -228,23 +231,31 @@ def forward( # Content-based key head # Compute k_head_h = einsum4x4("ibh,h(n*d)->ibnd", cat, self.k) k_head_h = paddle.matmul(cat, self.k) - k_head_h = paddle.reshape(k_head_h, shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) + k_head_h = paddle.reshape( + k_head_h, + shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) # Content-based value head # Compute v_head_h = einsum4x4("ibh,h(n*d)->ibnd", cat, self.v) v_head_h = paddle.matmul(cat, self.v) - v_head_h = paddle.reshape(v_head_h, shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) + v_head_h = paddle.reshape( + v_head_h, + shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) # Position-based key head # Compute k_head_r = einsum4x4("ibh,h(n*d)->ibnd", r, self.r) k_head_r = paddle.matmul(r, self.r) - k_head_r = paddle.reshape(k_head_r, shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) + k_head_r = paddle.reshape( + k_head_r, + shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) # H-stream # Content-stream query head # Compute q_head_h = einsum4x4("ibh,h(n*d)->ibnd", h, self.q) q_head_h = paddle.matmul(h, self.q) # shape - q_head_h = paddle.reshape(q_head_h, shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) + q_head_h = paddle.reshape( + q_head_h, + shape=[cat.shape[0], cat.shape[1], self.n_head, self.d_head]) # Core attention ops attn_vec_h = self.rel_attn_core( @@ -255,8 +266,7 @@ def forward( seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask, - output_attentions=output_attentions, - ) + output_attentions=output_attentions, ) if output_attentions: attn_vec_h, attn_prob_h = attn_vec_h @@ -268,14 +278,16 @@ def forward( # Query-stream query head # Compute q_head_g = einsum4x4("ibh,hnd->ibnd", g, self.q) shape = g.shape - q_head_g = paddle.matmul(g, self.q).reshape([shape[0], shape[1], self.n_head, self.d_head]) + q_head_g = paddle.matmul(g, self.q).reshape( + [shape[0], shape[1], self.n_head, self.d_head]) # Core attention ops if target_mapping is not None: # Compute q_head_g = einsum4x4("mbnd,mlb->lbnd", q_head_g, target_mapping) q_head_g = q_head_g.transpose([1, 2, 3, 0]) target_mapping = target_mapping.transpose([2, 0, 1]) - q_head_g = paddle.matmul(q_head_g, target_mapping).transpose([3, 0, 1, 2]) + q_head_g = paddle.matmul(q_head_g, target_mapping).transpose( + [3, 0, 1, 2]) attn_vec_g = self.rel_attn_core( q_head_g, @@ -285,8 +297,7 @@ def forward( seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask, - output_attentions=output_attentions, - ) + output_attentions=output_attentions, ) if output_attentions: attn_vec_g, attn_prob_g = attn_vec_g @@ -294,7 +305,8 @@ def forward( # Compute attn_vec_g = einsum4x4("lbnd,mlb->mbnd", attn_vec_g, target_mapping) attn_vec_g = attn_vec_g.transpose([1, 2, 3, 0]) target_mapping = target_mapping.transpose([2, 1, 0]) - attn_vec_g = paddle.matmul(attn_vec_g, target_mapping).transpose([3, 0, 1, 2]) + attn_vec_g = paddle.matmul( + attn_vec_g, target_mapping).transpose([3, 0, 1, 2]) else: attn_vec_g = self.rel_attn_core( q_head_g, @@ -304,8 +316,7 @@ def forward( seg_mat=seg_mat, attn_mask=attn_mask_g, head_mask=head_mask, - output_attentions=output_attentions, - ) + output_attentions=output_attentions, ) if output_attentions: attn_vec_g, attn_prob_g = attn_vec_g @@ -326,20 +337,28 @@ def forward( # Content heads # Compute q_head_h = einsum4x4("ibh,hnd->ibnd", h, self.q) q_head_h = paddle.matmul(h, self.q) - q_head_h = paddle.reshape(q_head_h, shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) + q_head_h = paddle.reshape( + q_head_h, + shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) # Compute k_head_h = einsum4x4("ibh,hnd->ibnd", cat, self.k) k_head_h = paddle.matmul(cat, self.k) - k_head_h = paddle.reshape(k_head_h, shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) + k_head_h = paddle.reshape( + k_head_h, + shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) # Compute v_head_h = einsum4x4("ibh,hnd->ibnd", cat, self.v) v_head_h = paddle.matmul(cat, self.v) - v_head_h = paddle.reshape(v_head_h, shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) + v_head_h = paddle.reshape( + v_head_h, + shape=[h.shape[0], h.shape[1], self.n_head, self.d_head]) # Position-based key head # Compute k_head_r = einsum4x4("ibh,hnd->ibnd", r, self.r) k_head_r = paddle.matmul(r, self.r) - k_head_r = paddle.reshape(k_head_r, shape=[k_head_r.shape[0], -1, self.n_head, self.d_head]) + k_head_r = paddle.reshape( + k_head_r, + shape=[k_head_r.shape[0], -1, self.n_head, self.d_head]) # Core attention ops attn_vec = self.rel_attn_core( @@ -350,8 +369,7 @@ def forward( seg_mat=seg_mat, attn_mask=attn_mask_h, head_mask=head_mask, - output_attentions=output_attentions, - ) + output_attentions=output_attentions, ) if output_attentions: attn_vec, attn_prob = attn_vec @@ -363,7 +381,7 @@ def forward( outputs = (output_h, output_g) if output_attentions: - outputs = outputs + (attn_prob,) + outputs = outputs + (attn_prob, ) return outputs @@ -374,8 +392,7 @@ def __init__( d_inner, layer_norm_eps, dropout, - ff_activation, - ): + ff_activation, ): super(XLNetFeedForward, self).__init__() self.layer_norm = nn.LayerNorm(d_model, epsilon=layer_norm_eps) @@ -407,12 +424,13 @@ def __init__( layer_norm_eps, dropout, d_inner, - ff_activation, - ): + ff_activation, ): super(XLNetLayer, self).__init__() - self.rel_attn = XLNetRelativeAttention(n_head, d_head, d_model, layer_norm_eps, dropout) - self.ff = XLNetFeedForward(d_model, d_inner, layer_norm_eps, dropout, ff_activation) + self.rel_attn = XLNetRelativeAttention(n_head, d_head, d_model, + layer_norm_eps, dropout) + self.ff = XLNetFeedForward(d_model, d_inner, layer_norm_eps, dropout, + ff_activation) self.seq_len_dim = 1 def forward( @@ -426,8 +444,7 @@ def forward( mems=None, target_mapping=None, head_mask=None, - output_attentions=False, - ): + output_attentions=False, ): outputs = self.rel_attn( output_h, @@ -439,8 +456,7 @@ def forward( mems=mems, target_mapping=target_mapping, head_mask=head_mask, - output_attentions=output_attentions, - ) + output_attentions=output_attentions, ) output_h, output_g = outputs[:2] @@ -448,7 +464,8 @@ def forward( output_g = self.ff(output_g) output_h = self.ff(output_h) - outputs = (output_h, output_g) + outputs[2:] # Add again attentions if they are there + outputs = (output_h, output_g + ) + outputs[2:] # Add again attentions if they are there return outputs @@ -481,7 +498,6 @@ class XLNetPretrainedModel(PretrainedModel): "same_length": False, "vocab_size": 32000 }, - "xlnet-large-cased": { "attn_type": "bi", "bi_data": False, @@ -500,16 +516,79 @@ class XLNetPretrainedModel(PretrainedModel): "reuse_len": None, "same_length": False, "vocab_size": 32000 - } + }, + "chinese-xlnet-base": { + "attn_type": "bi", + "bi_data": False, + "clamp_len": -1, + "d_head": 64, + "d_inner": 3072, + "d_model": 768, + "dropout": 0.1, + "classifier_dropout": 0.1, + "ff_activation": "relu", + "initializer_range": 0.02, + "layer_norm_eps": 1e-12, + "mem_len": None, + "n_head": 12, + "n_layer": 12, + "reuse_len": None, + "same_length": False, + "vocab_size": 32000 + }, + "chinese-xlnet-mid": { + "attn_type": "bi", + "bi_data": False, + "clamp_len": -1, + "d_head": 64, + "d_inner": 3072, + "d_model": 768, + "dropout": 0.1, + "classifier_dropout": 0.1, + "ff_activation": "relu", + "initializer_range": 0.02, + "layer_norm_eps": 1e-12, + "mem_len": None, + "n_head": 12, + "n_layer": 24, + "reuse_len": None, + "same_length": False, + "vocab_size": 32000 + }, + "chinese-xlnet-large": { + "attn_type": "bi", + "bi_data": False, + "clamp_len": -1, + "d_head": 64, + "d_inner": 4096, + "d_model": 1024, + "dropout": 0.1, + "classifier_dropout": 0.1, + "ff_activation": "relu", + "initializer_range": 0.02, + "layer_norm_eps": 1e-12, + "mem_len": None, + "n_head": 16, + "n_layer": 24, + "reuse_len": None, + "same_length": False, + "vocab_size": 32000 + }, } resource_files_names = {"model_state": "model_state.pdparams"} pretrained_resource_files_map = { "model_state": { "xlnet-base-cased": - "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-base-cased.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-base-cased.pdparams", "xlnet-large-cased": - "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-large-cased.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-large-cased.pdparams", + "chinese-xlnet-base": + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/chinese-xlnet-base.pdparams", + "chinese-xlnet-mid": + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/chinese-xlnet-mid.pdparams", + "chinese-xlnet-large": + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/chinese-xlnet-large.pdparams", } } base_model_prefix = "xlnet" @@ -536,15 +615,15 @@ def _init_weights(self, layer): layer.weight.set_value(paddle.full_like(layer.weight, 1.0)) elif isinstance(layer, XLNetRelativeAttention): for param in [ - layer.q, - layer.k, - layer.v, - layer.o, - layer.r, - layer.r_r_bias, - layer.r_s_bias, - layer.r_w_bias, - layer.seg_embed, + layer.q, + layer.k, + layer.v, + layer.o, + layer.r, + layer.r_r_bias, + layer.r_s_bias, + layer.r_w_bias, + layer.seg_embed, ]: param.set_value( paddle.tensor.normal( @@ -552,8 +631,7 @@ def _init_weights(self, layer): std=self.initializer_range if hasattr(self, "initializer_range") else self.transformer.config["initializer_range"], - shape=param.shape) - ) + shape=param.shape)) elif isinstance(layer, XLNetModel): layer.mask_emb.set_value( paddle.tensor.normal( @@ -561,8 +639,7 @@ def _init_weights(self, layer): std=self.initializer_range if hasattr(self, "initializer_range") else self.transformer.config["initializer_range"], - shape=layer.mask_emb.shape) - ) + shape=layer.mask_emb.shape)) @register_base_model @@ -585,8 +662,7 @@ def __init__( layer_norm_eps=1e-12, d_inner=3072, ff_activation="gelu", - initializer_range=0.02, - ): + initializer_range=0.02, ): super(XLNetModel, self).__init__() self.initializer_range = initializer_range self.mem_len = mem_len @@ -600,18 +676,16 @@ def __init__( self.dropout = nn.Dropout(dropout) self.word_embedding = nn.Embedding(vocab_size, d_model) self.mask_emb = self.create_parameter([1, 1, d_model]) - self.layer = nn.LayerList( - [ - XLNetLayer( - n_head, - d_head, - d_model, - layer_norm_eps, - dropout, - d_inner, - ff_activation, - ) for _ in range(n_layer) - ]) + self.layer = nn.LayerList([ + XLNetLayer( + n_head, + d_head, + d_model, + layer_norm_eps, + dropout, + d_inner, + ff_activation, ) for _ in range(n_layer) + ]) self.init_weights() @@ -649,14 +723,15 @@ def create_mask(self, qlen, mlen): ret = paddle.concat([attn_mask_pad, mask_up], axis=1) if self.same_length: mask_lo = paddle.tril(attn_mask, diagonal=-1) - ret = paddle.concat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], axis=1) + ret = paddle.concat( + [ret[:, :qlen] + mask_lo, ret[:, qlen:]], axis=1) return ret def cache_mem(self, curr_out, prev_mem): # Cache hidden states into memory. if self.reuse_len is not None and self.reuse_len > 0: - curr_out = curr_out[: self.reuse_len] + curr_out = curr_out[:self.reuse_len] if self.mem_len is None or self.mem_len == 0: # If :obj:`use_mems` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time @@ -677,8 +752,10 @@ def cache_mem(self, curr_out, prev_mem): @staticmethod def positional_embedding(pos_seq, inv_freq, bsz=None): # Compute sinusoid_inp = einsum4x4("i,d->id", pos_seq, inv_freq) - sinusoid_inp = paddle.matmul(pos_seq.reshape([-1, 1]), inv_freq.reshape([1, -1])) - pos_emb = paddle.concat([paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) + sinusoid_inp = paddle.matmul( + pos_seq.reshape([-1, 1]), inv_freq.reshape([1, -1])) + pos_emb = paddle.concat( + [paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)], axis=-1) pos_emb = paddle.unsqueeze(pos_emb, axis=1) if bsz is not None: pos_emb = pos_emb.expand([-1, bsz, -1]) @@ -689,7 +766,7 @@ def positional_embedding(pos_seq, inv_freq, bsz=None): def relative_positional_encoding(self, qlen, klen, bsz=None): # Create relative positional encoding. freq_seq = paddle.arange(0, self.d_model, 2.0, dtype=dtype_float) - inv_freq = 1 / 10000 ** (freq_seq / self.d_model) + inv_freq = 1 / 10000**(freq_seq / self.d_model) if self.attn_type == "bi": beg, end = klen, -qlen @@ -707,8 +784,10 @@ def relative_positional_encoding(self, qlen, klen, bsz=None): bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) if bsz is not None: - fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) - bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) + fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, + bsz // 2) + bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, + bsz // 2) else: fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq) bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq) @@ -735,8 +814,7 @@ def forward( use_mems_eval=False, output_attentions=False, output_hidden_states=False, - return_dict=False, - ): + return_dict=False, ): if self.training: use_mems = use_mems_train @@ -747,7 +825,9 @@ def forward( # but we want a unified interface in the library with the batch size on the first dimension # so we move here the first dimension (batch) to the end if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_ids = paddle.transpose(input_ids, perm=[1, 0]) qlen, bsz = input_ids.shape[0], input_ids.shape[1] @@ -755,15 +835,22 @@ def forward( inputs_embeds = paddle.transpose(inputs_embeds, perm=[1, 0]) qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1] else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - token_type_ids = token_type_ids.transpose([1, 0]) if token_type_ids is not None else None - input_mask = input_mask.transpose([1, 0]) if input_mask is not None else None - attention_mask = attention_mask.transpose([1, 0]) if attention_mask is not None else None - perm_mask = perm_mask.transpose([1, 2, 0]) if perm_mask is not None else None - target_mapping = target_mapping.transpose([1, 2, 0]) if target_mapping is not None else None - - mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0 + raise ValueError( + "You have to specify either input_ids or inputs_embeds") + + token_type_ids = token_type_ids.transpose( + [1, 0]) if token_type_ids is not None else None + input_mask = input_mask.transpose( + [1, 0]) if input_mask is not None else None + attention_mask = attention_mask.transpose( + [1, 0]) if attention_mask is not None else None + perm_mask = perm_mask.transpose( + [1, 2, 0]) if perm_mask is not None else None + target_mapping = target_mapping.transpose( + [1, 2, 0]) if target_mapping is not None else None + + mlen = mems[0].shape[0] if mems is not None and mems[ + 0] is not None else 0 klen = mlen + qlen # Attention mask @@ -774,7 +861,8 @@ def forward( elif self.attn_type == "bi": attn_mask = None else: - raise ValueError("Unsupported attention type: {}".format(self.attn_type)) + raise ValueError("Unsupported attention type: {}".format( + self.attn_type)) # Data mask: input mask & perm mask assert input_mask is None or attention_mask is None, "You can only use one of input_mask (uses 1 for padding) " @@ -793,7 +881,9 @@ def forward( if data_mask is not None: # All mems can be attended to if mlen > 0: - mems_mask = paddle.cast(paddle.zeros([data_mask.shape[0], mlen, bsz]), dtype=dtype_float) + mems_mask = paddle.cast( + paddle.zeros([data_mask.shape[0], mlen, bsz]), + dtype=dtype_float) data_mask = paddle.concat([mems_mask, data_mask], axis=1) if attn_mask is None: attn_mask = paddle.unsqueeze(data_mask, axis=-1) @@ -808,9 +898,16 @@ def forward( if mlen > 0: non_tgt_mask = paddle.concat( - [paddle.cast(paddle.zeros([qlen, mlen]), dtype=dtype_float), non_tgt_mask], axis=-1) + [ + paddle.cast( + paddle.zeros([qlen, mlen]), dtype=dtype_float), + non_tgt_mask + ], + axis=-1) non_tgt_mask = paddle.cast( - ((attn_mask + paddle.unsqueeze(non_tgt_mask, axis=[2, 3])) > 0), dtype=dtype_float) + ((attn_mask + paddle.unsqueeze( + non_tgt_mask, axis=[2, 3])) > 0), + dtype=dtype_float) else: non_tgt_mask = None @@ -822,7 +919,8 @@ def forward( output_h = self.dropout(word_emb_k) if target_mapping is not None: - word_emb_q = self.mask_emb.expand([target_mapping.shape[0], bsz, -1]) + word_emb_q = self.mask_emb.expand( + [target_mapping.shape[0], bsz, -1]) output_g = self.dropout(word_emb_q) else: output_g = None @@ -838,9 +936,13 @@ def forward( # `1` indicates not in the same segment [qlen x klen x bsz] seg_mat = paddle.cast( - paddle.unsqueeze(token_type_ids, axis=1) != paddle.unsqueeze(cat_ids, axis=0), + paddle.unsqueeze( + token_type_ids, axis=1) != paddle.unsqueeze( + cat_ids, axis=0), dtype='int64') - seg_mat = paddle.cast(F.one_hot(seg_mat, num_classes=2), dtype=dtype_float) + seg_mat = paddle.cast( + F.one_hot( + seg_mat, num_classes=2), dtype=dtype_float) else: seg_mat = None @@ -855,7 +957,8 @@ def forward( # And head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] if head_mask is not None: if head_mask.dim() == 1: - head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze( + 0).unsqueeze(0) head_mask = head_mask.expand([self.n_layer, -1, -1, -1, -1]) elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) @@ -871,9 +974,10 @@ def forward( for i, layer_module in enumerate(self.layer): if use_mems: # Cache new mems - new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) + new_mems = new_mems + (self.cache_mem(output_h, mems[i]), ) if output_hidden_states: - hidden_states.append((output_h, output_g) if output_g is not None else output_h) + hidden_states.append((output_h, output_g) + if output_g is not None else output_h) outputs = layer_module( output_h, @@ -885,8 +989,7 @@ def forward( mems=mems[i], target_mapping=target_mapping, head_mask=head_mask[i], - output_attentions=output_attentions, - ) + output_attentions=output_attentions, ) output_h, output_g = outputs[:2] if output_attentions: @@ -894,7 +997,8 @@ def forward( # Add last hidden state if output_hidden_states: - hidden_states.append((output_h, output_g) if output_g is not None else output_h) + hidden_states.append((output_h, output_g) + if output_g is not None else output_h) output = self.dropout(output_g if output_g is not None else output_h) @@ -906,26 +1010,37 @@ def forward( if output_hidden_states: if output_g is not None: - hidden_states = tuple(paddle.transpose(h, perm=[1, 0, 2]) for hs in hidden_states for h in hs) + hidden_states = tuple( + paddle.transpose( + h, perm=[1, 0, 2]) for hs in hidden_states for h in hs) else: - hidden_states = tuple(paddle.transpose(hs, perm=[1, 0, 2]) for hs in hidden_states) + hidden_states = tuple( + paddle.transpose( + hs, perm=[1, 0, 2]) for hs in hidden_states) if output_attentions: if target_mapping is not None: # When target_mapping is provided, there are 2-tuple of attentions attentions = tuple( - tuple(paddle.transpose(att_stream, perm=[2, 3, 0, 1]) for att_stream in t) for t in attentions - ) + tuple( + paddle.transpose( + att_stream, perm=[2, 3, 0, 1]) for att_stream in t) + for t in attentions) else: - attentions = tuple(paddle.transpose(t, perm=[2, 3, 0, 1]) for t in attentions) + attentions = tuple( + paddle.transpose( + t, perm=[2, 3, 0, 1]) for t in attentions) if not return_dict: - return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None) - return {"last_hidden_state": output, - "mems": new_mems, - "hidden_states": hidden_states, - "attentions": attentions, - } + return tuple( + v for v in [output, new_mems, hidden_states, attentions] + if v is not None) + return { + "last_hidden_state": output, + "mems": new_mems, + "hidden_states": hidden_states, + "attentions": attentions, + } class XLNetClassificationHead(Layer): @@ -962,27 +1077,25 @@ def __init__(self, xlnet, num_classes=2): self.transformer = xlnet self.classifier = XLNetClassificationHead( self.transformer.d_model, - self.transformer.config["classifier_dropout"], - num_classes) + self.transformer.config["classifier_dropout"], num_classes) self.init_weights() def forward( - self, - input_ids=None, - token_type_ids=None, - attention_mask=None, - mems=None, - perm_mask=None, - target_mapping=None, - input_mask=None, - head_mask=None, - inputs_embeds=None, - use_mems_train=False, - use_mems_eval=False, - output_attentions=False, - output_hidden_states=False, - return_dict=False, - ): + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + mems=None, + perm_mask=None, + target_mapping=None, + input_mask=None, + head_mask=None, + inputs_embeds=None, + use_mems_train=False, + use_mems_eval=False, + output_attentions=False, + output_hidden_states=False, + return_dict=False, ): transformer_outputs = self.transformer( input_ids, @@ -998,17 +1111,18 @@ def forward( use_mems_eval=use_mems_eval, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - output = transformer_outputs[0] if not return_dict else transformer_outputs["last_hidden_state"] + return_dict=return_dict, ) + output = transformer_outputs[ + 0] if not return_dict else transformer_outputs["last_hidden_state"] logits = self.classifier(output) if not return_dict: - return (logits,) + transformer_outputs[1:] - return {"logits": logits, - "mems": transformer_outputs["mems"], - "hidden_states": transformer_outputs["hidden_states"], - "attentions": transformer_outputs["attentions"], - } + return (logits, ) + transformer_outputs[1:] + return { + "logits": logits, + "mems": transformer_outputs["mems"], + "hidden_states": transformer_outputs["hidden_states"], + "attentions": transformer_outputs["attentions"], + } class XLNetForTokenClassification(XLNetPretrainedModel): @@ -1030,22 +1144,21 @@ def __init__(self, xlnet, num_classes=2): self.init_weights() def forward( - self, - input_ids=None, - token_type_ids=None, - attention_mask=None, - mems=None, - perm_mask=None, - target_mapping=None, - input_mask=None, - head_mask=None, - inputs_embeds=None, - use_mems_train=False, - use_mems_eval=False, - output_attentions=False, - output_hidden_states=False, - return_dict=False, - ): + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + mems=None, + perm_mask=None, + target_mapping=None, + input_mask=None, + head_mask=None, + inputs_embeds=None, + use_mems_train=False, + use_mems_eval=False, + output_attentions=False, + output_hidden_states=False, + return_dict=False, ): transformer_outputs = self.transformer( input_ids, token_type_ids=token_type_ids, @@ -1060,17 +1173,18 @@ def forward( use_mems_eval=use_mems_eval, output_attentions=output_attentions, output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = transformer_outputs[0] if not return_dict else transformer_outputs["last_hidden_state"] + return_dict=return_dict, ) + sequence_output = transformer_outputs[ + 0] if not return_dict else transformer_outputs["last_hidden_state"] logits = self.classifier(sequence_output) if not return_dict: - return (logits,) + transformer_outputs[1:] + return (logits, ) + transformer_outputs[1:] - return {"logits": logits, - "mems": transformer_outputs["mems"], - "hidden_states": transformer_outputs["hidden_states"], - "attentions": transformer_outputs["attentions"], - } + return { + "logits": logits, + "mems": transformer_outputs["mems"], + "hidden_states": transformer_outputs["hidden_states"], + "attentions": transformer_outputs["attentions"], + } diff --git a/paddlenlp/transformers/xlnet/tokenizer.py b/paddlenlp/transformers/xlnet/tokenizer.py index bc97c0d17c4fd..19212354f9775 100644 --- a/paddlenlp/transformers/xlnet/tokenizer.py +++ b/paddlenlp/transformers/xlnet/tokenizer.py @@ -14,7 +14,6 @@ # limitations under the License. """ Tokenization classes for XLNet model.""" - import os import unicodedata from shutil import copyfile @@ -43,9 +42,15 @@ class XLNetTokenizer(PretrainedTokenizer): pretrained_resource_files_map = { "vocab_file": { "xlnet-base-cased": - "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-base-cased-spiece.model", + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-base-cased-spiece.model", "xlnet-large-cased": - "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-large-cased-spiece.model", + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/xlnet-large-cased-spiece.model", + "chinese-xlnet-base": + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/chinese-xlnet-base-spiece.model", + "chinese-xlnet-mid": + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/chinese-xlnet-mid-spiece.model", + "chinese-xlnet-large": + "https://paddlenlp.bj.bcebos.com/models/transformers/xlnet/chinese-xlnet-large-spiece.model", } } pretrained_init_configuration = { @@ -55,30 +60,41 @@ class XLNetTokenizer(PretrainedTokenizer): "xlnet-large-cased": { "do_lower_case": False }, + "chinese-xlnet-base": { + "do_lower_case": False + }, + "chinese-xlnet-mid": { + "do_lower_case": False + }, + "chinese-xlnet-large": { + "do_lower_case": False + }, } pretrained_positional_embedding_sizes = { "xlnet-base-cased": None, "xlnet-large-cased": None, + "chinese-xlnet-base": None, + "chinese-xlnet-mid": None, + "chinese-xlnet-large": None, } max_model_input_sizes = pretrained_positional_embedding_sizes padding_side = "left" pad_token_type_id = 3 def __init__( - self, - vocab_file, - do_lower_case=False, - remove_space=True, - keep_accents=False, - bos_token="", - eos_token="", - unk_token="", - sep_token="", - pad_token="", - cls_token="", - mask_token="", - additional_special_tokens=["", ""], - ): + self, + vocab_file, + do_lower_case=False, + remove_space=True, + keep_accents=False, + bos_token="", + eos_token="", + unk_token="", + sep_token="", + pad_token="", + cls_token="", + mask_token="", + additional_special_tokens=["", ""], ): self.do_lower_case = do_lower_case self.remove_space = remove_space @@ -93,7 +109,10 @@ def vocab_size(self): return len(self.sp_model) def get_vocab(self): - vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab = { + self.convert_ids_to_tokens(i): i + for i in range(self.vocab_size) + } vocab.update(self.added_tokens_encoder) return vocab @@ -117,7 +136,8 @@ def preprocess_text(self, inputs): if not self.keep_accents: outputs = unicodedata.normalize("NFKD", outputs) - outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) + outputs = "".join( + [c for c in outputs if not unicodedata.combining(c)]) if self.do_lower_case: outputs = outputs.lower() @@ -134,8 +154,10 @@ def _tokenize(self, text, sample=False): new_pieces = [] for piece in pieces: if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit(): - cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, "")) - if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE: + cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace( + SPIECE_UNDERLINE, "")) + if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][ + 0] == SPIECE_UNDERLINE: if len(cur_pieces[0]) == 1: cur_pieces = cur_pieces[1:] else: @@ -176,7 +198,10 @@ def convert_ids_to_tokens(self, ids, skip_special_tokens=False): return self._convert_id_to_token(ids) tokens = [self._convert_id_to_token(_id) for _id in ids] if skip_special_tokens: - return [token for token in tokens if token not in self.all_special_tokens] + return [ + token for token in tokens + if token not in self.all_special_tokens + ] return tokens def convert_tokens_to_string(self, tokens): @@ -201,11 +226,13 @@ def num_special_tokens_to_add(self, pair=False): """ token_ids_0 = [] token_ids_1 = [] - return len(self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 if pair else None)) + return len( + self.build_inputs_with_special_tokens(token_ids_0, token_ids_1 + if pair else None)) def build_inputs_with_special_tokens( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + self, token_ids_0: List[int], + token_ids_1: Optional[List[int]]=None) -> List[int]: """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. An XLNet sequence has the following format: @@ -251,11 +278,15 @@ def build_offset_mapping_with_special_tokens(self, if offset_mapping_1 is None: return offset_mapping_0 + [(0, 0)] + [(0, 0)] - return offset_mapping_0 + [(0, 0)] + offset_mapping_1 + [(0, 0)] + [(0, 0)] + return offset_mapping_0 + [(0, 0)] + offset_mapping_1 + [(0, 0)] + [ + (0, 0) + ] def get_special_tokens_mask( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False - ) -> List[int]: + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]]=None, + already_has_special_tokens: bool=False) -> List[int]: """ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding special tokens using the tokenizer ``prepare_for_model`` method. @@ -278,15 +309,18 @@ def get_special_tokens_mask( "You should not supply a second sequence if the provided sequence of " "ids is already formatted with special tokens for the model." ) - return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) + return list( + map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, + token_ids_0)) if token_ids_1 is not None: - return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1, 1] + return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1) + ) + [1, 1] return ([0] * len(token_ids_0)) + [1, 1] def create_token_type_ids_from_sequences( - self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None - ) -> List[int]: + self, token_ids_0: List[int], + token_ids_1: Optional[List[int]]=None) -> List[int]: """ Create a mask from the two sequences passed to be used in a sequence-pair classification task. An XLNet sequence pair mask has the following format: @@ -313,7 +347,8 @@ def create_token_type_ids_from_sequences( if token_ids_1 is None: return len(token_ids_0 + sep) * [0] + cls_segment_id - return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id + return len(token_ids_0 + sep) * [0] + len(token_ids_1 + + sep) * [1] + cls_segment_id def save_resources(self, save_directory): """