-
Notifications
You must be signed in to change notification settings - Fork 561
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
author Louis J <ljean@etud.insa-toulouse.fr> 1563984477 +0200 committer Guillaume Infantes <guillaume.infantes@jolibrain.com> 1576060297 +0100 parent 7eb6443 author Louis J <ljean@etud.insa-toulouse.fr> 1563984477 +0200 committer Guillaume Infantes <guillaume.infantes@jolibrain.com> 1576059845 +0100 LOUISJ'S COMMITS: Move dataset management and model building in separate classes Add train and test The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way BERT finetuning with custom number of classes Add self supervised Masked LM learning Save solver checkpoint along with model Ensure label is of correct dimension Fix masked_lm, add more explicit error message Add script to trace huggingface models Add classfication on hidden states to be able to use masked lm model for classif Better API, more features, less memory usage and fix bugs Add unit tests for training Move training parameters to solver and net Add comments Download tar from deepdetect.com torch 1.3.1 alone working with caffe patch correction: add pcaffe/logging.h force -j8 when building libtorch (default is -j nproc) points to model traced for torch 131 GUILLAUME COMMITS: changes for torch 131 Move dataset management and model building in separate classes Add train and test The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way BERT finetuning with custom number of classes Add self supervised Masked LM learning Save solver checkpoint along with model Ensure label is of correct dimension Better API, more features, less memory usage and fix bugs Move training parameters to solver and net Add comments Add inference support for GPT2 Make lower case optional Add gpt2 training Add gpt2 demo rebase all glitches in merge update to last transformers from hugginface gpt2 inference ok sanitize width vs sequence remove comment in cmakelist
- Loading branch information
Showing
17 changed files
with
2,288 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../clients/python/dd_client.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import random | ||
import sys | ||
import argparse | ||
from dd_client import DD | ||
|
||
parser = argparse.ArgumentParser(description="Use DeepDetect and GPT-2 to generate text") | ||
parser.add_argument("-r", "--repository", required=True, help="Model repository") | ||
parser.add_argument("--host", type=str, default="localhost") | ||
parser.add_argument("--port", type=int, default=8080) | ||
parser.add_argument("--cpu", action='store_true', help="Force model to run on CPU") | ||
parser.add_argument("--input-size", type=int, default=512) | ||
parser.add_argument("--topk", type=int, default=5, help="How many top predictions should be considered to chose the next token.") | ||
parser.add_argument("--temperature", type=float, default=1, help="Temperature of the predictions. The higher, the 'randomer'.") | ||
|
||
args = parser.parse_args() | ||
|
||
# dd global variables | ||
sname = 'gpt-2' | ||
description = 'Inference with GPT-2' | ||
mllib = 'torch' | ||
|
||
dd = DD(args.host, args.port) | ||
dd.set_return_format(dd.RETURN_PYTHON) | ||
|
||
# setting up the ML service | ||
model = {'repository':args.repository} | ||
parameters_input = { | ||
'connector':'txt', | ||
'ordered_words': True, | ||
'wordpiece_tokens': True, | ||
'punctuation_tokens': True, | ||
'lower_case': False, | ||
'width': args.input_size | ||
} | ||
parameters_mllib = {'template':'gpt2', 'gpu':True} | ||
parameters_output = {} | ||
dd.put_service(sname,model,description,mllib, | ||
parameters_input,parameters_mllib,parameters_output) | ||
|
||
# generating text | ||
prompt = input("Enter beggining of sentence >>> ") | ||
|
||
for i in range(0, 256): | ||
data = [prompt] | ||
parameters_input = {'word_start': "Ġ", 'suffix_start': ""} | ||
parameters_mllib = {} | ||
parameters_output = {'best':args.topk} | ||
result = dd.post_predict(sname, data, parameters_input,parameters_mllib,parameters_output) | ||
|
||
# Select result from the returned tokens | ||
word_probs = list() | ||
total_probs = 0 | ||
|
||
for cls in result['body']['predictions'][0]['classes']: | ||
word = cls['cat'].replace("Ġ", " ") | ||
# dede does not support \n character well, so we don't select tokens containing a new line | ||
if 'Ċ' in word: | ||
continue | ||
|
||
prob = pow(cls['prob'], args.temperature) | ||
total_probs += prob | ||
word_probs.append((word, prob)) | ||
|
||
selector = random.uniform(0, total_probs) | ||
total_probs = 0 | ||
|
||
for word, prob in word_probs: | ||
total_probs += prob | ||
if total_probs > selector: | ||
selected_word = word | ||
break | ||
|
||
print(selected_word, sep='', end='') | ||
sys.stdout.flush() | ||
prompt += selected_word |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py | ||
index 894559ed43..7887147a28 100644 | ||
--- a/tools/setup_helpers/cmake.py | ||
+++ b/tools/setup_helpers/cmake.py | ||
@@ -229,6 +229,7 @@ class CMake: | ||
'CUDA_NVCC_EXECUTABLE', | ||
'CUDNN_LIBRARY', | ||
'CUDNN_INCLUDE_DIR', | ||
+ 'CAFFE2_LINK_LOCAL_PROTOBUF', | ||
'EXPERIMENTAL_SINGLE_THREAD_POOL', | ||
'INSTALL_TEST', | ||
'MKL_THREADING', |
Oops, something went wrong.