Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddlePaddle Hackathon 57 提交 #1128

Merged
merged 17 commits into from
Nov 29, 2021
12 changes: 12 additions & 0 deletions community/renmada/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Convert tp paddle
1. download checkpoints from huggingface model hub
2. modify model path and run script blow.
```bash
# path1: sshleifertiny-distilbert-base-uncased-finetuned-sst-2-english
# path2: distilbert-base-multilingual-cased
export path1='sshleifertiny-distilbert-base-uncased-finetuned-sst-2-english/pytorch_model.bin'
export path2='distilbert-base-multilingual-cased/pytorch_model.bin'
python convert_to_paddle.py \
--sshleifertiny_model_path $path1 \
--base_model_path $path2
```
101 changes: 101 additions & 0 deletions community/renmada/convert_to_paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import argparse
import numpy as np

huggingface_to_paddle1 = {
"embeddings.LayerNorm": "embeddings.layer_norm",
"transformer.layer": "encoder.layers",
"attention.q_lin": "self_attn.q_proj",
"attention.k_lin": "self_attn.k_proj",
"attention.v_lin": "self_attn.v_proj",
"attention.out_lin": "self_attn.out_proj",
"ffn.lin1": "linear1",
"ffn.lin2": "linear2",
"sa_layer_norm": "norm1",
"output_layer_norm": "norm2",
}

huggingface_to_paddle2 = {
"bert": "distilbert",
"embeddings.LayerNorm": "embeddings.layer_norm",
"encoder.layer": "encoder.layers",
"attention.self.query": "self_attn.q_proj",
"attention.self.key": "self_attn.k_proj",
"attention.self.value": "self_attn.v_proj",
"attention.output.dense": "self_attn.out_proj",
"intermediate.dense": "linear1",
"output.dense": "linear2",
"attention.output.LayerNorm": "norm1",
"output.LayerNorm": "norm2",
"predictions.decoder.": "predictions.decoder_",
"predictions.transform.dense": "predictions.transform",
"predictions.transform.LayerNorm": "predictions.layer_norm",
}


def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
huggingface_to_paddle):
import torch
import paddle
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
Total_params = 0

for k, v in pytorch_state_dict.items():
mulValue = np.prod(v.shape)
Total_params += mulValue
is_transpose = False
if k[-7:] == ".weight":
if ".embeddings." not in k and ".LayerNorm." not in k:
if v.ndim == 2:
v = v.transpose(0, 1)
is_transpose = True

oldk = k
for huggingface_name, paddle_name in huggingface_to_paddle.items():
k = k.replace(huggingface_name, paddle_name)
if k.startswith('distilbert.pooler.dense'):
k = k.replace('distilbert.pooler.dense', 'pre_classifier')

print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()
paddle_dump_path = pytorch_checkpoint_path.replace('pytorch_model.bin',
'model_state.pdparams')
paddle.save(paddle_state_dict, paddle_dump_path)
print(f'Total params: {Total_params}')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--sshleifertiny_model_path",
default="D:\\paddle_models\\sshleifertiny-distilbert-base-uncased-finetuned-sst-2-english\\pytorch_model.bin",
type=str,
required=False,
help="Path to the Pytorch checkpoint path.")
parser.add_argument(
"--base_model_path",
default="D:\\paddle_models\\distilbert-base-multilingual-cased\\pytorch_model.bin",
type=str,
required=False,
help="Path to the Pytorch checkpoint path.")
args = parser.parse_args()
convert_pytorch_checkpoint_to_paddle(args.sshleifertiny_model_path,
huggingface_to_paddle2)
print()
convert_pytorch_checkpoint_to_paddle(args.base_model_path,
huggingface_to_paddle1)
32 changes: 32 additions & 0 deletions community/renmada/distilbert-base-multilingual-cased/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# 模型介绍
This model is a distilled version of the BERT base multilingual model. The code for the distillation process can be found here. This model is cased: it does make a difference between english and English.

The model is trained on the concatenation of Wikipedia in 104 different languages listed here. The model has 6 layers, 768 dimension and 12 heads, totalizing 134M parameters (compared to 177M parameters for mBERT-base). On average DistilmBERT is twice as fast as mBERT-base.
# 模型来源
https://huggingface.co/distilbert-base-multilingual-cased

