Skip to content

MtSomeThree/constrDecoding

Repository files navigation

Controllable Text Generation with Neurally-Decomposed Oracle

This repository contains the source code to reproduce the experiments in NeurIPS 2022 paper Controllable Text Generation with Neurally-Decomposed Oracle by Tao Meng, Sidi Lu, Nanyun Peng and Kai-Wei Chang.

We are now working on the camera ready and the codebase is not a in a stable version. If you come up with some technical issue, please feel free to leave an issue or send an email to the first author.

  • Abstract

We propose a general and efficient framework to control auto-regressive generation models with NeurAlly-Decomposed Oracle (NADO). Given a pre-trained base language model and a sequence-level boolean oracle function, we propose to decompose the oracle function into token-level guidance to steer the base model in text generation. Specifically, the token-level guidance is approximated by a neural model trained with examples sampled from the base model, demanding no additional auxiliary labeled data. We present the closed-form optimal solution to incorporate the token-level guidance into the base model for controllable generation. We further provide a theoretical analysis of how the approximation quality of NADO affects the controllable generation results. Experiments conducted on two applications: (1) text generation with lexical constraints and (2) machine translation with formality control demonstrate that our framework efficiently guides the base model towards the given oracle while maintaining high generation quality.

  • Experiments

This repository will contain both experiments described in this paper. So far the LCG part is still under construction and expected to come out later October.

  • Data

The machine translation formality change experiments leverage the CALLHOME Spanish-English Speech Translation Corpus as source data, and evaluate the BLUE score with the fluent references. Note that LDC access is required for the first dataset.

  • Running experiments

Requirements

pip install -r requirements.txt

Running

python train_MT.py

The code will automatically download MarianMT model and sample translated texts from source texts from Fisher-and-Callhome Corpus. The sampled data will be dumped in ./dump/MT directory. The sampled data is labeled by an formality oracle trained in FUDGE paper. A NADO model will be trained by those labeled sampled data. The translated results will be evaluated based on oracle scores and the BLEU scores compared to fluent references.

Alternative arguments:
  --sample_batch_size  the batch size in sampling. Must be integer times of 8.
  --batch_size  the batch size in training. Must be a divider of sample_batch_size
  --regularization  the strength of regularization
  --max_length  the maximum length accepted in training or evaluation

About

Constrained Decoding Project

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published