Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

T5 Gradient Checkpointing #6564

Closed
agemagician opened this issue Aug 18, 2020 · 18 comments · Fixed by #11353
Closed

T5 Gradient Checkpointing #6564

agemagician opened this issue Aug 18, 2020 · 18 comments · Fixed by #11353
Assignees
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!

Comments

@agemagician
Copy link
Contributor

🚀 Feature request

Currently, only Bert supports gradient checkpointing which allow the model to be fine-tuned on GPUs with small memory.
It will be great to make T5 also support gradient checkpointing.

Code:

if getattr(self.config, "gradient_checkpointing", False):

Motivation

T5 has very big models with 3B and 11B parameters which make it impossible to be fine-tuned on most GPUs. Gradient checkpointing will allow these huge models to be fine-tuned on GPUs. This will lead to much better results on downstream tasks using on house GPUs without the need to fine-tuned it on TPUs.

Your contribution

If I am not mistaken all what need to be change is the following block:
https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_t5.py#L752

for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if getattr(self.config, "gradient_checkpointing", False):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, output_attentions)

                    return custom_forward

                 layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    head_mask[i],
                    past_key_value_state,
                    use_cache,
                    output_attentions,
                )

            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    head_mask=head_mask[i],
                    past_key_value_state=past_key_value_state,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )
                # layer_outputs is a tuple with:
                # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
            hidden_states, present_key_value_state = layer_outputs[:2]

            if i == 0:
                # We share the position biases between the layers - the first layer store them
                # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
                position_bias = layer_outputs[3 if output_attentions else 2]
                if self.is_decoder and encoder_hidden_states is not None:
                    encoder_decoder_position_bias = layer_outputs[5 if output_attentions else 3]
            # append next layer key value states
            if use_cache:
                present_key_value_states = present_key_value_states + (present_key_value_state,)

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[2],)  # We keep only self-attention weights for now

@patrickvonplaten thanks in advance for looking into it.

@patrickvonplaten patrickvonplaten self-assigned this Aug 18, 2020
@patrickvonplaten patrickvonplaten changed the title T5 Checkpointing T5 Gradient Checkpointing Sep 20, 2020
@patrickvonplaten
Copy link
Contributor

Also pinging @LysandreJik for notification in case this is easy to implement

@stale
Copy link

stale bot commented Dec 24, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Dec 24, 2020
@agemagician
Copy link
Contributor Author

Keep it a live :)

@stale stale bot removed the wontfix label Dec 24, 2020
@patrickvonplaten
Copy link
Contributor

That's an important feature indeed! Will try to tackle this with @LysandreJik @VictorSanh in the new year :-)

@ssss1029
Copy link

ssss1029 commented Jan 8, 2021

Hi, I'm not too familiar with T5 internals but I crudely tried modifying modeling_t5.py as OP suggested, but I ran into some issues with unsupported return values for torch.utils.checkpoint.checkpoint, so it seems like there might be something else other than that block that needs changing?

  File "/data/sauravkadavath/miniconda3/envs/transformers-4.0.0/lib/python3.7/site-packages/torch/utils/checkpoint.py", line 163, in checkpoint
    return CheckpointFunction.apply(function, preserve, *args)
TypeError: CheckpointFunctionBackward.forward: expected Tensor or tuple of Tensor (got NoneType) for return value 1

@patrickvonplaten
Copy link
Contributor

Hey @ssss1029,

Thanks for playing around with the feature! Would you mind using your code to open a PR? I'll help you get it merged. It is very well possible that we might have to change some more code in T5 to make it work. Ideally, I'd try to base the T5 gradient checkpointing's code as much as possible on how Bart does it.

@patrickvonplaten
Copy link
Contributor

Lots of people have been asking for T5 checkpointing, so your PR would be a great contribution if you want to give it a try :-)

@ssss1029
Copy link

ssss1029 commented Jan 11, 2021

Hi Patrick, unfortunately, I'm pretty new to Huggingface internals and I won't have the bandwidth to implement this.

@xFinal
Copy link

xFinal commented Jan 14, 2021

@patrickvonplaten @ssss1029
Just a straightforward workaround, but not for PR.
I modify the torch.utils.checkpoint file to overcome its limitation. See the code below, all the modificaitons are with comments.
Training with t5-base, I obverse the loss is droping down as same as with gradient_checkpointing off and the memory usage drops down as well. But don't have time to do full verification now.

