Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bert + gpt2 inference + training wrt torch 1.3.1 and transformers 2.2.1 #673

Merged
merged 1 commit into from
Feb 5, 2020

Conversation

fantes
Copy link
Contributor

@fantes fantes commented Dec 11, 2019

tested OK on unit tests and examples from louijie

include

  • louijie's work on bert and gpt2
  • update wrt normalizing data from beniz (on master)
  • update to make it work with libtorch 1.3.1
  • update to use hugginface transormers 2.2.1

doc for bert training : #637 [edited for new versions or torch and transformers]
doc for gpt2 training : #644

@fantes fantes changed the title bert training wrt torch 1.3.1 bert + gpt2 training wrt torch 1.3.1 Dec 11, 2019
@fantes fantes force-pushed the bert_training branch 2 times, most recently from d981238 to 293abb2 Compare December 12, 2019 13:45
@fantes fantes changed the title bert + gpt2 training wrt torch 1.3.1 bert + gpt2 training wrt torch 1.3.1 and transformers 2.2.1 Dec 12, 2019
@fantes fantes force-pushed the bert_training branch 2 times, most recently from 3e08630 to 4acd30f Compare December 12, 2019 16:13
@fantes fantes changed the title bert + gpt2 training wrt torch 1.3.1 and transformers 2.2.1 bert + gpt2 inference + training wrt torch 1.3.1 and transformers 2.2.1 Dec 12, 2019
author Louis J <ljean@etud.insa-toulouse.fr> 1563984477 +0200
committer Guillaume Infantes <guillaume.infantes@jolibrain.com> 1576060297 +0100

parent 7eb6443
author Louis J <ljean@etud.insa-toulouse.fr> 1563984477 +0200
committer Guillaume Infantes <guillaume.infantes@jolibrain.com> 1576059845 +0100

LOUISJ'S COMMITS:

Move dataset management and model building in separate classes

Add train and test

The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way

BERT finetuning with custom number of classes

Add self supervised Masked LM learning

Save solver checkpoint along with model

Ensure label is of correct dimension

Fix masked_lm, add more explicit error message

Add script to trace huggingface models

Add classfication on hidden states to be able to use masked lm model for classif

Better API, more features, less memory usage and fix bugs

Add unit tests for training

Move training parameters to solver and net

Add comments

Download tar from deepdetect.com

torch 1.3.1 alone

working with caffe

patch correction: add pcaffe/logging.h

force -j8 when building libtorch (default is -j nproc)

points to model traced for torch 131

GUILLAUME COMMITS:
changes for torch 131

Move dataset management and model building in separate classes

Add train and test

The fix on txtinputconnector is temporary, vocab generation should be fixed a more robust way

BERT finetuning with custom number of classes

Add self supervised Masked LM learning

Save solver checkpoint along with model

Ensure label is of correct dimension

Better API, more features, less memory usage and fix bugs

Move training parameters to solver and net

Add comments

Add inference support for GPT2

Make lower case optional

Add gpt2 training

Add gpt2 demo

rebase all

glitches in merge

update to last transformers from hugginface

gpt2 inference ok

sanitize width vs sequence

remove comment in cmakelist
@fantes
Copy link
Contributor Author

fantes commented Dec 13, 2019

BERT EXAMPLES

<!> When tracing models, use pytorch 1.3.1. and latest transformers (formerly pytorch-transformers)

pip3 install torch==1.3.1 transformers

Added parameters

  • Text Input Connector
Parameter Type Optional Default Description
ordered_words bool yes false word-based processing with positionnal information
wordpiece_tokens bool yes false if vocabulary contains partial words. Words can be split into multiple tokens.
punctuation_tokens bool yes false Treat each punctuation sign as a token. (if false, punctuation is stripped from input)
  • Torch MLLib
Parameter Type Optional Default Description
self_supervised string yes "" self-supervised mode: "mask" for masked language model [TODO Add options : "next" = Next token prediction for GPT2?]
embedding_size int yes 768 embedding size for NLP models
freeze_traced bool yes false Freeze the traced part of the net during finetuning (e.g. for classification)

Example: Finetune a classification model

  • Trace pytorch pretrained bert
pip3 install --user transformers
mkdir classif_training
./trace_pytorch_transformers.py bert --output-dir classif_training --vocab --verbose
  • Run dede
  • Start the service
curl -X PUT "http://localhost:8080/services/torch_bert_training" -d '{
    "description": "News20 classification service using BERT",
    "mllib": "torch",
    "model": {
        "repository": "./classif_training/"
    },
    "parameters": {
        "input": {
            "connector": "txt",
            "ordered_words": true,
            "wordpiece_tokens": true,
            "punctuation_tokens": true,
            "sequence": 512
        },
        "mllib": {
            "template":"bert",
            "nclasses": 20,
            "finetuning":true,
            "gpu": true
        }
    },
    "type": "supervised"
}
'
  • Train the model on news20 dataset