# 模型使用
```python
import paddle
from paddlenlp.transformers import DistilBertForMaskedLM, DistilBertTokenizer

model = DistilBertForMaskedLM.from_pretrained('renmada/distilbert-base-multilingual-cased')
tokenizer = DistilBertTokenizer.from_pretrained('renmada/distilbert-base-multilingual-cased')

inp = '北京是中国的首都'
ids = tokenizer.encode(inp)['input_ids'] # [101, 10751, 13672, 16299, 10124, 10105, 12185, 10108, 50513, 119, 102]
print(ids)

# mask "北京"
ids[1] = 103
ids[2] = 103
ids = paddle.to_tensor([ids])

# Do mlm
model.eval()
with paddle.no_grad():
mlm_logits = model(ids)
mlm_pred = paddle.topk(mlm_logits, 1, -1)[1][0].unsqueeze(-1)

print(''.join(tokenizer.vocab.idx_to_token[int(x)] for x in mlm_pred[1:-1])) # 汉阳是中国的首都
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/distilbert-base-multilingual-cased/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/distilbert-base-multilingual-cased/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/distilbert-base-multilingual-cased/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/distilbert-base-multilingual-cased/vocab.txt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# 模型介绍
tiny-distilbert-base-uncased在sst-2上finetune后的模型
# 模型来源
https://huggingface.co/sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english
# 模型使用
```python
import paddle
from paddlenlp.transformers import DistilBertForSequenceClassification, DistilBertTokenizer

model = DistilBertForSequenceClassification.from_pretrained('renmada/sshleifer-tiny-distilbert-base-uncase-finetuned-sst-2-english')
tokenizer = DistilBertTokenizer.from_pretrained('renmada/sshleifer-tiny-distilbert-base-uncase-finetuned-sst-2-english')
inp = 'It is good'
ids = tokenizer.encode(inp)['input_ids']
ids = paddle.to_tensor([ids])
model.eval()
with paddle.no_grad():
logtis = model(ids)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/sshleifer-tiny-distilbert-base-uncased-finetuned-sst-2-english/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/sshleifer-tiny-distilbert-base-uncased-finetuned-sst-2-english/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/sshleifer-tiny-distilbert-base-uncased-finetuned-sst-2-english/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/renmada/sshleifer-tiny-distilbert-base-uncased-finetuned-sst-2-english/vocab.txt"
}
12 changes: 11 additions & 1 deletion docs/model_zoo/transformers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,16 @@ Transformer预训练模型汇总
| | | | 12-heads, 66M parameters. |
| | | | The DistilBERT model distilled from |
| | | | the BERT model ``bert-base-cased`` |
| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+
| |``distilbert-base-multilingual-cased`` | English | 6-layer, 768-hidden, |
| | | | 12-heads, 200M parameters. |
| | | | The DistilBERT model distilled from |
| | | | the BERT model |
| | | | ``bert-base-multilingual-cased`` |
| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+
| |``sshleifer-tiny-distilbert-base-uncase-finetuned-sst-2-english`` | English | 2-layer, 2-hidden, |
| | | | 2-heads, 50K parameters. |
| | | | The DistilBERT model |
+--------------------+----------------------------------------------------------------------------------+--------------+-----------------------------------------+
|ELECTRA_ |``electra-small`` | English | 12-layer, 768-hidden, |
| | | | 4-heads, _M parameters. |
Expand Down Expand Up @@ -717,4 +727,4 @@ Reference
- Jiao, Xiaoqi, et al. "Tinybert: Distilling bert for natural language understanding." arXiv preprint arXiv:1909.10351 (2019).
- Bao, Siqi, et al. "Plato-2: Towards building an open-domain chatbot via curriculum learning." arXiv preprint arXiv:2006.16779 (2020).
- Yang, Zhilin, et al. "Xlnet: Generalized autoregressive pretraining for language understanding." arXiv preprint arXiv:1906.08237 (2019).
- Cui, Yiming, et al. "Pre-training with whole word masking for chinese bert." arXiv preprint arXiv:1906.08101 (2019).
- Cui, Yiming, et al. "Pre-training with whole word masking for chinese bert." arXiv preprint arXiv:1906.08101 (2019).
Loading