forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#35 from jeff41404/electra_pretrain_an…
…d_deploy unite electra pretrain, fine-tune and deploy
- Loading branch information
Showing
9 changed files
with
1,053 additions
and
455 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
199 changes: 199 additions & 0 deletions
199
examples/language_model/electra/deploy/python/predict.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
import argparse | ||
import time | ||
import numpy as np | ||
import os | ||
|
||
import paddle.inference as paddle_infer | ||
from paddle.fluid.core import AnalysisConfig | ||
from paddle.fluid.core import create_paddle_predictor | ||
|
||
from paddlenlp.transformers import ElectraTokenizer | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model_file", type=str, required=True, help="model filename") | ||
parser.add_argument( | ||
"--params_file", type=str, required=True, help="parameter filename") | ||
parser.add_argument( | ||
"--predict_sentences", | ||
type=str, | ||
nargs="*", | ||
help="one or more sentence to predict") | ||
parser.add_argument( | ||
"--predict_file", | ||
type=str, | ||
nargs="*", | ||
help="one or more file which contain sentence to predict") | ||
parser.add_argument("--batch_size", type=int, default=1, help="batch size") | ||
parser.add_argument( | ||
"--use_gpu", action="store_true", help="whether to use gpu") | ||
parser.add_argument( | ||
"--use_trt", action="store_true", help="whether to use TensorRT") | ||
parser.add_argument( | ||
"--max_seq_length", | ||
type=int, | ||
default=128, | ||
help="max length of each sequence") | ||
parser.add_argument( | ||
"--model_name", | ||
type=str, | ||
default="electra-small", | ||
help="shortcut name selected in the list: " + | ||
", ".join(list(ElectraTokenizer.pretrained_init_configuration.keys()))) | ||
return parser.parse_args() | ||
|
||
|
||
def read_sentences(paths=[]): | ||
sentences = [] | ||
for sen_path in paths: | ||
assert os.path.isfile(sen_path), "The {} isn't a valid file.".format( | ||
sen_path) | ||
sen = read_file(sen_path) | ||
if sen is None: | ||
logger.info("error in loading file:{}".format(sen_path)) | ||
continue | ||
sentences.extend(sen) | ||
return sentences | ||
|
||
|
||
def read_file(path): | ||
lines = [] | ||
with open(path, encoding="utf-8") as f: | ||
while True: | ||
line = f.readline() | ||
if line: | ||
if (len(line) > 0 and not line.isspace()): | ||
lines.append(line.strip()) | ||
else: | ||
break | ||
return lines | ||
|
||
|
||
def get_predicted_input(predicted_data, tokenizer, max_seq_length, batch_size): | ||
if predicted_data == [] or not isinstance(predicted_data, list): | ||
raise TypeError("The predicted_data is inconsistent with expectations.") | ||
|
||
sen_ids_batch = [] | ||
sen_words_batch = [] | ||
sen_ids = [] | ||
sen_words = [] | ||
batch_num = 0 | ||
pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) | ||
for sen in predicted_data: | ||
sen_id = tokenizer(sen, max_seq_len=max_seq_length)['input_ids'] | ||
sen_ids.append(sen_id) | ||
sen_words.append(tokenizer.cls_token + " " + sen + " " + | ||
tokenizer.sep_token) | ||
batch_num += 1 | ||
if batch_num == batch_size: | ||
tmp_list = [] | ||
max_length = max([len(i) for i in sen_ids]) | ||
for i in sen_ids: | ||
if len(i) < max_length: | ||
tmp_list.append(i + (max_length - len(i)) * [pad_token_id]) | ||
else: | ||
tmp_list.append(i) | ||
sen_ids_batch.append(tmp_list) | ||
sen_words_batch.append(sen_words) | ||
sen_ids = [] | ||
sen_words = [] | ||
batch_num = 0 | ||
|
||
if len(sen_ids) > 0: | ||
tmp_list = [] | ||
max_length = max([len(i) for i in sen_ids]) | ||
for i in sen_ids: | ||
if len(i) < max_length: | ||
tmp_list.append(i + (max_length - len(i)) * [pad_token_id]) | ||
else: | ||
tmp_list.append(i) | ||
sen_ids_batch.append(tmp_list) | ||
sen_words_batch.append(sen_words) | ||
|
||
return sen_ids_batch, sen_words_batch | ||
|
||
|
||
def predict(args, sentences=[], paths=[]): | ||
""" | ||
Args: | ||
sentences (list[str]): each string is a sentence. If sentences not paths | ||
paths (list[str]): The paths of file which contain sentences. If paths not sentences | ||
Returns: | ||
res (list(numpy.ndarray)): The result of sentence, indicate whether each word is replaced, same shape with sentences. | ||
""" | ||
|
||
# initialize data | ||
if sentences != [] and isinstance(sentences, list) and (paths == [] or | ||
paths is None): | ||
predicted_data = sentences | ||
elif (sentences == [] or sentences is None) and isinstance( | ||
paths, list) and paths != []: | ||
predicted_data = read_sentences(paths) | ||
else: | ||
raise TypeError("The input data is inconsistent with expectations.") | ||
|
||
tokenizer = ElectraTokenizer.from_pretrained(args.model_name) | ||
predicted_input, predicted_sens = get_predicted_input( | ||
predicted_data, tokenizer, args.max_seq_length, args.batch_size) | ||
|
||
# config | ||
config = AnalysisConfig(args.model_file, args.params_file) | ||
config.switch_use_feed_fetch_ops(False) | ||
config.enable_memory_optim() | ||
if args.use_gpu: | ||
config.enable_use_gpu(1000, 0) | ||
if args.use_trt: | ||
config.enable_tensorrt_engine( | ||
workspace_size=1 << 30, | ||
max_batch_size=args.batch_size, | ||
min_subgraph_size=5, | ||
precision_mode=AnalysisConfig.Precision.Float32, | ||
use_static=False, | ||
use_calib_mode=False) | ||
|
||
# predictor | ||
predictor = create_paddle_predictor(config) | ||
|
||
start_time = time.time() | ||
output_datas = [] | ||
count = 0 | ||
for i, sen in enumerate(predicted_input): | ||
sen = np.array(sen).astype("int64") | ||
# get input name | ||
input_names = predictor.get_input_names() | ||
# get input pointer and copy data | ||
input_tensor = predictor.get_input_tensor(input_names[0]) | ||
input_tensor.reshape(sen.shape) | ||
input_tensor.copy_from_cpu(sen) | ||
#input_tensor.copy_from_cpu(fake_input.copy()) | ||
|
||
# run predictor | ||
predictor.zero_copy_run() | ||
|
||
# get output name | ||
output_names = predictor.get_output_names() | ||
# get output pointer and copy data(nd.array) | ||
output_tensor = predictor.get_output_tensor(output_names[0]) | ||
output_data = output_tensor.copy_to_cpu() | ||
output_res = np.argmax(output_data, axis=1).tolist() | ||
output_datas.append(output_res) | ||
|
||
print("===== batch {} =====".format(i)) | ||
for j in range(len(predicted_sens[i])): | ||
print("Input sentence is : {}".format(predicted_sens[i][j])) | ||
#print("Output logis is : {}".format(output_data[j])) | ||
print("Output data is : {}".format(output_res[j])) | ||
count += len(predicted_sens[i]) | ||
print("inference total %s sentences done, total time : %s s" % | ||
(count, time.time() - start_time)) | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
sentences = args.predict_sentences | ||
paths = args.predict_file | ||
#sentences = ["The quick brown fox see over the lazy dog.", "The quick brown fox jump over tree lazy dog."] | ||
#paths = ["../../debug/test.txt", "../../debug/test.txt.1"] | ||
predict(args, sentences, paths) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License" | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#from collections import namedtuple | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
import hashlib | ||
import argparse | ||
import json | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
from paddle.static import InputSpec | ||
|
||
from paddlenlp.transformers import ElectraForTotalPretraining, ElectraDiscriminator, ElectraGenerator, ElectraModel | ||
from paddlenlp.transformers import ElectraForSequenceClassification, ElectraTokenizer | ||
|
||
|
||
def get_md5sum(file_path): | ||
md5sum = None | ||
if os.path.isfile(file_path): | ||
with open(file_path, 'rb') as f: | ||
md5_obj = hashlib.md5() | ||
md5_obj.update(f.read()) | ||
hash_code = md5_obj.hexdigest() | ||
md5sum = str(hash_code).lower() | ||
return md5sum | ||
|
||
|
||
def main(): | ||
# check and load config | ||
with open(os.path.join(args.input_model_dir, "model_config.json"), | ||
'r') as f: | ||
config_dict = json.load(f) | ||
num_classes = config_dict['num_classes'] | ||
if num_classes is None or num_classes <= 0: | ||
print("%s/model_config.json may not be right, please check" % | ||
args.input_model_dir) | ||
exit(1) | ||
|
||
# check and load model | ||
input_model_file = os.path.join(args.input_model_dir, | ||
"model_state.pdparams") | ||
print("load model to get static model : %s \nmodel md5sum : %s" % | ||
(input_model_file, get_md5sum(input_model_file))) | ||
model_state_dict = paddle.load(input_model_file) | ||
|
||
if all((s.startswith("generator") or s.startswith("discriminator")) | ||
for s in model_state_dict.keys()): | ||
print( | ||
"the model : %s is electra pretrain model, we need fine-tuning model to deploy" | ||
% input_model_file) | ||
exit(1) | ||
elif "discriminator_predictions.dense.weight" in model_state_dict: | ||
print( | ||
"the model : %s is electra discriminator model, we need fine-tuning model to deploy" | ||
% input_model_file) | ||
exit(1) | ||
elif "classifier.dense.weight" in model_state_dict: | ||
print("we are load glue fine-tuning model") | ||
model = ElectraForSequenceClassification.from_pretrained( | ||
args.input_model_dir, num_classes=num_classes) | ||
print("total model layers : ", len(model_state_dict)) | ||
else: | ||
print("the model file : %s may not be fine-tuning model, please check" % | ||
input_model_file) | ||
exit(1) | ||
|
||
# save static model to disk | ||
paddle.jit.save( | ||
layer=model, | ||
path=os.path.join(args.output_model_dir, args.model_name), | ||
input_spec=[InputSpec( | ||
shape=[None, None], dtype='int64')]) | ||
print("save electra inference model success") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--input_model_dir", | ||
required=True, | ||
default=None, | ||
help="Directory for storing Electra pretraining model") | ||
parser.add_argument( | ||
"--output_model_dir", | ||
required=True, | ||
default=None, | ||
help="Directory for output Electra inference model") | ||
parser.add_argument( | ||
"--model_name", | ||
default="electra-deploy", | ||
type=str, | ||
help="prefix name of output model and parameters") | ||
args, unparsed = parser.parse_known_args() | ||
main() |
Oops, something went wrong.