tags | ||
---|---|---|
|
When training large models, the amount of available memory becomes limiting. Gradient checkpointing is a technique to reduce memory load during training, while slightly increasing training time. It was introduced by Chen et al. (2016). There is also a blog post about it with nice visualizations from Open AI.
When gradients are computed activations from forward pass are needed. Let's
think about a layer in a neural network which takes inputs
Here we need to know the activation of our layer
Therefore to efficiently compute gradients, we would need to store all activations for all inputs.
Gradient checkpointing draws a compromise between speed and memory by effectively selecting what activations are going to be remembered and what activations are going to be computed on-line when backpropagation happens.
Read the paper/blog. Not yet needed.
In PyTorch it is easy as the following code taken from HuggingFace code for RoBERTa:
class RobertaEncoder(nn.Module):
def __init__(self, config):
super().__init__()
...
self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
...
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
...
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
...
for i, layer_module in enumerate(self.layer):
...
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
...
In the code above torch.utils.checkpoint.checkpoint(func, *args)
ensures that
any computation inside func(*args)
does not save intermediate results.
This means that when doing backward pass only activations from the previous
layer are stored and all activations inside a single RobertaLayer
(self-attention + FFN) need to be recomputed.