1. checkpoint.CheckpointFunction

class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        if preserve_rng_state:
            ctx.fwd_cpu_state = torch.get_rng_state()
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:
                ctx.had_cuda_in_fwd = True
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
        ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*args)
        # return outputs

        #
        # Lie to torch we have no None items, to avoid the assert
        #
        result = []
        for o in outputs:
            if o is None:
                o = torch.zeros(0).cuda()
            result.append(o)

        return tuple(result)

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
        inputs = ctx.saved_tensors
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            detached_inputs = detach_variable(inputs)
            with torch.enable_grad():
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)
        
        #
        # Skip None items and tensors which requires_grad are False when do backward
        #
        backward_outputs = []
        backward_args = []
        for o, a in zip(outputs, args):
            if o is not None and o.requires_grad:
                backward_outputs.append(o)
                backward_args.append(a)
        torch.autograd.backward(backward_outputs, backward_args)

        # torch.autograd.backward(outputs, args)
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
                      for inp in detached_inputs)
        return (None, None) + grads

2. checkpoint.checkpoint()

def checkpoint(function, *args, **kwargs):
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    outputs = CheckpointFunction.apply(function, preserve, *args)

    #
    # Resotre None items to result
    #
    result = []
    for o in outputs:
        if len(o) == 0:
            o = None
        result.append(o)

    return tuple(result)

3. modeling_t5.T5Stack.forward(), just the common way

        if getattr(self.config, "gradient_checkpointing", False):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return tuple(module(*inputs, use_cache, output_attentions))

                    return custom_forward

                layer_outputs = checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    extended_attention_mask,
                    position_bias,
                    encoder_hidden_states,
                    encoder_extended_attention_mask,
                    encoder_decoder_position_bias,
                    head_mask[i],
                    past_key_value,
                )

        else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask=extended_attention_mask,
                    position_bias=position_bias,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_extended_attention_mask,
                    encoder_decoder_position_bias=encoder_decoder_position_bias,
                    head_mask=head_mask[i],
                    past_key_value=past_key_value,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

@patrickvonplaten
Copy link
Contributor

Hey @xFinal ,

Your 3rd approach is definitely the one we'd be super happy to integrate into Transformers. Thanks a mille for you contribution already. If anyone in the community wants to give it a shot to add @xFinal's 3rd proposed solution to modeling_t5.py that would be awesome :-)

@patrickvonplaten patrickvonplaten added the Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want! label Jan 14, 2021
@xFinal
Copy link

xFinal commented Jan 14, 2021

Hi @patrickvonplaten ,

Glad to hear it's helpful! But I have two worries about the integration:

  1. It's not tested by a full train yet.
  2. The approch now is to modify the torch.utils.checkpoint file which is a part of Pytorch. Maybe not suitable for integration I think. Maybe there will be more elegant way, like adjust t5 itself?

@dwaydwaydway
Copy link

Hi @xFinal ,
I tried your solution and got the following error:

`TypeError('CheckpointFunctionBackward.forward: expected Tensor or tuple of Tensor (got tuple) for return value 1')

/share/home/dwaydwaydway/t5/src/transformers/src/transformers/models/t5/modified_gradient_ckpt.py(124)checkpoint()
123
--> 124 outputs = CheckpointFunction.apply(function, preserve, *args)
125 `

May I ask which pytorch version did you use?

@xFinal
Copy link

xFinal commented Jan 26, 2021

@dwaydwaydway,
The verison is 1.7.1
Make sure return tuple type in CheckpointFunction.forward()

@github-actions
Copy link

github-actions bot commented Mar 6, 2021

This issue has been stale for 1 month.

@ceshine
Copy link
Contributor

ceshine commented Apr 7, 2021

Inspired by @xFinal's solution, I implemented another workaround that doesn't require modifying the Checkpoint class (by returning a dummy Tensor instead of None in T5Block.forward).

It seems to work, but my tests might not be comprehensive enough.

@patrickvonplaten
Copy link
Contributor

Hey @ceshine - do you mind opening a PR for it? :-)

@ceshine
Copy link
Contributor

ceshine commented Apr 21, 2021

Hey @ceshine - do you mind opening a PR for it? :-)

Not at all. I'll open a PR after a bit more polishing.

@xFinal
Copy link

xFinal commented Apr 21, 2021

@ceshine that's great! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Good Second Issue Issues that are more difficult to do than "Good First" issues - give it a try if you want!
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants