-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7f0cd21
commit af12140
Showing
23 changed files
with
1,828 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
BERT_Base_Uncased | ||
results_045 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
Apache License | ||
|
||
Apache License | ||
Version 2.0, January 2004 | ||
http://www.apache.org/licenses/ | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# UDA(Unsupervised Data Augmentation) with BERT | ||
This is re-implementation of Google's UDA [[paper]](https://arxiv.org/abs/1904.12848)[[tensorflow]](https://github.com/google-research/uda) in pytorch through with Kakao Brain's Pytorchic BERT[[pytorch]](https://github.com/dhlee347/pytorchic-bert). | ||
|
||
Model | UDA official | This repository | ||
-- | -- | -- | ||
UDA (X) | 68% | | ||
UDA (O) | 90% | 88.45% | ||
![](README_data/2019-08-30-22-18-28.png) | ||
|
||
|
||
## UDA | ||
> UDA(Unsupervised Data Augmentation) is a semi-supervised learning method which achieves SOTA results on a wide variety of language and vision tasks. With only 20 labeled examples, UDA outperforms the previous SOTA on IMDb trained on 25,000 labeled examples. (BERT=4.51, UDA=4.20, error rate) | ||
![](README_data/2019-08-21-18-01-07.png) | ||
> * Unsupervised Data Augmentation for Consistency Training (2019 Google Brain, Q Xie et al.) | ||
#### - UDA with BERT | ||
UDA works as part of BERT. It means that UDA act as an assistant of BERT. So, in the picture above model **M** is BERT. | ||
|
||
#### - Loss | ||
UDA consist of supervised loss and unsupervised loss. Supervised loss is traditional Cross-entropy loss and Unsupervised loss is KL-divergence loss of original example and augmented example outputs. In this project, I used Back translation technique for augmentation.<br /> | ||
The supervised loss and unsupervised loss are added to form a total loss and then total loss is descent. To be careful is loss doesn't descent trough original example route. Only by labeled data and augmented unlabeled data Model's weights are updated. | ||
|
||
#### - TSA(Training Signal Annealing) | ||
There is a large gap between the amount of unlabeled data and that of labeled data. So, it is easy to overfit to labeled data. Therefore, TSA technique mask out the examples that predicted probability is bigger than threshold. The threshold is scheduled by log, linear or exponential function.<br /> | ||
![](README_data/2019-08-22-14-16-49.png) <br /> | ||
![](README_data/2019-08-22-14-16-59.png) <br /> | ||
|
||
#### - Sharpening Predictions | ||
The KL-divergence loss(ori, aug) is too small to just use. It can cause that the total loss is dominated by supervised loss. Therefore, Sharpening Prediction techniques is needed. | ||
|
||
- Confidence-based masking : Maksing out examples that the current model is not confident about. Specifically, in each minibatch, the consistency loss term is computed only on examples whose highest probability. | ||
- Softmax temperature controlling : Be used when computing the predictions on original example. Specifically, probability of original example is computed as Softmax(l(x)/τ) where l(x) denotes the logits and τ is the temperature. A lower temperature corresponds to a sharper distribution.<br /> (UDA, 2019 Google Brain, Q Xie et al.) | ||
|
||
## Requirements | ||
**UDA** : python > 3.6, fire, tqdm, tensorboardX, tensorflow, pytorch, pandas, numpy | ||
|
||
## Overview | ||
|
||
- [`download.sh`](./download.sh) : Download pre-trained BERT model from Google's official BERT and IMDb data file | ||
- [`load_data.py`](./load_data.py) : Load the data of sup, unsup | ||
- [`models.py`](./models.py) : Model calsses for a general transformer (from Pytorchic BERT's code) | ||
- [`main.py`](./main.py) : Including default BERT, UDA(TSA, Sharpening) modes | ||
- [`train.py`](./train.py) : A custom training class(Trainer class) adopted from Pytorhchic BERT's code | ||
- ***utils*** | ||
- [`configuration.py`](./utils/configuration.py) : Set a configuration from json file | ||
- [`checkpoint.py`](./utils/checkpoint.py) : Functions to load a model from tensorflow's file (from Pytorchic BERT's code) | ||
- [`optim.py`](./utils.optim.py) : Optimizer (BERTAdam class) (from Pytorchic BERT's code) | ||
- [`tokenization.py`](./utils/tokenization.py) : Tokenizers adopted from the original Google BERT's code | ||
- [`utils.py`](./utils/utils.py) : A custom utility functions adopted from Pytorchic BERT's code | ||
|
||
## Pre-works | ||
|
||
#### - Download pre-trained BERT model | ||
First, you have to download pre-trained BERT_base from Google's BERT repository. | ||
|
||
bash download_BERT_Base_Uncased.sh | ||
After running, you can get the pre-trained BERT_base_Uncased model at **/BERT_Base_Uncased** directory. | ||
|
||
#### - Data | ||
I use already pre-processed and augmented IMDb data extract from official [UDA](https://github.com/google-research/uda). If you want to use your raw data, change need_prepro = True. | ||
|
||
## Example usage | ||
This project are broadly divided into two parts(Fine-tuning, Evaluation).<br/> | ||
**Caution** : **Before runing code, you have to check and edit config file** | ||
|
||
1. **Fine-tuning** | ||
<br />You can choose train mode(train, train_eval) on non-uda.json or uda.json (default : train_eval). | ||
- Non UDA fine-tuning | ||
|
||
python main.py \ | ||
--cfg='config/non-uda.json' \ | ||
--model_cfg='config/bert_muling.json' | ||
|
||
- UDA fine-tuning | ||
|
||
python main.py \ | ||
--cfg='config/uda.json' \ | ||
--model_cfg='config/bert_muling.json' | ||
|
||
2. **Evaluation** | ||
- Basically evaluation code, dump out results file. So, you can change dump option in [main.py](./main.py) There is two mode (real_time print, make tsv file) | ||
|
||
python main.py \ | ||
--cfg='config/eval.json' \ | ||
--model_cfg='config/bert_muling.json' | ||
|
||
|
||
## Acknowledgement | ||
Thanks to references of [UDA](https://github.com/google-research/uda) and [Pytorchic BERT](https://github.com/dhlee347/pytorchic-bert), I can implement this code. | ||
|
||
## TODO | ||
It is known that further training(more pre-training by the specific corpus on already pre-trained BERT) can improve performance. But, this repository does not have pretrain code. So, pretrain code will be added. If you want to further training you can use [Pytorchic BERT](https://github.com/dhlee347/pytorchic-bert) 's pretrain.py or any BERT project. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
{ | ||
"dim": 768, | ||
"dim_ff": 3072, | ||
"n_layers": 12, | ||
"p_drop_attn": 0.1, | ||
"n_heads": 12, | ||
"p_drop_hidden": 0.1, | ||
"max_len": 512, | ||
"n_segments": 2, | ||
"vocab_size": 30522 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
{ | ||
"mode": "eval", | ||
"max_seq_length": 128, | ||
"eval_batch_size": 16, | ||
"do_lower_case": true, | ||
"data_parallel": true, | ||
"need_prepro": false, | ||
"model_file": "results_045/save/model_steps_6250.pt", | ||
"eval_data_dir": "data/imdb_sup_test.txt", | ||
"vocab":"BERT_Base_Uncased/vocab.txt", | ||
"task": "imdb" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
{ | ||
"seed": 42, | ||
"lr": 2e-5, | ||
"warmup": 0.1, | ||
"do_lower_case": true, | ||
"mode": "train_eval", | ||
"uda_mode": false, | ||
|
||
"total_steps": 10000, | ||
"max_seq_length": 128, | ||
"train_batch_size": 8, | ||
"eval_batch_size": 16, | ||
|
||
"data_parallel": true, | ||
"need_prepro": false, | ||
"sup_data_dir": "data/imdb_sup_train.txt", | ||
"eval_data_dir": "data/imdb_sup_test.txt", | ||
|
||
"model_file":null, | ||
"pretrain_file":"BERT_Base_Uncased/bert_model.ckpt", | ||
"vocab":"BERT_Base_Uncased/vocab.txt", | ||
"task": "imdb", | ||
|
||
"save_steps": 100, | ||
"check_steps": 250, | ||
"results_dir": "results_non", | ||
|
||
"is_position": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
{ | ||
"seed": 42, | ||
"lr": 2e-5, | ||
"warmup": 0.1, | ||
"do_lower_case": true, | ||
"mode": "train_eval", | ||
"uda_mode": true, | ||
|
||
"total_steps": 10000, | ||
"max_seq_length": 128, | ||
"train_batch_size": 8, | ||
"eval_batch_size": 16, | ||
|
||
"unsup_ratio": 3, | ||
"uda_coeff": 1, | ||
"tsa": "linear_schedule", | ||
"uda_softmax_temp": 0.85, | ||
"uda_confidence_thresh": 0.45, | ||
|
||
"data_parallel": true, | ||
"need_prepro": false, | ||
"sup_data_dir": "data/imdb_sup_train.txt", | ||
"unsup_data_dir": "data/imdb_unsup_train.txt", | ||
"eval_data_dir": "data/imdb_sup_test.txt", | ||
|
||
"model_file":null, | ||
"pretrain_file":"BERT_Base_Uncased/bert_model.ckpt", | ||
"vocab":"BERT_Base_Uncased/vocab.txt", | ||
"task": "imdb", | ||
|
||
"save_steps": 100, | ||
"check_steps": 250, | ||
"results_dir": "results", | ||
|
||
"is_position": false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# coding=utf-8 | ||
# Copyright 2019 The Google UDA Team Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
#!/bin/bash | ||
|
||
# **** download pretrained models **** | ||
wget storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip | ||
unzip uncased_L-12_H-768_A-12.zip && rm uncased_L-12_H-768_A-12.zip | ||
mv uncased_L-12_H-768_A-12 BERT_Base_Uncased | ||
|
||
# **** unzip data **** | ||
unzip data.zip && rm data.zip |
Oops, something went wrong.