Skip to content
/ NAST Public

Codes for "NAST: A Non-Autoregressive Generator with Word Alignment for Unsupervised Text Style Transfer" (ACL 2021 findings)

License

Notifications You must be signed in to change notification settings

thu-coai/NAST

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

NAST

This repository contains the codes and model outputs for the paper NAST: A Non-Autoregressive Generator with Word Alignment for Unsupervised Text Style Transfer (Findings of ACL 2021)

overview

Outputs

We release the outputs of NAST under outputs.

YELP: outputs/YELP/{stytrans|latentseq}_{simple|learnable}/{pos2neg|neg2pos}.txt
GYAFC: outputs/GYAFC/{stytrans|latentseq}_{simple|learnable}/{fm2inf|inf2fm}.txt
  • {stytrans|latentseq} indicates the base model, i.e., StyTrans or LatentSeq.
  • {simple|learnable} indicates the two alignment strategies.

Requirements

  • python >= 3.7
  • pytorch >= 1.7.1
  • cotk == 0.1.0 (pip install cotk == 0.1.0)
  • transformers == 3.0.2
  • tensorboardX

Evaluation

The evaluation code is under eval.

We use 6 metrics in paper:

  • PPL: The perplexity of transferred sentences, which is evaluated by a finetuned GPT-2 base.
  • Acc: The accuracy of the transferred sentences' style, which is evaluated by a finetuned Roberta-base.
  • SelfBleu: The bleu score between the source sentences and the transferred sentences.
  • RefBleu: The bleu score between the transferred sentences and the human references.
  • G2: Geometric mean of Acc and RefBLEU, sqrt(Acc * RefBLEU).
  • H2: Harmonic mean of Acc and RefBLEUAcc * RefBLEU / (Acc + RefBLEU).

The codes also provides other 3 metrics:

  • self_G2: Geometric mean of Acc and SelfBLEU, sqrt(Acc * SelfBleu).
  • self_H2: Harmonic mean of Acc and SelfBLEUAcc * SelfBLEU / (Acc + SelfBLEU).
  • Overall: Use G2 if available, otherwise self_G2.

Data Preparation

The YELP data can be downloaded here and should be put under data/yelp.

We cannot provide the GYAFC data because copyright issues. You can download the data and the human references, and then preprocess the data following the format as the YELP data. We use family&relationship domain in all our experiments. The GYAFC data should be put under data/GYAFC_family.

Pretrained Classifier & Language Model

The evaluation codes require a pretrained classifier and a language model. We provide our pretrained models below.

Classifier Language Model
YELP Link Link
GYAFC Link Link

Download the models and put them under the ./eval/model/.

See the training instructions for how to train the classifier and language model. You should keep the same classifier and language model to evaluate NAST and baselines, otherwise the results cannot be compared.

Usage

For YELP:

cd eval
python eval_yelp.py --test0 test0.out --test1 test1.out

test0.out and test1.out should be the generated outputs.

Other arguments (Optional):

--allow_unk (Allow unknown tokens in generated outputs)
--dev0 dev0.out  (Evaluate the result on the dev set)
--dev1 dev1.out  (Evaluate the result on the dev set)
--datadir DATADIR (The data path, default: ../yelp_transfer_data)
--clsrestore MODELNAME (The file name of the pretrained classifier, default: cls_yelp_best. The corresponding path is ./model/MODELNAME.model)
--lmrestore MODELNAME (The file name of the pretrained language model, default: lm_yelp_best. Indicating ./model/MODELNAME.model)
--cache  (Build cache to make the evaluation faster)

For GYAFC:

python eval_GYAFC.py --test0 test0.out --test1 test1.out

The other arguments are similar with YELP.

Example Outputs

domain  acc     self_bleu       ref_bleu        ppl   self_g2    self_h2    g2     h2     overall
test0   0.862   0.629   0.491   156.298 0.737   0.727   0.650   0.625   0.650
test1   0.910   0.638   0.633   88.461  0.762   0.750   0.759   0.747   0.759

You can find results of NAST here.

Train your Classifier / Language Model

Training scripts:

cd eval
% train a classifier
python run_cls.py --name CLSNAME --dataid ../data/yelp --pos_weight 1 --cuda
% train a language model
python run_lm.py --name LMNAME --dataid ../data/yelp --cuda

Arguments:

  • name can be an arbitrary string, which is used for identifying checkpoints and tensorboard curves.
  • dataid specifies the data path.
  • pos_weight specifies the sample weight for label 1 (positive sentences in YELP dataset). A number bigger than 1 make the model bias to the label 1. (In GYAFC, we use pos_weight=2.)
  • cuda specifies the model use GPU in training.

See run_cls.py or run_lm.py for more arguments.

You can track the training process by Tensorboard, where the log will be under ./eval/tensorboard.

The trained model will be saved in ./eval/model.

Training: Style Transformer

Data Preparation

Follow the same instructions as here.

Use the Pretrained Classifier

The classifier is used for validation.

You can download a pretrained classifier or train a classifier yourself. Then put them under ./styletransformer/model.

Train NAST

Simple Alignment:

cd styletransformer
python run.py --name MODELNAME --dataid ../data/yelp --clsrestore cls_yelp_best

Learnable Alignment:

cd styletransformer
python run.py --name MODELNAME --dataid ../data/yelp --clsrestore cls_yelp_best --use_learnable --pretrain_batch 1000

Arguments:

  • name can be an arbitrary string, which is used for identifying checkpoints and tensorboard curves.
  • dataid specifies the data path.
  • clsrestore specifies the name of pretrained model.
  • use_learnable uses learnable alignment.
  • pretrain_batch specifies steps for pretraining (only use cycle loss).

See run.py for more arguments.

You can track the training process by Tensorboard, where the log will be under ./styletransformer/tensorboard.

The trained model will be saved in ./styletransformer/model.

Todo

  • Add the implementation for LatentSeq

Acknowledgement & Related Repository

Thanks DualRL for providing multiple human references and some baselines' outputs. Thanks StyIns for other baselines' outputs. Thanks StyTrans and LatentSeq for providing great base models.

Citing

Please kindly cite our paper if this paper and the codes are helpful.

@inproceedings{huang2021NAST,
  author = {Fei Huang and Zikai Chen and Chen Henry Wu and Qihan Guo and Xiaoyan Zhu and Minlie Huang},
  title = {{NAST}: A Non-Autoregressive Generator with Word Alignment for Unsupervised Text Style Transfer},
  booktitle = {Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics: Findings},
  year = {2021}
}

About

Codes for "NAST: A Non-Autoregressive Generator with Word Alignment for Unsupervised Text Style Transfer" (ACL 2021 findings)

Topics

Resources

License

Stars

Watchers

Forks

Languages