curl -X POST "http://localhost:8080/train" -d '{
    "service": "torch_bert_training", 
    "parameters": { 
         "mllib": {
            "solver": {
              "iterations":3000,
              "test_interval":250,
              "base_lr":1e-5,
              "iter_size":4,
              "snapshot":250,
              "solver_type":"ADAM"
            },
            "net": {
              "batch_size":8,
              "test_batch_size":4
            }
        },
        "input": {
            "shuffle":true
        },
        "output": {
            "measure":["f1", "mcll", "acc", "cmdiag", "cmfull"]
        }
    }, 
    "data": ["/opt/data/news20/train/", "/opt/data/news20/test/"]
}
'

Example: Finetune language model

  • Trace pytorch pretrained bert
pip3 install --user transformers
mkdir lm_training
./trace_pytorch_transformers.py bert -vo lm_training --vocab
  • Run dede
  • Start the service
curl -X PUT "http://localhost:8080/services/torch_bert_lm" -d '{
    "description": "BERT language model finetuning on News20 ",
    "mllib": "torch",
    "model": {
        "repository": "./lm_training/"
    },
    "parameters": {
        "input": {
            "connector": "txt",
            "ordered_words": true,
            "wordpiece_tokens": true,
            "punctuation_tokens": true,
            "sequence": 512
        },
        "mllib": {
            "template":"bert",
            "self_supervised":"mask",
            "finetuning": true,
            "gpu": true
        }
    },
    "type": "supervised"
}
'
  • Train the model on news20 dataset
curl -X POST "http://localhost:8080/train" -d '{
    "service": "torch_bert_lm", 
    "parameters": { 
         "mllib": {
            "solver": {
              "iterations":3000,
              "test_interval":250,
              "base_lr":1e-5,
              "iter_size":8,
              "snapshot":250,
              "solver_type":"ADAM"
            },
            "net": {
              "batch_size":4,
              "test_batch_size":4
            }
        },
        "input": {
            "shuffle":true,
            "test_split":0.03
        },
        "output": {
            "measure":["acc", "acc-5"]
        }
    }, 
    "data": ["/opt/data/news20/train/"]
}
'

@fantes
Copy link
Contributor Author

fantes commented Dec 13, 2019

GPT 2 EXAMPLES

Example: trace gpt2 and run it with the demo

Run dede, then

mkdir gpt2_repo
cd tools/torch/
./trace_pytorch_transformers.py gpt2 -vo ../../gpt2_repo --vocab
cd ../../demo/gpt2
python3 -m run_gpt2 --r ../../gpt2_repo/ --host localhost --port 8080 --input-size 512 --topk 40

Example: call gpt2 inference with curl

curl -X PUT "http://10.10.77.61:8501/services/torch_gpt2" -d '{
    "description": "GPT2 service",
    "mllib": "torch",
    "model": {
        "repository": "/opt/models/gpt2"
    },
    "parameters": {
        "input": {
            "connector": "txt",
            "ordered_words": true,
            "wordpiece_tokens": true,
            "punctuation_tokens": true,
            "lower_case": false,
            "sequence":512,
            "word_start":"Ġ",
            "suffix_start":""
        },
        "mllib": {
            "template":"gpt2",
            "gpu": true
        }
    },
    "type": "supervised"
}
' 

curl -X POST "http://10.10.77.61:8501/predict" -d '{
    "service": "torch_gpt2",
    "parameters": {
        "input": {
        },
        "output": {
            "best":3
        }
    },
    "data": ["Why did the chicken cross the"]
}
' 

Example: finetuning gpt2

<!> Test takes a lot of time and memory -- requires f1 sparse

curl -X PUT "http://10.10.77.61:8501/services/gpt2" -d '{
    "description": "gpt2 training",
    "mllib": "torch",
    "model": {
        "repository": "/data1/louisj/gpt2_news20/"
    },
    "parameters": {
        "input": {
            "connector": "txt",
            "ordered_words": true,
            "wordpiece_tokens": true,
            "punctuation_tokens": true,
            "lower_case": false,
            "sequence":256
        },
        "mllib": {
            "template":"gpt2",
            "finetuning": true,
            "gpu": true
        }
    },
    "type": "supervised"
}
' 
curl -X POST "http://10.10.77.61:8501/train" -d '{
    "service": "gpt2",
    "parameters": {
        "mllib": {
            "solver":{
                "iterations": 3000,
                "test_interval":250,
                "base_lr":1e-3,
                "iter_size":8,
                "snapshot":250,
                "solver_type":"ADAM"
            },
            "net":{
                "test_batch_size":1,
                "batch_size":4
            }
        },
        "input": {
            "shuffle":true,
            "word_start":"Ġ",
            "suffix_start":""
        },
        "output": {
            "measure":["f1", "mcll", "acc", "cmdiag", "cmfull"]
        }
    },
    "data":["/home/louisj/dede/data/news20/train/"]
}
' 

@beniz beniz added the v0.9.6 label Dec 18, 2019
@beniz beniz merged commit 3d17e9d into jolibrain:master Feb 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants