-
Notifications
You must be signed in to change notification settings - Fork 28
/
text_generation.py
69 lines (62 loc) · 2.09 KB
/
text_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
def parse_inputs():
parser = argparse.ArgumentParser(description="OrionStar-Yi-34B-Chat text generation demo")
parser.add_argument(
"--model",
type=str,
default="OrionStarAI/OrionStar-Yi-34B-Chat",
help="pretrained model path locally or name on huggingface",
)
parser.add_argument(
"--tokenizer",
type=str,
default="OrionStarAI/OrionStar-Yi-34B-Chat",
help="tokenizer path locally or name on huggingface",
)
parser.add_argument(
"--prompt",
type=str,
default="你好!",
help="The prompt to start with",
)
parser.add_argument(
"--streaming",
action="store_true",
help="whether to enable streaming text generation",
)
parser.add_argument(
"--eos-token",
type=str,
default="<|endoftext|>",
help="End of sentence token",
)
args = parser.parse_args()
return args
def main(args):
print(args)
model = AutoModelForCausalLM.from_pretrained(args.model, device_map="auto", torch_dtype=torch.bfloat16,
trust_remote_code=True, use_safetensors=False)
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer or args.model, trust_remote_code=True)
messages = [{"role": "user", "content": args.prompt}]
if args.streaming:
position = 0
try:
for response in model.chat(tokenizer, messages, streaming=True):
print(response[position:], end='', flush=True)
position = len(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
except KeyboardInterrupt:
pass
else:
response = model.chat(tokenizer, messages)
print(response)
if torch.backends.mps.is_available():
torch.mps.empty_cache()
if __name__ == "__main__":
args = parse_inputs()
main(args)