|
| 1 | +import random |
| 2 | +import sys |
| 3 | +import argparse |
| 4 | +from dd_client import DD |
| 5 | + |
| 6 | +parser = argparse.ArgumentParser(description="Use DeepDetect and GPT-2 to generate text") |
| 7 | +parser.add_argument("-r", "--repository", required=True, help="Model repository") |
| 8 | +parser.add_argument("--host", type=str, default="localhost") |
| 9 | +parser.add_argument("--port", type=int, default=8080) |
| 10 | +parser.add_argument("--cpu", action='store_true', help="Force model to run on CPU") |
| 11 | +parser.add_argument("--input-size", type=int, default=512) |
| 12 | +parser.add_argument("--topk", type=int, default=5, help="How many top predictions should be considered to chose the next token.") |
| 13 | +parser.add_argument("--temperature", type=float, default=1, help="Temperature of the predictions. The higher, the 'randomer'.") |
| 14 | + |
| 15 | +args = parser.parse_args() |
| 16 | + |
| 17 | +# dd global variables |
| 18 | +sname = 'gpt-2' |
| 19 | +description = 'Inference with GPT-2' |
| 20 | +mllib = 'torch' |
| 21 | + |
| 22 | +dd = DD(args.host, args.port) |
| 23 | +dd.set_return_format(dd.RETURN_PYTHON) |
| 24 | + |
| 25 | +# setting up the ML service |
| 26 | +model = {'repository':args.repository} |
| 27 | +parameters_input = { |
| 28 | + 'connector':'txt', |
| 29 | + 'ordered_words': True, |
| 30 | + 'wordpiece_tokens': True, |
| 31 | + 'punctuation_tokens': True, |
| 32 | + 'lower_case': False, |
| 33 | + 'width': args.input_size |
| 34 | +} |
| 35 | +parameters_mllib = {'template':'gpt2', 'gpu':True} |
| 36 | +parameters_output = {} |
| 37 | +dd.put_service(sname,model,description,mllib, |
| 38 | + parameters_input,parameters_mllib,parameters_output) |
| 39 | + |
| 40 | +# generating text |
| 41 | +prompt = input("Enter beggining of sentence >>> ") |
| 42 | + |
| 43 | +for i in range(0, 256): |
| 44 | + data = [prompt] |
| 45 | + parameters_input = {'word_start': "Ġ", 'suffix_start': ""} |
| 46 | + parameters_mllib = {} |
| 47 | + parameters_output = {'best':args.topk} |
| 48 | + result = dd.post_predict(sname, data, parameters_input,parameters_mllib,parameters_output) |
| 49 | + |
| 50 | + # Select result from the returned tokens |
| 51 | + word_probs = list() |
| 52 | + total_probs = 0 |
| 53 | + |
| 54 | + for cls in result['body']['predictions'][0]['classes']: |
| 55 | + word = cls['cat'].replace("Ġ", " ") |
| 56 | + # dede does not support \n character well, so we don't select tokens containing a new line |
| 57 | + if 'Ċ' in word: |
| 58 | + continue |
| 59 | + |
| 60 | + prob = pow(cls['prob'], args.temperature) |
| 61 | + total_probs += prob |
| 62 | + word_probs.append((word, prob)) |
| 63 | + |
| 64 | + selector = random.uniform(0, total_probs) |
| 65 | + total_probs = 0 |
| 66 | + |
| 67 | + for word, prob in word_probs: |
| 68 | + total_probs += prob |
| 69 | + if total_probs > selector: |
| 70 | + selected_word = word |
| 71 | + break |
| 72 | + |
| 73 | + print(selected_word, sep='', end='') |
| 74 | + sys.stdout.flush() |
| 75 | + prompt += selected_word |
0 commit comments