This repository contains code and checkpoints for CPT.
CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation
Yunfan Shao, Zhichao Geng, Yitao Liu, Junqi Dai, Fei Yang, Li Zhe, Hujun Bao, Xipeng Qiu
Aiming to unify both NLU and NLG tasks, We propose a novel Chinese Pre-trained Un-balanced Transformer (CPT), which is an unbalanced Transformer encoder-decoder pre-trained with MLM and DAE jointly.
The architecture of CPT is a variant of the full Transformer and consists of three parts:
- Shared Encoder (S-Enc): a Transformer encoder with fully-connected self-attention, which is designed to capture the common semantic representation for both language understanding and generation.
- Understanding Decoder (U-Dec): a shallow Transformer encoder with fully-connected self-attention, which is designed for NLU tasks. The input of U-Dec is the output of S-Enc.
- Generation Decoder (G-Dec): a Transformer decoder with masked self-attention, which is designed for generation tasks with auto-regressive fashion. G-Dec utilizes the output of S-Enc with cross-attention.
We provide the pre-trained weights of CPT and Chinese BART with source code, which can be directly used in Huggingface-Transformers.
Chinese BART-base
: 6 layers Encoder, 6 layers Decoder, 12 Heads and 768 Model dim.Chinese BART-large
: 12 layers Encoder, 12 layers Decoder, 16 Heads and 1024 Model dim.CPT-base
: 10 layers S-Enc, 2 layers U-Dec/G-Dec, 12 Heads and 768 Model dim.CPT-large
: 20 layers S-Enc, 4 layers U-Dec/G-Dec, 16 Heads and 1024 Model dim.
The pre-trained weights can be downloaded here.
Model | MODEL_NAME |
---|---|
Chinese BART-base |
fnlp/bart-base-chinese |
Chinese BART-large |
fnlp/bart-large-chinese |
CPT-base |
fnlp/cpt-base |
CPT-large |
fnlp/cpt-large |
- pytorch==1.8.1
- transformers==4.4.1
To use CPT, please import the file finetune/modeling_cpt.py
that define the architecture of CPT into your project.
Then, use the PTMs as the following example, where MODEL_NAME
is the corresponding string that refers to the model.
For CPT:
from modeling_cpt import BertTokenizer, CPTForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
model = CPTForConditionalGeneration.from_pretrained("MODEL_NAME")
print(model)
For Chinese BART:
from transformers import BertTokenizer, BartForConditionalGeneration
tokenizer = BertTokenizer.from_pretrained("MODEL_NAME")
model = BartForConditionalGeneration.from_pretrained("MODEL_NAME")
print(model)
Pre-training code and examples can be find Here.
Fine-tuning code and examples can be find Here.
run locally
#make sure you got trained models in 6 folder
cd endpoint
sh build_and_push.sh
docker run -v -d -p 8080:8080 cpt
# test
#curl http://localhost:8080/ping
# curl
import requests
import json
url='http://localhost:8080/invocations'
data={"data": "《半導體》Q1展望保守,世界垂淚2019/02/11 10:28時報資訊【時報記者沈培華台北報導】世界先進 (5347) 去年營運創歷史新高,每股純益達3.72元。但對今年首季展望保守,預計營收將比上季高點減近一成。世界先進於封關前股價拉高,今早則是開平走低。世界先進於年前台股封關後舉行法說會公布財報。公司去年營運表現亮麗,營收與獲利同創歷史新高紀錄。2018年全年營收289.28億元,年增16.1%,毛利率35.2%,拉升3.2個百分點,稅後淨利61.66億元,年增36.9%,營收與獲利同創歷史新高,每股純益3.72元。董事會通過去年度擬配發現金股利3.2元。展望第一季,受到客戶進入庫存調整,公司預期,本季營收估在67億至71億元,將季減8%至13%,毛利率將約34.5%至36.5%。此外,因應客戶需求,世界先進決定斥資2.36億美元,收購格芯新加坡8吋晶圓廠。世界先進於年前宣布,將購買格芯位於新加坡Tampines的8吋晶圓3E廠房、廠務設施、機器設備及微機電(MEMS)智財權與業務,交易總金額2.36億美元,交割日訂108年12月31日。格芯晶圓3E廠現有月產能3.5萬片8吋晶圓,世界先進每年將可增加超過40萬片8吋晶圓產能,增進公司明年起業績成長動能。TOP關閉"}
data = json.dumps(data)
r = requests.post(url,data=data)
#show result
print (r.text)
结果如下
{"摘要": "2011 年 f1 周 年 展 望 保 守 保 守 世 界 第 一 (图 )"}
run on endpoint
endpoint_ecr_image="847380964353.dkr.ecr.us-east-2.amazonaws.com/cpt"
python create_endpoint.py \
--endpoint_ecr_image_path ${endpoint_ecr_image} \
--endpoint_name 'cpt' \
--instance_type "ml.g4dn.xlarge"
在部署结束后,看到SageMaker控制台生成了对应的endpoint,可以使用如下客户端代码测试调用
from boto3.session import Session
import json
data={"data": "《半導體》Q1展望保守,世界垂淚2019/02/11 10:28時報資訊【時報記者沈培華台北報導】世界先進 (5347) 去年營運創歷史新高,每股純益達3.72元。但對今年首季展望保守,預計營收將比上季高點減近一成。世界先進於封關前股價拉高,今早則是開平走低。世界先進於年前台股封關後舉行法說會公布財報。公司去年營運表現亮麗,營收與獲利同創歷史新高紀錄。2018年全年營收289.28億元,年增16.1%,毛利率35.2%,拉升3.2個百分點,稅後淨利61.66億元,年增36.9%,營收與獲利同創歷史新高,每股純益3.72元。董事會通過去年度擬配發現金股利3.2元。展望第一季,受到客戶進入庫存調整,公司預期,本季營收估在67億至71億元,將季減8%至13%,毛利率將約34.5%至36.5%。此外,因應客戶需求,世界先進決定斥資2.36億美元,收購格芯新加坡8吋晶圓廠。世界先進於年前宣布,將購買格芯位於新加坡Tampines的8吋晶圓3E廠房、廠務設施、機器設備及微機電(MEMS)智財權與業務,交易總金額2.36億美元,交割日訂108年12月31日。格芯晶圓3E廠現有月產能3.5萬片8吋晶圓,世界先進每年將可增加超過40萬片8吋晶圓產能,增進公司明年起業績成長動能。TOP關閉"}
session = Session()
runtime = session.client("runtime.sagemaker")
response = runtime.invoke_endpoint(
EndpointName='cpt',
ContentType="application/json",
Body=json.dumps(data),
)
result = json.loads(response["Body"].read())
print (result)
@article{shao2021cpt,
title={CPT: A Pre-Trained Unbalanced Transformer for Both Chinese Language Understanding and Generation},
author={Yunfan Shao and Zhichao Geng and Yitao Liu and Junqi Dai and Fei Yang and Li Zhe and Hujun Bao and Xipeng Qiu},
journal={arXiv preprint arXiv:2109.05729},
year={2021}
}