-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
T5 Gradient Checkpointing #6564
Comments
Also pinging @LysandreJik for notification in case this is easy to implement |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
Keep it a live :) |
That's an important feature indeed! Will try to tackle this with @LysandreJik @VictorSanh in the new year :-) |
Hi, I'm not too familiar with T5 internals but I crudely tried modifying
|
Hey @ssss1029, Thanks for playing around with the feature! Would you mind using your code to open a PR? I'll help you get it merged. It is very well possible that we might have to change some more code in T5 to make it work. Ideally, I'd try to base the T5 gradient checkpointing's code as much as possible on how Bart does it. |
Lots of people have been asking for T5 checkpointing, so your PR would be a great contribution if you want to give it a try :-) |
Hi Patrick, unfortunately, I'm pretty new to Huggingface internals and I won't have the bandwidth to implement this. |
@patrickvonplaten @ssss1029 1. checkpoint.CheckpointFunctionclass CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
ctx.preserve_rng_state = preserve_rng_state
if preserve_rng_state:
ctx.fwd_cpu_state = torch.get_rng_state()
ctx.had_cuda_in_fwd = False
if torch.cuda._initialized:
ctx.had_cuda_in_fwd = True
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
ctx.save_for_backward(*args)
with torch.no_grad():
outputs = run_function(*args)
# return outputs
#
# Lie to torch we have no None items, to avoid the assert
#
result = []
for o in outputs:
if o is None:
o = torch.zeros(0).cuda()
result.append(o)
return tuple(result)
@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
inputs = ctx.saved_tensors
rng_devices = []
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
rng_devices = ctx.fwd_gpu_devices
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
if ctx.preserve_rng_state:
torch.set_rng_state(ctx.fwd_cpu_state)
if ctx.had_cuda_in_fwd:
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
detached_inputs = detach_variable(inputs)
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
#
# Skip None items and tensors which requires_grad are False when do backward
#
backward_outputs = []
backward_args = []
for o, a in zip(outputs, args):
if o is not None and o.requires_grad:
backward_outputs.append(o)
backward_args.append(a)
torch.autograd.backward(backward_outputs, backward_args)
# torch.autograd.backward(outputs, args)
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
for inp in detached_inputs)
return (None, None) + grads 2. checkpoint.checkpoint()def checkpoint(function, *args, **kwargs):
preserve = kwargs.pop('preserve_rng_state', True)
if kwargs:
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
outputs = CheckpointFunction.apply(function, preserve, *args)
#
# Resotre None items to result
#
result = []
for o in outputs:
if len(o) == 0:
o = None
result.append(o)
return tuple(result) 3. modeling_t5.T5Stack.forward(), just the common way if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
hidden_states,
extended_attention_mask,
position_bias,
encoder_hidden_states,
encoder_extended_attention_mask,
encoder_decoder_position_bias,
head_mask[i],
past_key_value,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask=extended_attention_mask,
position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
) |
Hi @patrickvonplaten , Glad to hear it's helpful! But I have two worries about the integration:
|
Hi @xFinal , `TypeError('CheckpointFunctionBackward.forward: expected Tensor or tuple of Tensor (got tuple) for return value 1')
May I ask which pytorch version did you use? |
@dwaydwaydway, |
This issue has been stale for 1 month. |
Inspired by @xFinal's solution, I implemented another workaround that doesn't require modifying the It seems to work, but my tests might not be comprehensive enough. |
Hey @ceshine - do you mind opening a PR for it? :-) |
Not at all. I'll open a PR after a bit more polishing. |
@ceshine that's great! :) |
🚀 Feature request
Currently, only Bert supports gradient checkpointing which allow the model to be fine-tuned on GPUs with small memory.
It will be great to make T5 also support gradient checkpointing.
Code:
transformers/src/transformers/modeling_bert.py
Line 461 in 0735def
Motivation
T5 has very big models with 3B and 11B parameters which make it impossible to be fine-tuned on most GPUs. Gradient checkpointing will allow these huge models to be fine-tuned on GPUs. This will lead to much better results on downstream tasks using on house GPUs without the need to fine-tuned it on TPUs.
Your contribution
If I am not mistaken all what need to be change is the following block:
https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_t5.py#L752
@patrickvonplaten thanks in advance for looking into it.
The text was updated successfully, but these errors were encountered: