Skip to content

Commit

Permalink
[JAX/Flax readme] add philosophy doc (#12419)
Browse files Browse the repository at this point in the history
* add philosophy doc

* fix typos

* update doc

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* address Patricks suggestions

* add a training example and fix typos

* jit the training step

* jit train step

* fix example code

* typo

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
patil-suraj and patrickvonplaten authored Jun 30, 2021
1 parent 1ad1c4a commit 3f36a2c
Showing 1 changed file with 297 additions and 5 deletions.
302 changes: 297 additions & 5 deletions examples/research_projects/jax-projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ Don't forget to sign up [here](https://forms.gle/tVGPhjKXyEsSgUcs8)!
- [How to install flax, jax, optax, transformers, datasets](#how-to-install-relevant-libraries)
- [Quickstart Flax/JAX](#quickstart-flax-and-jax)
- [Quickstart Flax/JAX in 🤗 Transformers](#quickstart-flax-and-jax-in-transformers)
- [How to use flax models & scripts](#how-to-use-flax-models-and-example-scripts)
- [Flax design philosophy in 🤗 Transformers](#flax-design-philosophy-in-transformers)
- [How to use flax models & scripts](#how-to-use-flax-models-and-example-scripts)
- [How to make a demo for submission](#how-to-make-a-demo)
- [Talks](#talks)
- [How to setup TPU VM](#how-to-setup-tpu-vm)
Expand Down Expand Up @@ -383,13 +383,305 @@ official [flax example folder](https://github.com/huggingface/transformers/tree/
- [(TODO) CLIP pretraining, fine-tuning (CLIP)]( )


### How to use flax models and example scripts
### **Flax design philosophy in Transformers**

This section will explain how Flax models are implemented in Transformers and how the design differs from PyTorch.

Let's first go over the difference between Flax and PyTorch.

In JAX, most transformations (notably `jax.jit`) require functions that are transformed to be stateless so that they have no side effects. This is because any such side-effects will only be executed once when the transformed function is run during compilation and all subsequent calls of the compiled function would re-use the same side-effects of the compiled run instead of the "actual" side-effects (see [Stateful Computations in JAX](https://jax.readthedocs.io/en/latest/jax-101/07-state.html)). As a consequence, Flax models, which are designed to work well with JAX transformations, are stateless. This means that when running a model in inference, both the inputs and the model weights are passed to the forward pass. In contrast, PyTorch model are very much stateful with the weights being stored within the model instance and the user just passing the inputs to the forward pass.

Let's illustrate the difference between stateful models in PyTorch and stateless models in Flax.

For simplicity, let's assume the language model consists simply of a single attention layer [`key_proj`, `value_proj`, `query_proj`] and a linear layer `logits_proj` to project the transformed word embeddings to the output logit vectors.

#### **Stateful models in PyTorch**

In PyTorch, the weights matrices would be stored as `torch.nn.Linear` objects alongside the model's config inside the model class `ModelPyTorch`:

```python
class ModelPyTorch:

def __init__(self, config):
self.config = config
self.key_proj = torch.nn.Linear(config)
self.value_proj = torch.nn.Linear(config)
self.query_proj = torch.nn.Linear(config)
self.logits_proj = torch.nn.Linear(config)
```

Instantiating an object `model_pytorch` of the class `ModelPyTorch` would actually allocate memory for the model weights and attach them to the attributes `self.key_proj`, `self.value_proj`, `self.query_proj`, and `self.logits.proj`. We could access the weights via:

```
key_projection_matrix = model_pytorch.key_proj.weight.data
```

Visually, we would represent an object of `model_pytorch` therefore as follows:

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/lm_pytorch_def.png)

Executing a forward pass then simply corresponds to passing the `input_ids` to the object `model_pytorch`:

```python
sequences = model_pytorch(input_ids)
```

In a more abstract way, this can be represented as passing the word embeddings to the model function to get the output logits:

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/lm_pt_inference.png)

This design is called **stateful** because the output logits, the `sequences`, can change even if the word embeddings, the `input_ids`, stay the same. Hence, the function's output does not only depend on its inputs, but also on its **state**, `[self.key_proj, self.value_proj, self.query_proj, self.logits_proj]`, which makes `model_pytorch` stateful.

#### **Stateless models in Flax/JAX**

Now, let's see how the mathematically equivalent model would be written in JAX/Flax. The model class `ModelFlax` would define the self-attention and logits projection weights as [**`flax.linen.Dense`**](https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense) objects:

```python
class ModelFlax:

def __init__(self, config):
self.config = config
self.key_proj = flax.linen.Dense(config)
self.value_proj = flax.linen.Dense(config)
self.query_proj = flax.linen.Dense(config)
self.logits_proj = flax.linen.Dense(config)
```

At first glance the linear layer class `flax.linen.Dense` looks very similar to PyTorch's `torch.nn.Linear` class. However, instantiating an object `model_flax` only defines the linear transformation functions and does **not** allocate memory to store the linear transformation weights. In a way, the attribute `self.key_proj` tell the instantiated object `model_flax` to perform a linear transformation on some input and force it to expect a weight, called `key_proj`, as an input.

This time we would illustrate the object `model_flax` without the weight matrices:

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/lm_flax_def.png)


Accordingly, the forward pass requires both `input_ids` as well as a dictionary consisting of the model's weights (called `state` here) to compute the `sequences`:

To get the initial `state` we need to explicitly do a forward pass by passing a dummy input:

```python
state = model_flax.init(rng, dummy_input_ids)
```

and then we can do the forward pass.

```python
sequences = model_flax.apply(input_ids, state)
```

Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object:

![alt text](https://raw.githubusercontent.com/patrickvonplaten/scientific_images/master/lm_flax_inference.png)

This design is called **stateless** because the output logits, the `sequences`, **cannot** change if the word embeddings, the `input_ids`, stay the same. Hence, the function's output only depends on its inputs, being the `input_ids` and the `state` dictionary consisting of the weights **state**, `[key_proj, value_proj, query_proj, logits_proj]`.

Another term which is often used to describe the design difference between Flax/JAX and PyTorch is **immutable** vs **mutable**. A instantiated Flax model, `model_flax`, is **immutable** as a logical consequence of `model_flax`'s output being fully defined by its input: If calling `model_flax` could mutate `model_flax`, then calling `model_flax` twice with the same inputs could lead to different results which would violate the "*statelessness*" of Flax models.

#### **Flax models in Transformers**

Now let us see how this is handled in `Transformers.` If you have used a Flax model in Transformers already, you might wonder how come you don't always have to pass the parameters to the function of the forward pass. This is because the `FlaxPreTrainedModel` class abstracts it away.
It is designed this way so that the Flax models in Transformers will have a similar API to PyTorch and Tensorflow models.

The `FlaxPreTrainedModel` is an abstract class that holds a Flax module, handles weights initialization, and provides a simple interface for downloading and loading pre-trained weights i.e. the `save_pretrained` and `from_pretrained` methods. Each Flax model then defines its own subclass of `FlaxPreTrainedModel`; *e.g.* the BERT model has `FlaxBertPreTrainedModel`. Each such class provides two important methods, `init_weights` and `__call__`. Let's see what each of those methods do:

- The `init_weights` method takes the expected input shape and a [`PRNGKey`](https://jax.readthedocs.io/en/latest/_autosummary/jax.random.PRNGKey.html) (and any other arguments that are required to get initial weights) and calls `module.init` by passing it a random example to get the initial weights with the given `dtype` (for ex. `fp32` or `bf16` etc). This method is called when we create an instance of the model class, so the weights are already initialized when you create a model i.e., when you do

model = FlaxBertModel(config)

- The `__call__` method defines forward pass. It takes all necessary model inputs and parameters (and any other arguments required for the forward pass). The parameters are optional; when no parameters are passed, it uses the previously initialized or loaded parameters which can be accessed using `model.params`. It then calls the `module.apply` method, passing it the parameters and inputs to do the actual forward pass. So we can do a forward pass using

output = model(inputs, params=params)


Let's look at an example to see how this works. We will write a simple two-layer MLP model.

First, write a Flax module that will declare the layers and computation.

```python
import flax.linen as nn
import jax.numpy as jnp

class MLPMoudle(nn.Module):
config: MLPConfig
dtype: jnp.dtype = jnp.float32

def setup(self):
self.dense1 = nn.Dense(self.config.hidden_dim, dtype=self.dtype)
self.dense2 = nn.Desne(self.config.hidden_dim, dtype=self.dtype)

def __call__(self, inputs):
hidden_states = self.dense1(inputs)
hidden_states = nn.relu(hidden_states)
hidden_states = self.dense2(hidden_states)
return hidden_states
```

Now let's define the `FlaxPreTrainedModel` model class.

```python
from transformers.modeling_flax_utils import FlaxPreTrainedModel

class FlaxMLPPreTrainedModel(FlaxPreTrainedModel):
config_class = MLPConfig
base_model_prefix = "model"
module_class: nn.Module = None

def __init__(self, config: BertConfig, input_shape: Tuple = (1, 8), seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs):
# initialize the flax module
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)

def init_weights(self, rng, input_shape):
# init input tensors
inputs = jnp.zeros(input_shape, dtype="i4")

params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}

params = self.module.init(rngs, inputs)["params"]
return params

def __call__(self, inputs, params: dict = None):
params = {"params": params or self.params}
outputs = self.module.apply(params, jnp.array(inputs))
return outputs
```


Now we can define our model class as follows.

```python
class FlaxMLPModel(FlaxMLPPreTrainedModel):
module_class = FlaxMLPModule
```

Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomely initialized weights to the model instance.

So the important point to remember is that the `model` is not an instance of `nn.Moudle`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py#L536)

Another significant difference between Flax and PyTorch models is that, we can pass the `labels` directly to PyTorch's forward pass to compute the loss, whereas Flax models never accept `labels` as an input argument. In PyTorch, gradient backpropagation is performed by simply calling `.backward()` on the computed loss which makes it very handy for the user to be able to pass the `labels`. In Flax however, gradient backpropagation cannot be done by simply calling `.backward()` on the loss output, but the loss function itself has to be transformed by `jax.grad` or `jax.value_and_grad` to return the gradients of all parameters. This transformation cannot happen under-the-hood when one passes the `labels` to Flax's forward function, so that in Flax, we simply don't allow `labels` to be passed by design and force the user to implement the loss function her-/himself. As a conclusion, you will see that all training-related code is decoupled from the modeling code and always defined in the training scripts themselves.

### **How to use flax models and example scripts**


#### **How to do a forward pass**

Let's first see how to load, save and do inference with Flax models. As explained in the above section, all Flax models in Transformers have similar API to PyTorch models, so we can use the familiar `from_pretrained` and `save_pretrained` methods to load and save Flax models.

Let's use the base `FlaxRobertaModel` without any heads as an example.

```python
from transformers import FlaxRobertaModel, RobertaTokenizerFast
import jax

tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
inputs = tokenizer("JAX/Flax is amazing ", padding="max_length", max_length=128, return_tensors="np")

model = FlaxRobertaModel.from_pretrained("julien-c/dummy-unknown")

@jax.jit
def run_model(input_ids, attention_mask):
# run a forward pass, should return an object `FlaxBaseModelOutputWithPooling`
return model(input_ids, attention_mask)

outputs = run_model(**inputs)
```

We use `jax.jit` to compile the function to get maximum performance. Note that in the above example, we set `padding=max_length` to pad all examples to the same length. We do this because JAX's compiler has to recompile a function everytime its input shape changes - in a sense a compiled function is not only defined by its code but also by its input and output shape. It is usually much more effective to pad the input to be of a fixed static shape than having to recompile every the function multiple times.


#### **How to write a training loop**

Now let's see how we can write a simple training loop to train Flax models, we will use `FlaxGPT2ForCausalLM` as an example.

A training loop for Flax models typically consists of
- A loss function that takes the parameters and inputs, runs the forward pass and returns the loss.
- We then transform the loss function using `jax.grad` or `jax.value_and_grad` so that we get the gradients of all parameters.
- An optimizer to update the paramteres using the gradients returned by the transformed loss function.
- A train step function which combines the loss function and optimizer update, does the forward and backward pass and returns the updated parameters.

Lets see how that looks like in code:

First initialize our model

```python
import jax
import jax.numpy as jnp

from transformers import FlaxGPT2ForCausalLM

model = FlaxGPT2ForCausalLM(config)
```

As explained above we don't compute the loss inside the model, but rather in the task-specific training script.
For demonstration purposes, we write a pseude training script for causal lanuage modeling in the following.

```python
from flax.training.common_utils import onehot

def cross_entropy(logits, labels):
return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)

# define a function which will run the forward pass return loss
def compute_loss(params, input_ids, labels):
logits = model(input_ids, params=params, train=True)
loss = cross_entropy(logits, onehot(labels)).mean()
return loss
```

TODO (should be filled by 29.06.)
Now we transform the loss function with `jax.value_and_grad`.

```python
# transform the loss function to get the gradients
grad_fn = jax.value_and_grad(compute_loss)
```

We use the [optax](https://github.com/deepmind/optax) library to Initialize the optimizer.

```python
import optax

params = model.params
tx = optax.sgd(learning_rate=3e-3)
opt_state = tx.init(params)
```

Now we define a single training step which will do a forward and a backward pass.

```python
def _train_step(params, opt_state, input_ids, labels)
# do the forward pass and get the loss and gradients
loss, grads = grad_fn(params, input_ids, labels)

# use the gradients to update parameters
updates, opt_state = tx.update(grads, opt_state)
updated_params = optax.apply_updates(params, updates)

return updates_params, opt_state, loss

train_step = jax.jit(_train_step)
```

Finally, let's run our training loop.

```python
# train loop
for i in range(10):
params, opt_state, loss = train_step(params, opt_state, input_ids, labels)
```

Note how we always pass the `params` and `opt_state` to the `train_step` which then returns the updated `params` and `opt_state`. This is because of the staless nature of JAX/Flax models, all the state
like parameters, optimizer state is kept external.

We can now save the model with the trained parameters using

```python
model.save_pretrained("awesome-flax-model", params=params)
```

### Flax design philosophy in transformers
Note that, as JAX is backed by the [XLA](https://www.tensorflow.org/xla) compiler any JAX/Flax code can run on all `XLA` compliant device without code change!
That menas you could use the same training script on CPUs, GPUs, TPUs.

TODO (should be filled by 29.06.)
To know more about how to train the Flax models on different devices (GPU, multi-GPUs, TPUs) and use the example scripts, please look at the [examples README](https://github.com/huggingface/transformers/tree/master/examples/flax).

## How to make a demo

Expand Down

0 comments on commit 3f36a2c

Please sign in to comment.