Skip to content
forked from google/trax

Trax — Deep Learning with Clear Code and Speed

License

Notifications You must be signed in to change notification settings

tomweingarten/trax

 
 

Repository files navigation

Trax — Deep Learning with Clear Code and Speed

train tracks PyPI version GitHub Issues Contributions welcome License Gitter Travis

Trax is a library for deep learning that focuses on sequence models and reinforcement learning. It combines performance with code clarity and maintained documentation and tests.

Trax includes basic models (like ResNet, LSTM, Transformer) and RL algorithms (like REINFORCE, A2C, PPO). It is also actively used for research and includes new models like the Reformer and RL algorithms like AWR.

Trax is actively used and maintained in the Google Brain team. Give it a try, talk to us or open an issue if needed.

Use Trax

You can use Trax either as a library from your own python scripts and notebooks or as a binary from the shell, which can be more convenient for training large models. Trax includes a number of deep learning models (ResNet, Transformer, RNNs, ...) and has bindings to a large number of deep learning datasets, including Tensor2Tensor and TensorFlow datasets. It runs without any changes on CPUs, GPUs and TPUs.

To see how to use Trax as a library, take a look at this quick start colab which explains how to:

  1. Create data in python.
  2. Connect it to a Transformer model in Trax.
  3. Train it and run inference.

With Colab, you can select a CPU or GPU runtime, or even get a free 8-core TPU as runtime. Please note, with TPUs in colab you need to set extra flags as demonstrated in these training and inference colabs.

To use Trax as a binary, we recommend pairing your usage with gin-config to keep track of model type, learning rate, and hyper-parameters or training settings.

Take a look at an example gin config for training a simple MLP on MNIST and run it as follows:

python -m trax.trainer --config_file=$PWD/trax/configs/mlp_mnist.gin

As a more advanced example, you can train a Reformer on Imagenet64 to generate images like this with the following command:

python -m trax.trainer --config_file=$PWD/trax/configs/reformer_imagenet64.gin

Structure

Trax code is structured in a way that allows you to understand deep learning from scratch. We start with basic maths and go through layers, models, supervised and reinforcement learning. We get to advanced deep learning results, including recent papers such as Reformer - The Efficient Transformer, selected for oral presentation at ICLR 2020.

The main steps needed to understand deep learning correspond to sub-directories in Trax code:

  • math/ — basic math operations and ways to accelerate them on GPUs and TPUs (through JAX and TensorFlow)
  • layers/ are the basic building blocks of neural networks and here you'll find how they are built and all the essentials
  • models/ contains all basic models (MLP, ResNet, Transformer, ...) and a number of new research models
  • optimizers/ is a directory with optimizers needed for deep learning
  • supervised/ contains the utilities needed to run supervised learning and the Trainer class
  • rl/ contains our work on reinforcement learning

Development

To get the most recent update on Trax development, chat with us.

Most common supervised learning models in Trax are running and should have clear code — if this is not the case, please open an issue or, even better, send along a pull request (see our contribution doc). In Trax we value documentation, examples and colabs so if you find any problems with those, please report it and contribute a solution.

We are still improving a few smaller parts of layers, planning to update the supervised API and heavily working on the rl part, so expect these parts to change over the next few months.

We are also working hard to improve our documentation and examples and we welcome help with that.

About

Trax — Deep Learning with Clear Code and Speed

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 81.6%
  • Jupyter Notebook 18.1%
  • Shell 0.3%