Skip to content

Commit

Permalink
fixed tokenizer in glm blankfilling (#13)
Browse files Browse the repository at this point in the history
* fixed tokenizer in glm blankfilling

Signed-off-by: Anhforth <yanzhaodong2021@163.com>

* Update glm_generate_samples_en.py

Co-authored-by: Anhforth <yanzhaodong2021@163.com>
Co-authored-by: Zac Liu <liuguang@baai.ac.cn>
  • Loading branch information
3 people authored Jun 28, 2022
1 parent 1a4e253 commit 97ffea4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 3 deletions.
30 changes: 30 additions & 0 deletions examples/glm_blank_filling/glm_generate_samples_en.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright © 2022 BAAI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License")

import torch
from flagai.model.predictor.predictor import Predictor
from flagai.auto_model.auto_loader import AutoLoader
if __name__ == "__main__":
"""Main training program."""
print('Generate Samples')
# Random seeds for reproducability.
# Model,
loader = AutoLoader(task_name='lm',
model_name='GLM-large-en',
only_download_config=False)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
model.cuda(torch.cuda.current_device())

predictor = Predictor(model, tokenizer)
# generate samples
text = [
'Question: Is drinking beer bad for your health? Answer: [gMASK]',
]
for t in text:
output = predictor.predict_generate_randomsample(
t, top_k=50, repetition_penalty=4.0, top_p=1.0)
print(t, '\n', output)


4 changes: 3 additions & 1 deletion flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __getattr__(self, name):
],
"glm-large-en": [
"flagai.data.tokenizer.glm_large_en.glm_large_en_tokenizer",
"GLMLargeEnTokenizer"
"GLMLargeEnWordPieceTokenizer"
],
"gpt2-base-ch": ["flagai.data.tokenizer.bert.bert_tokenizer", "BertTokenizer"],
"cpm-large-ch": ["flagai.data.tokenizer.cpm_1.cpm1_tokenizer", "CPMTokenizer"],
Expand Down Expand Up @@ -200,6 +200,8 @@ def __init__(self,
self.tokenizer = tokenizer_class(vocab_file_1, vocab_file_2)
elif brief_model_name == "opt":
self.tokenizer = tokenizer_class("facebook/opt-350m")
elif model_name in ["glm-large-en", "glm-large-ch"]:
self.tokenizer = tokenizer_class()
else :
self.tokenizer = tokenizer_class(vocab_file)

Expand Down
3 changes: 2 additions & 1 deletion flagai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,8 @@ def evaluate_and_print_results(
if eval_dict.get("loss", None) is not None:
string = ' validation loss at {} | {:.4f}, '.format(
prefix, eval_dict["loss"])

# with open("results.txt", "a") as myfile:
# myfile.write(string)
if self.metric_methods is None:
return eval_dict

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_tokenizer_opt(self):
tokenizer = OPTTokenizer(tokenizer_model_type="facebook/opt-125m")
self.assertEqual(tokenizer.get_vocab()["day"], 1208, '')
self.assertEqual(tokenizer.encode_plus("fried chicken makes me happy")["input_ids"],
[2, 21209, 5884, 817, 162, 1372], '')
[21209, 5884, 817, 162, 1372], '')
self.assertEqual(tokenizer.decode([21209, 5884, 817, 162, 1372]),
'fried chicken makes me happy', 'DecodeIds Error')

Expand Down

0 comments on commit 97ffea4

Please sign in to comment.