Skip to content

Commit 881249c

Browse files
committed
Improve gpt2 demo script
1 parent 13a2e65 commit 881249c

File tree

1 file changed

+31
-5
lines changed

1 file changed

+31
-5
lines changed

demo/gpt2/run_gpt2.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import random
12
import sys
23
import argparse
34
from dd_client import DD
@@ -6,8 +7,10 @@
67
parser.add_argument("-r", "--repository", required=True, help="Model repository")
78
parser.add_argument("--host", type=str, default="localhost")
89
parser.add_argument("--port", type=int, default=8080)
9-
parser.add_argument("--cpu", action='store_true')
10+
parser.add_argument("--cpu", action='store_true', help="Force model to run on CPU")
1011
parser.add_argument("--input-size", type=int, default=512)
12+
parser.add_argument("--topk", type=int, default=5, help="How many top predictions should be considered to chose the next token.")
13+
parser.add_argument("--temperature", type=float, default=1, help="Temperature of the predictions. The higher, the 'randomer'.")
1114

1215
args = parser.parse_args()
1316

@@ -41,9 +44,32 @@
4144
data = [prompt]
4245
parameters_input = {'word_start': "Ġ", 'suffix_start': ""}
4346
parameters_mllib = {}
44-
parameters_output = {'best':3}
47+
parameters_output = {'best':args.topk}
4548
result = dd.post_predict(sname, data, parameters_input,parameters_mllib,parameters_output)
46-
word = result['body']['predictions'][0]['classes'][0]['cat'].replace("Ġ", " ").replace("Ċ", "\n")
47-
print(word, sep='', end='')
49+
50+
# Select result from the returned tokens
51+
word_probs = list()
52+
total_probs = 0
53+
54+
for cls in result['body']['predictions'][0]['classes']:
55+
word = cls['cat'].replace("Ġ", " ")
56+
# dede does not support \n character well, so we don't select tokens containing a new line
57+
if 'Ċ' in word:
58+
continue
59+
60+
prob = pow(cls['prob'], args.temperature)
61+
total_probs += prob
62+
word_probs.append((word, prob))
63+
64+
selector = random.uniform(0, total_probs)
65+
total_probs = 0
66+
67+
for word, prob in word_probs:
68+
total_probs += prob
69+
if total_probs > selector:
70+
selected_word = word
71+
break
72+
73+
print(selected_word, sep='', end='')
4874
sys.stdout.flush()
49-
prompt += word
75+
prompt += selected_word

0 commit comments

Comments
 (0)