Skip to content

Commit e0d8971

Browse files
committed
Add gpt2 demo
1 parent c46d6b7 commit e0d8971

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

demo/gpt2/dd_client.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../clients/python/dd_client.py

demo/gpt2/run_gpt2.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import random
2+
import sys
3+
import argparse
4+
from dd_client import DD
5+
6+
parser = argparse.ArgumentParser(description="Use DeepDetect and GPT-2 to generate text")
7+
parser.add_argument("-r", "--repository", required=True, help="Model repository")
8+
parser.add_argument("--host", type=str, default="localhost")
9+
parser.add_argument("--port", type=int, default=8080)
10+
parser.add_argument("--cpu", action='store_true', help="Force model to run on CPU")
11+
parser.add_argument("--input-size", type=int, default=512)
12+
parser.add_argument("--topk", type=int, default=5, help="How many top predictions should be considered to chose the next token.")
13+
parser.add_argument("--temperature", type=float, default=1, help="Temperature of the predictions. The higher, the 'randomer'.")
14+
15+
args = parser.parse_args()
16+
17+
# dd global variables
18+
sname = 'gpt-2'
19+
description = 'Inference with GPT-2'
20+
mllib = 'torch'
21+
22+
dd = DD(args.host, args.port)
23+
dd.set_return_format(dd.RETURN_PYTHON)
24+
25+
# setting up the ML service
26+
model = {'repository':args.repository}
27+
parameters_input = {
28+
'connector':'txt',
29+
'ordered_words': True,
30+
'wordpiece_tokens': True,
31+
'punctuation_tokens': True,
32+
'lower_case': False,
33+
'width': args.input_size
34+
}
35+
parameters_mllib = {'template':'gpt2', 'gpu':True}
36+
parameters_output = {}
37+
dd.put_service(sname,model,description,mllib,
38+
parameters_input,parameters_mllib,parameters_output)
39+
40+
# generating text
41+
prompt = input("Enter beggining of sentence >>> ")
42+
43+
for i in range(0, 256):
44+
data = [prompt]
45+
parameters_input = {'word_start': "Ġ", 'suffix_start': ""}
46+
parameters_mllib = {}
47+
parameters_output = {'best':args.topk}
48+
result = dd.post_predict(sname, data, parameters_input,parameters_mllib,parameters_output)
49+
50+
# Select result from the returned tokens
51+
word_probs = list()
52+
total_probs = 0
53+
54+
for cls in result['body']['predictions'][0]['classes']:
55+
word = cls['cat'].replace("Ġ", " ")
56+
# dede does not support \n character well, so we don't select tokens containing a new line
57+
if 'Ċ' in word:
58+
continue
59+
60+
prob = pow(cls['prob'], args.temperature)
61+
total_probs += prob
62+
word_probs.append((word, prob))
63+
64+
selector = random.uniform(0, total_probs)
65+
total_probs = 0
66+
67+
for word, prob in word_probs:
68+
total_probs += prob
69+
if total_probs > selector:
70+
selected_word = word
71+
break
72+
73+
print(selected_word, sep='', end='')
74+
sys.stdout.flush()
75+
prompt += selected_word

0 commit comments

Comments
 (0)