Skip to content

Commit

Permalink
Change Gigaword to optional dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Huffon committed Apr 13, 2020
1 parent 6462630 commit 8c3a418
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 26 deletions.
21 changes: 7 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Sentence Compressor

This repository contains Sentence Compressor API trained using **Transformer** and **BART** architecture
This repository contains **Sentence Compressor** API trained using **Transformer** and **BART** architecture

Lots of code are borrowed from [fairseq](https://github.com/pytorch/fairseq) library

Expand Down Expand Up @@ -29,47 +29,40 @@ pip install fairseq requests pandas tensorflow-datasets
bash preprocess.sh
```

- **Transformer**'s best loss: **3.118**
- **BART**'s best loss: **2.838**

### (1) Transformer
- To train **Transformer** using pre-processed dataset, run following command:

```bash
python train_transformer.py
```

- To **generate** example sentence using [pre-trained Transformer](), run following command:
- To **generate** example sentence using pre-trained **Transformer**, run following command:

```
wget MODEL
tag xvzf transformer.tar.gz
python generate_transformer.py
```

### (2) BART

- To download and fine-tune pre-trained **BART**, run following command:
- To **download** and **fine-tune** pre-trained **BART**, run following command:

```bash
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
tar xvzf bart.large.tar.gz
bash train_bart.sh
```

- To **generate** example sentence using [fine-tuned BART](), run following command:
- To **generate** example sentence using fine-tuned **BART**,, run following command:

```
wget MODEL
tag xvzf bart.tar.gz
python generate_bart.py
```

<br/>

## Example

- To test your own sentences, fill [**input.txt**](output/input.txt) with your sentences
- To **test** your own sentences, fill [**input.txt**](output/input.txt) with your sentences

```
[Transformer]
Expand Down Expand Up @@ -109,7 +102,7 @@ Target length

## References
- [**Sentence Compression Dataset**](https://github.com/google-research-datasets/sentence-compression)
- [Gigaword Dataset](https://www.tensorflow.org/datasets/catalog/gigaword)
- [Pre-trained **BART**](https://github.com/pytorch/fairseq/tree/master/examples/bart)
- [**fairseq**](https://github.com/pytorch/fairseq)
- [fairseq Transformer __*train*__ script](https://github.com/kakaobrain/helo_word/blob/master/gec/track.py#L91)
- [Pre-trained **BART**](https://github.com/pytorch/fairseq/tree/master/examples/bart)
- [Gigaword Dataset](https://www.tensorflow.org/datasets/catalog/gigaword) (Opt.)
2 changes: 1 addition & 1 deletion preprocess.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rm broadcastnews-compressions.tar.gz written-compressions.tar.gz Release.zip
mkdir data
python utils/preprocess.py

rm -rf annotator1 annotator2 annotator3 written Release
rm -rf annotator1 annotator2 annotator3 written Release sentence-compression

wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
Expand Down
4 changes: 2 additions & 2 deletions train_bart.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
TOTAL_NUM_UPDATES=20000
TOTAL_NUM_UPDATES=15000
WARMUP_UPDATES=500
LR=3e-05
MAX_TOKENS=4000
UPDATE_FREQ=4
UPDATE_FREQ=2
BART_PATH=bart.large/model.pt

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python utils/train.py data-bin \
Expand Down
18 changes: 9 additions & 9 deletions utils/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List

import pandas as pd
import tensorflow_datasets as tfds
# import tensorflow_datasets as tfds

random.seed(42)

Expand Down Expand Up @@ -147,21 +147,21 @@ def main():
logging.info("[TRAIN] MSR Dataset")
src, tgt = preprocess_msr(src, tgt)

logging.info("[TRAIN] Gigaword Dataset")
giga_train, giga_val = preprocess_gigaword()
giga_src, giga_tgt = giga_train
src += giga_src
tgt += giga_tgt
# logging.info("[TRAIN] Gigaword Dataset")
# giga_train, giga_val = preprocess_gigaword()
# giga_src, giga_tgt = giga_train
# src += giga_src
# tgt += giga_tgt
logging.info(f"Current dataset size: {len(src)}")
create_pair("train", src, tgt)

logging.info("[VAL] Google Dataset")
prefix = "data/comp-data.eval"
nums = [""]
src, tgt = preprocess_google("val", prefix, nums)
giga_src, giga_tgt = giga_val
src += giga_src
tgt += giga_tgt
# giga_src, giga_tgt = giga_val
# src += giga_src
# tgt += giga_tgt
create_pair("val", src, tgt)


Expand Down

0 comments on commit 8c3a418

Please sign in to comment.