Skip to content

ufal/augpt

Repository files navigation

AuGPT

tests

Getting started

Start with creating a python 3.7 venv and installing requirements.txt. Python 3.8 is not supported by ConvLab-2. Also, newer version of transformers is unfortunatelly not supported by ConvLab-2, therefore you need to install the legacy transformers version.

Get started by cloning this repository:

git clone https://github.com/ufal/augpt.git

Downloading datasets

First, start by installing required packages. If you intend to train on these datasets, you can install all required packages by running:

pip install -r requirements.txt

Otherwise, install the required packages by running:

pip install -r requirements-minimal.txt

To download datasets, run scripts/download_{dataset}.py, where dataset is the name of the dataset you need. Supported datasets:

  • taskmaster: The Taskmaster corpus [1] comprising over 55,000 spoken and written task-oriented dialogs in over a dozen domains.
  • schemaguided: The Schema-Guided Dialogue [2] dataset consisting of over 20k annotated multi-domain, task-oriented conversations between a human and a virtual assistant.
  • multiwoz: The MultiWOZ 2.0 dataset [3] - a large-scale multi-domain wizard-of-oz dataset for task-oriented dialogue modelling.
  • convlab_multiwoz: The MultiWOZ 2.1 dataset [4] - a cleaner version of MultiWOZ 2.0 with span information.

In this work, we name the union of taskmaster and schemaguided as bigdata.

Interact and generate

To interact with the model or use it in your own pipline, you need to ensure all the required packages are present by running:

pip install -r requirements-minimal.txt

To run the model in interactive mode, you can use interact.py utility. Alternatively, to use the model in your code, you can modify the following code:

import pipelines  # Required here, modifies the transformers package to support AuGPT pipeline.
import transformers

# Loads the pipeline with MultiWOZ 2.1 model
pipeline = transformers.pipeline('augpt-conversational', 'jkulhanek/aug-mw-21')

# Either AuGPTConversation or Conversation can be used
conversation = pipelines.AuGPTConversation('Hi, I need a hotel')

conversation = conversation(pipeline)
print(conversation.generated_responses[-1])

To generate the predictions, use generate.py script.

./generate.py --model jkulhanek/aug-mw-21 --dataset multiwoz-2.1-test --file predictions.txt

Training and evaluation

The following scripts creates a virtual environment and installs required packages for training and ConvLab-2 evaluation.

python -m venv ~/envs/dstc
source ~/envs/dstc/bin/activate
pip install -r requirements.txt
cd ~/source 
git clone git@github.com:ufal/ConvLab-2.git
cd ConvLab-2
git reset --hard 8b4464c57de0fbc497ce3532532c30ae461906e9
pip install -e . --no-deps
python -m spacy download en_core_web_sm

Training bigdata model

The bigdata pre-trained model can be trained using the following arguments:

./train.py --epochs 8 --restrict-domains --train-dataset schemaguided-train+taskmaster-train --dev-dataset schemaguided-dev+taskmaster-dev --validation-steps 10000 --logging-steps 1000 --warmup-steps 5000 --evaluation-dialogs 0 --fp16

The pre-trained model can be downloaded from the Hugging Face model repository as jkulhanek/augpt-bigdata.

Fine-tuning on MultiWOZ

The pretrained model can be finetuned on MultiWOZ 2.x dataset as follows:

./train_multiwoz.py --train-dataset multiwoz-2.1-train --dev-dataset multiwoz-2.1-val --model jkulhanek/augpt-bigdata --backtranslations latest --response-loss unlikelihood --epochs 10 --fp16 --clean-samples

For MultiWOZ 2.0, substitute the correct dataset version.

Distributed training

To start the training on single CPU node (for testing), run the training with the following arguments:

./train.py --no-cuda --gradient-accumulation-steps 4

NOTE: For optimal performance at least four GPUs are required for training.

To run the training with single GPU:

./train.py --gradient-accumulation-steps 4

To run on single node with multiple GPUs, run the following command:

python -m torch.distributed.launch --nproc_per_node=NUM_GPUS_YOU_HAVE train.py

In this case the expected number of GPUs is four, you may need to adjust learning_rate and/or gradient-accumulation-steps accordingly.

To run the training on multiple nodes with multiple GPUs, you can use pytorch launch utility https://pytorch.org/docs/stable/distributed.html#launch-utility. Alternatively, consult your job scheduling system. You may need to set the environment variables: LOCAL_RANK, RANK, WORLD_SIZE, MASTER_PORT, MASTER_ADDR. In this case, RANK is global number of current process across the world and LOCAL_RANK is the number of each process running on single node. Every node is required to have as many GPUs as there are processes running on single machine.

Evaluation

All packages required for the training must also be installed for evaluation.

ConvLab-2 evaluation

To evaluate your trained model using ConvLab-2 evaluation, run the following script:

./evaluate_convlab.py --model {model}

MultiWOZ 2.x evaluation

To evaluate your trained model using MultiWOZ evaluation, run the following:

./evaluate_multiwoz.py --model {model} --dataset multiwoz-2.1-test

If you have your predictions generated by running generate.py script, you can evaluate them by running:

./evaluate_multiwoz.py --file predictions.txt --dataset multiwoz-2.1-test

For MultiWOZ 2.0, substitute the correct dataset version.

References

[1]: Byrne, B.; Krishnamoorthi, K.; Sankar, C.; Neelakantan, A.; Duckworth, D.; Yavuz, S.; Goodrich, B.; Dubey, A.; Kim, K.-Y.; and Cedilnik, A. 2019. Taskmaster-1: Toward a Realistic and Diverse Dialog Dataset.

[2]: Rastogi, A.; Zang, X.; Sunkara, S.; Gupta, R.; and Khaitan, P. 2019. Towards Scalable Multi-domain Conversational Agents: The Schema-Guided Dialogue Dataset.arXiv preprint arXiv:1909.05855.

[3]: Budzianowski, P.; Wen, T.-H.; Tseng, B.-H.; Casanueva, I.; Ultes, S.; Ramadan, O.; and Gašić, M. 2018. Multiwoz - a large-scale multi-domain wizard-of-oz dataset for task-oriented dialogue modelling. arXiv preprint arXiv:1810.00278.

[4]: Eric, M.; Goel, R.; Paul, S.; Kumar, A.; Sethi, A.; Ku, P.;Goyal, A. K.; Agarwal, S.; Gao, S.; and Hakkani-Tur, D. 2019. MultiWOZ 2.1: A Consolidated Multi-Domain Dialogue Dataset with State Corrections and State Tracking Baselines.arXiv preprint arXiv:1907.01669.

Releases

No releases published

Packages

No packages published

Languages