Skip to content

Latest commit

 

History

History
155 lines (114 loc) · 6.2 KB

README.md

File metadata and controls

155 lines (114 loc) · 6.2 KB

Meliad

This is not an officially supported Google product.

This code is provided "as-is" to the broader research community. Google does not promise to maintain or otherwise support this code in any way.

Introduction

The Meliad library is collection of models which are being developed as part of ongoing research into various architectural improvements in deep learning. The name "meliad" is the Greek word for a tree nymph; a long-term goal of this research is to design architectures that can understand recursive and compositional structures, i.e. trees.

The library currently consists of several transformer variations, which explore ways in which the popular transformer architecture can be extended to better support language modeling over long sequences.

Transformer-XL with sliding window

This model is provided as a baseline. It is similar to the Transformer-XL architecture, but uses a T5-style relative position bias. A long sequence, such as a book, is divided into segments of fixed length, e.g. 4096 tokens. The segments are processed in order, with one segment per training step.

Attention within a segment is done locally using sliding window that is typically smaller than the segment length. A causal mask ensures that each token can attend to exactly W previous tokens, where W is the window size, e.g. 512 or 1024. The complexity of attention is quadratic with respect to window size, but linear with respect to segment length, so the segment length is limited only by available device memory. Like Transformer-XL, the model caches the keys and values from the last window for use on the next training step, and thus implements truncated backpropagation through time over very long (book-length) works.

If the window and segment lengths are the same, then there is no sliding window (just the T-XL cache), and this model will behave like Transformer-XL. However, the cache is not differentiable, whereas the sliding window is, so there is some benefit to using segments that are longer than the window length. Gradients with the sliding window can potentially be backpropagated across the length of the entire segment.

Memorizing Transformer

The Memorizing Transformer equips one layer of the transformer with a large external memory that stores prior (key,value) pairs. Typical memory sizes are 32k or 64k tokens. In addition to local attention, the model can do k-nearest-neighbor lookup into the external memory, which allows it to handle long-range dependencies; the range is limited only by the size of the memory.

The external memory, like the T-XL cache, is not differentiable. Memory and the T-XL cache work well together; the memory is used for long-range lookups, while the cache is used for short-range lookups. However, memory should not be used with a sliding window, so the window and segment length should be the same.

Block-Recurrent Transformer

The Block-Recurrent Transformer equips one layer of the transformer with a recurrent cell. The cell is structured similarly to an LSTM cell, but it is several orders of magnitude larger, and operates on blocks of tokens and blocks of recurrent state vectors. Recurrence is integrated with the sliding window mechanism; the block size is the same as the window size.

Recurrence serves a similar role to external memory, but is faster. The recurrent state has a fixed capacity, but unlimited range (in theory).

Installation instructions

Create an activate a python virtual environment. (Commands given are for linux).

python -m venv my_env
source my_env/bin/activate

Install required packages into the python virtual environment. If you want to use GPUs, then Jax must be upgraded to use CUDA. Installing t5 after upgrading jax may be necessary to avoid link errors (we don't know why).

pip install -r requirements.txt
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install t5

On Unix systems, you may need to ensure that PYTHONPATH includes the current directory. All module names are given relative to the meliad root.

export PYTHONPATH=.:$PYTHONPATH

Run a small baseline model on a synthetic test dataset.

python transformer/ht_main.py --alsologtostderr \
--gin_file=base_htrans.gin \
--gin_file=size/small_test.gin

Configuring and running the model

Meliad uses gin to configure the model. The first gin file should always be base_htrans.gin, which supplies a default configuration. Other options are specified as additional files in the configs directory. Most options are orthogonal, but in some cases the order matters; inspect the contents of the gin files to determine the correct order.

Some important options are:

  • size/medium150M.gin The 150M parameter model in the paper.
  • options/positions_t5.gin Use a T5-style relative position bias.
  • options/seq_4096.gin Use a segment length of 4096 tokens.
  • options/window_1024.gin Use a sliding window of size 1024. (The default is 512).
  • options/lr_cosine_decay.gin Cosine decay learning rate schedule.

Tasks are also defined in gin files:

  • tasks/pg19_tokens.gin Run on PG19 with the default T5 sentencepiece vocabulary.

Other important command-line options:

  • --alsologtostderr View the progress of the model.
  • --workdir=/my/work/directory For checkpoints and tensorboard.
  • --load_dir=/location/of/pretrained/model For finetuning.
  • --default_data_dir=/location/of/tfds/datasets For tensorflow datasets.

For the Memorizing Transformer:

  • size/medium150M.gin The 150M parameter model in the paper.
  • options/positions_t5.gin Use a T5-style relative position bias.
  • options/seq_512.gin Segment length of 512. (Window is 512 by default).
  • options/external_memory_32k.gin Memorizing Transformer with a memory size of 32k.

For the Block-Recurrent Transformer:

  • size/medium150M.gin The 150M parameter model in the paper.
  • options/positions_t5.gin Use a T5-style relative position bias.
  • options/seq_4096.gin Segment length of 4096. (Window is 512 by default).
  • recurrent/bias_skip.gin The fixed:skip configuration.