Trains a Transformer-based model (Vaswani et al., 2017) on the WMT Machine Translation en-de dataset.
This example uses linear learning rate warmup and inverse square root learning rate schedule.
- TensorFlow datasets
wmt17_translate/de-en
andwmt14_translate/de-en
need to be downloaded and prepared. A sentencepiece tokenizer vocabulary will be automatically generated and saved on each training run. - This example additionally depends on the
sentencepiece
andtensorflow-text
packages.
The model should run with other configurations and hardware, but was explicitly tested on the following.
Hardware | Batch size | Training time | BLEU | TensorBoard.dev |
---|---|---|---|---|
TPU v3-8 | 256 | 1h 35m | 25.13 | 2020-04-21 |
python train.py --batch_size=256 --model_dir=./wmt_256 --reverse_translation=True
Creating a Cloud TPU involves creating the user GCE VM and the TPU node.
To create a user GCE VM, run the following command from your GCP console or your computer terminal where you have gcloud installed.
export ZONE=europe-west4-a
gcloud compute instances create $USER-user-vm-0001 \
--machine-type=n1-standard-16 \
--image-project=ml-images \
--image-family=tf-2-2 \
--boot-disk-size=200GB \
--scopes=cloud-platform \
--zone=$ZONE
To create a larger GCE VM, choose a different machine type.
Next, create the TPU node, following these guidelines to choose a <TPU_IP_ADDRESS>. e.g.:
export TPU_IP_ADDRESS=192.168.0.2
gcloud compute tpus create $USER-tpu-0001 \
--zone=$ZONE \
--network=default \
--accelerator-type=v3-8 \
--range=$TPU_IP_ADDRESS \
--version=tpu_driver_nightly
Now that you have created both the user GCE VM and the TPU node, ssh to the GCE VM by executing the following command, including a local port forwarding rule for viewing tensorboard:
gcloud compute ssh $USER-user-vm-0001 -- -L 2222:localhost:8888
Be sure to install the latest jax
and jaxlib
packages alongside the other requirements above.
e.g. as of April 2020 the following package versions were used successfully:
pip install -U pip
pip install -U setuptools wheel
pip install -U pip jax jaxlib sentencepiece \
tensorflow==2.2.0rc3 tensorflow-datasets==3.0.0 \
tensorflow-text==2.2.0rc2 tensorboard
git clone https://github.com/google/flax
pip install -e flax
Then, if your TPU is at IP 192.168.0.2
:
python train.py --batch_size=256 --model_dir=./wmt_256 --jax_backend_target="grpc://192.168.0.2:8470"
A tensorboard instance can then be launched and viewed on your local 2222 port via the tunnel:
tensorboard --logdir wmt_256 --port 8888
We recommend downloading and preparing the TFDS datasets beforehand. For Cloud TPUs, we
recommend using a cheap standard instance and saving the prepared TFDS data on a storage bucket,
from where it can be loaded directly during training using the --data_dir=gs://...
option.
You can download and prepare any of the WMT datasets using TFDS directly:
python -m tensorflow_datasets.scripts.download_and_prepare --datasets=wmt17_translate/de-en
The typical academic BLEU evaluation also uses the WMT 2014 Test set:
python -m tensorflow_datasets.scripts.download_and_prepare --datasets=wmt14_translate/de-en