|
| 1 | +import random |
1 | 2 | import sys
|
2 | 3 | import argparse
|
3 | 4 | from dd_client import DD
|
|
6 | 7 | parser.add_argument("-r", "--repository", required=True, help="Model repository")
|
7 | 8 | parser.add_argument("--host", type=str, default="localhost")
|
8 | 9 | parser.add_argument("--port", type=int, default=8080)
|
9 |
| -parser.add_argument("--cpu", action='store_true') |
| 10 | +parser.add_argument("--cpu", action='store_true', help="Force model to run on CPU") |
10 | 11 | parser.add_argument("--input-size", type=int, default=512)
|
| 12 | +parser.add_argument("--topk", type=int, default=5, help="How many top predictions should be considered to chose the next token.") |
| 13 | +parser.add_argument("--temperature", type=float, default=1, help="Temperature of the predictions. The higher, the 'randomer'.") |
11 | 14 |
|
12 | 15 | args = parser.parse_args()
|
13 | 16 |
|
|
41 | 44 | data = [prompt]
|
42 | 45 | parameters_input = {'word_start': "Ġ", 'suffix_start': ""}
|
43 | 46 | parameters_mllib = {}
|
44 |
| - parameters_output = {'best':3} |
| 47 | + parameters_output = {'best':args.topk} |
45 | 48 | result = dd.post_predict(sname, data, parameters_input,parameters_mllib,parameters_output)
|
46 |
| - word = result['body']['predictions'][0]['classes'][0]['cat'].replace("Ġ", " ").replace("Ċ", "\n") |
47 |
| - print(word, sep='', end='') |
| 49 | + |
| 50 | + # Select result from the returned tokens |
| 51 | + word_probs = list() |
| 52 | + total_probs = 0 |
| 53 | + |
| 54 | + for cls in result['body']['predictions'][0]['classes']: |
| 55 | + word = cls['cat'].replace("Ġ", " ") |
| 56 | + # dede does not support \n character well, so we don't select tokens containing a new line |
| 57 | + if 'Ċ' in word: |
| 58 | + continue |
| 59 | + |
| 60 | + prob = pow(cls['prob'], args.temperature) |
| 61 | + total_probs += prob |
| 62 | + word_probs.append((word, prob)) |
| 63 | + |
| 64 | + selector = random.uniform(0, total_probs) |
| 65 | + total_probs = 0 |
| 66 | + |
| 67 | + for word, prob in word_probs: |
| 68 | + total_probs += prob |
| 69 | + if total_probs > selector: |
| 70 | + selected_word = word |
| 71 | + break |
| 72 | + |
| 73 | + print(selected_word, sep='', end='') |
48 | 74 | sys.stdout.flush()
|
49 |
| - prompt += word |
| 75 | + prompt += selected_word |
0 commit comments