-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PaddlePaddle Hackathon 54 提交 (#1086)
* update * add community/junnyu * update electra docs * update electra docs * update * update * add import * update * update md * fix attention_mask bug Co-authored-by: Zeyu Chen <chenzeyu01@baidu.com> Co-authored-by: yingyibiao <yyb0576@163.com>
- Loading branch information
1 parent
089f8ae
commit 723becf
Showing
19 changed files
with
836 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import torch | ||
import numpy as np | ||
import paddlenlp.transformers as ppnlp | ||
import transformers as hgnlp | ||
|
||
|
||
def compare(a, b): | ||
a = a.cpu().numpy() | ||
b = b.cpu().numpy() | ||
meandif = np.abs(a - b).mean() | ||
maxdif = np.abs(a - b).max() | ||
print("mean dif:", meandif) | ||
print("max dif:", maxdif) | ||
|
||
|
||
def compare_discriminator( | ||
path="junnyu/hfl-chinese-electra-180g-base-discriminator"): | ||
pdmodel = ppnlp.ElectraDiscriminator.from_pretrained(path) | ||
ptmodel = ppnlp.ElectraForPreTraining.from_pretrained(path).cuda() | ||
tokenizer = ppnlp.ElectraTokenizer.from_pretrained(path) | ||
pdmodel.eval() | ||
ptmodel.eval() | ||
text = "欢迎使用paddlenlp!" | ||
pdinputs = { | ||
k: paddle.to_tensor( | ||
v, dtype="int64").unsqueeze(0) | ||
for k, v in tokenizer(text).items() | ||
} | ||
ptinputs = { | ||
k: torch.tensor( | ||
v, dtype=torch.long).unsqueeze(0).cuda() | ||
for k, v in tokenizer(text).items() | ||
} | ||
with paddle.no_grad(): | ||
pd_logits = pdmodel(**pdinputs) | ||
|
||
with torch.no_grad(): | ||
pt_logits = ptmodel(**ptinputs).logits | ||
|
||
compare(pd_logits, pt_logits) | ||
|
||
|
||
def compare_generator(): | ||
text = "本院经审查认为,本案[MASK]民间借贷纠纷申请再审案件,应重点审查二审判决是否存在错误的情形。" | ||
# ppnlp | ||
path = "junnyu/hfl-chinese-legal-electra-small-generator" | ||
model = ppnlp.ElectraForMaskedLM.from_pretrained(path) | ||
tokenizer = ppnlp.ElectraTokenizer.from_pretrained(path) | ||
model.eval() | ||
tokens = ["[CLS]"] | ||
text_list = text.split("[MASK]") | ||
for i, t in enumerate(text_list): | ||
tokens.extend(tokenizer.tokenize(t)) | ||
if i == len(text_list) - 1: | ||
tokens.extend(["[SEP]"]) | ||
else: | ||
tokens.extend(["[MASK]"]) | ||
|
||
input_ids_list = tokenizer.convert_tokens_to_ids(tokens) | ||
input_ids = paddle.to_tensor([input_ids_list]) | ||
with paddle.no_grad(): | ||
pd_outputs = model(input_ids)[0] | ||
pd_outputs_sentence = "paddle: " | ||
for i, id in enumerate(input_ids_list): | ||
if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]: | ||
scores, index = paddle.nn.functional.softmax(pd_outputs[i], | ||
-1).topk(5) | ||
tokens = tokenizer.convert_ids_to_tokens(index.tolist()) | ||
outputs = [] | ||
for score, tk in zip(scores.tolist(), tokens): | ||
outputs.append(f"{tk}={score}") | ||
pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " " | ||
else: | ||
pd_outputs_sentence += "".join( | ||
tokenizer.convert_ids_to_tokens( | ||
[id], skip_special_tokens=True)) + " " | ||
|
||
print(pd_outputs_sentence) | ||
|
||
# transformers | ||
path = "hfl/chinese-legal-electra-small-generator" | ||
config = hgnlp.ElectraConfig.from_pretrained(path) | ||
config.hidden_size = 64 | ||
config.intermediate_size = 256 | ||
config.num_attention_heads = 1 | ||
model = hgnlp.ElectraForMaskedLM.from_pretrained(path, config=config) | ||
tokenizer = hgnlp.ElectraTokenizer.from_pretrained(path) | ||
model.eval() | ||
|
||
inputs = tokenizer(text, return_tensors="pt") | ||
|
||
with torch.no_grad(): | ||
pt_outputs = model(**inputs).logits[0] | ||
pt_outputs_sentence = "pytorch: " | ||
for i, id in enumerate(inputs["input_ids"][0].tolist()): | ||
if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]: | ||
scores, index = torch.nn.functional.softmax(pt_outputs[i], | ||
-1).topk(5) | ||
tokens = tokenizer.convert_ids_to_tokens(index.tolist()) | ||
outputs = [] | ||
for score, tk in zip(scores.tolist(), tokens): | ||
outputs.append(f"{tk}={score}") | ||
pt_outputs_sentence += "[" + "||".join(outputs) + "]" + " " | ||
else: | ||
pt_outputs_sentence += "".join( | ||
tokenizer.convert_ids_to_tokens( | ||
[id], skip_special_tokens=True)) + " " | ||
|
||
print(pt_outputs_sentence) | ||
|
||
|
||
if __name__ == "__main__": | ||
compare_discriminator( | ||
path="junnyu/hfl-chinese-electra-180g-base-discriminator") | ||
# # mean dif: 3.1698835e-06 | ||
# # max dif: 1.335144e-05 | ||
compare_discriminator( | ||
path="junnyu/hfl-chinese-electra-180g-small-ex-discriminator") | ||
# mean dif: 3.7930229e-06 | ||
# max dif: 1.04904175e-05 | ||
compare_generator() | ||
# paddle: 本 院 经 审 查 认 为 , 本 案 [因=0.27444931864738464||经=0.18613006174564362||系=0.09408623725175858||的=0.07536833733320236||就=0.033634234219789505] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。 | ||
# pytorch: 本 院 经 审 查 认 为 , 本 案 [因=0.2744344472885132||经=0.1861187219619751||系=0.09407979995012283||的=0.07537488639354706||就=0.03363779932260513] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from collections import OrderedDict | ||
import argparse | ||
|
||
huggingface_to_paddle = { | ||
"embeddings.LayerNorm": "embeddings.layer_norm", | ||
"encoder.layer": "encoder.layers", | ||
"attention.self.query.": "self_attn.q_proj.", | ||
"attention.self.key.": "self_attn.k_proj.", | ||
"attention.self.value.": "self_attn.v_proj.", | ||
"attention.output.dense.": "self_attn.out_proj.", | ||
"intermediate.dense": "linear1", | ||
"output.dense": "linear2", | ||
"attention.output.LayerNorm": "norm1", | ||
"output.LayerNorm": "norm2", | ||
"generator_predictions.LayerNorm": "generator_predictions.layer_norm", | ||
"generator_lm_head.bias": "generator_lm_head_bias", | ||
} | ||
|
||
skip_weights = ["electra.embeddings.position_ids"] | ||
dont_transpose = ["_embeddings.weight", "LayerNorm."] | ||
|
||
|
||
def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path, | ||
paddle_dump_path): | ||
import torch | ||
import paddle | ||
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu") | ||
paddle_state_dict = OrderedDict() | ||
for k, v in pytorch_state_dict.items(): | ||
if k == "generator_lm_head.weight": continue | ||
is_transpose = False | ||
if k in skip_weights: | ||
continue | ||
if k[-7:] == ".weight": | ||
if not any([w in k for w in dont_transpose]): | ||
if v.ndim == 2: | ||
v = v.transpose(0, 1) | ||
is_transpose = True | ||
oldk = k | ||
for huggingface_name, paddle_name in huggingface_to_paddle.items(): | ||
k = k.replace(huggingface_name, paddle_name) | ||
|
||
print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}") | ||
paddle_state_dict[k] = v.data.numpy() | ||
|
||
paddle.save(paddle_state_dict, paddle_dump_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--pytorch_checkpoint_path", | ||
default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\pytorch_model.bin", | ||
type=str, | ||
required=False, | ||
help="Path to the Pytorch checkpoint path.") | ||
parser.add_argument( | ||
"--paddle_dump_path", | ||
default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\model_state.pdparams", | ||
type=str, | ||
required=False, | ||
help="Path to the output Paddle model.") | ||
args = parser.parse_args() | ||
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path, | ||
args.paddle_dump_path) |
37 changes: 37 additions & 0 deletions
37
community/junnyu/hfl-chinese-electra-180g-base-discriminator/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# 详细介绍 | ||
**介绍**:该模型是base版本的Electra discriminator模型,并且在180G的中文数据上进行训练。 | ||
|
||
**模型结构**: **`ElectraDiscriminator`**,带有判别器的中文Electra模型。 | ||
|
||
**适用下游任务**:**通用下游任务**,如:句子级别分类,token级别分类,抽取式问答等任务。 | ||
|
||
# 使用示例 | ||
|
||
```python | ||
import paddle | ||
from paddlenlp.transformers import ElectraDiscriminator, ElectraTokenizer | ||
|
||
path = "junnyu/hfl-chinese-electra-180g-base-discriminator" | ||
model = ElectraDiscriminator.from_pretrained(path) | ||
tokenizer = ElectraTokenizer.from_pretrained(path) | ||
model.eval() | ||
|
||
text = "欢迎使用paddlenlp!" | ||
inputs = { | ||
k: paddle.to_tensor( | ||
v, dtype="int64").unsqueeze(0) | ||
for k, v in tokenizer(text).items() | ||
} | ||
|
||
with paddle.no_grad(): | ||
logits = model(**inputs) | ||
|
||
print(logits.shape) | ||
|
||
``` | ||
|
||
# 权重来源 | ||
|
||
https://huggingface.co/hfl/chinese-electra-180g-base-discriminator | ||
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。 | ||
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra |
6 changes: 6 additions & 0 deletions
6
community/junnyu/hfl-chinese-electra-180g-base-discriminator/files.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_config.json", | ||
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_state.pdparams", | ||
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/tokenizer_config.json", | ||
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/vocab.txt" | ||
} |
36 changes: 36 additions & 0 deletions
36
community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# 详细介绍 | ||
**介绍**:该模型是small版本的Electra discriminator模型,并且在180G的中文数据上进行训练。 | ||
|
||
**模型结构**: **`ElectraDiscriminator`**,带有判别器的中文Electra模型。 | ||
|
||
**适用下游任务**:**通用下游任务**,如:句子级别分类,token级别分类,抽取式问答等任务。 | ||
|
||
# 使用示例 | ||
|
||
```python | ||
import paddle | ||
from paddlenlp.transformers import ElectraDiscriminator,ElectraTokenizer | ||
|
||
path = "junnyu/hfl-chinese-electra-180g-small-ex-discriminator" | ||
model = ElectraDiscriminator.from_pretrained(path) | ||
tokenizer = ElectraTokenizer.from_pretrained(path) | ||
model.eval() | ||
|
||
text = "欢迎使用paddlenlp!" | ||
inputs = { | ||
k: paddle.to_tensor( | ||
v, dtype="int64").unsqueeze(0) | ||
for k, v in tokenizer(text).items() | ||
} | ||
|
||
with paddle.no_grad(): | ||
logits = model(**inputs) | ||
|
||
print(logits.shape) | ||
|
||
``` | ||
|
||
# 权重来源 | ||
|
||
https://huggingface.co/hfl/chinese-electra-180g-small-ex-discriminator | ||
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。 |
6 changes: 6 additions & 0 deletions
6
community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/files.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_config.json", | ||
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_state.pdparams", | ||
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/tokenizer_config.json", | ||
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/vocab.txt" | ||
} |
58 changes: 58 additions & 0 deletions
58
community/junnyu/hfl-chinese-legal-electra-small-generator/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# 详细介绍 | ||
**介绍**:该模型是small版本的Electra generator模型,该模型在法律领域数据上进行了预训练。 | ||
|
||
**模型结构**: **`ElectraGenerator`**,带有生成器的中文Electra模型。 | ||
|
||
**适用下游任务**:**法律领域的下游任务**,如:法律领域的句子级别分类,法律领域的token级别分类,法律领域的抽取式问答等任务。 | ||
(注:生成器的效果不好,通常使用判别器进行下游任务微调) | ||
|
||
|
||
# 使用示例 | ||
|
||
```python | ||
import paddle | ||
from paddlenlp.transformers import ElectraGenerator, ElectraTokenizer | ||
|
||
text = "本院经审查认为,本案[MASK]民间借贷纠纷申请再审案件,应重点审查二审判决是否存在错误的情形。" | ||
path = "junnyu/hfl-chinese-legal-electra-small-generator" | ||
model = ElectraGenerator.from_pretrained(path) | ||
model.eval() | ||
tokenizer = ElectraTokenizer.from_pretrained(path) | ||
|
||
tokens = ["[CLS]"] | ||
text_list = text.split("[MASK]") | ||
for i, t in enumerate(text_list): | ||
tokens.extend(tokenizer.tokenize(t)) | ||
if i == len(text_list) - 1: | ||
tokens.extend(["[SEP]"]) | ||
else: | ||
tokens.extend(["[MASK]"]) | ||
|
||
input_ids_list = tokenizer.convert_tokens_to_ids(tokens) | ||
input_ids = paddle.to_tensor([input_ids_list]) | ||
with paddle.no_grad(): | ||
pd_outputs = model(input_ids)[0] | ||
pd_outputs_sentence = "paddle: " | ||
for i, id in enumerate(input_ids_list): | ||
if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]: | ||
scores, index = paddle.nn.functional.softmax(pd_outputs[i], | ||
-1).topk(5) | ||
tokens = tokenizer.convert_ids_to_tokens(index.tolist()) | ||
outputs = [] | ||
for score, tk in zip(scores.tolist(), tokens): | ||
outputs.append(f"{tk}={score}") | ||
pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " " | ||
else: | ||
pd_outputs_sentence += "".join( | ||
tokenizer.convert_ids_to_tokens( | ||
[id], skip_special_tokens=True)) + " " | ||
|
||
print(pd_outputs_sentence) | ||
# paddle: 本 院 经 审 查 认 为 , 本 案 [因=0.27444931864738464||经=0.18613006174564362||系=0.09408623725175858||的=0.07536833733320236||就=0.033634234219789505] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。 | ||
``` | ||
|
||
# 权重来源 | ||
|
||
https://huggingface.co/hfl/chinese-legal-electra-small-generator | ||
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。 | ||
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra |
6 changes: 6 additions & 0 deletions
6
community/junnyu/hfl-chinese-legal-electra-small-generator/files.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
{ | ||
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_config.json", | ||
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_state.pdparams", | ||
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/tokenizer_config.json", | ||
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/vocab.txt" | ||
} |
Oops, something went wrong.