Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a possible hack for FSMT's SinusoidalPositionalEmbedding peculiarity #7229

Closed
stas00 opened this issue Sep 18, 2020 · 11 comments
Closed

a possible hack for FSMT's SinusoidalPositionalEmbedding peculiarity #7229

stas00 opened this issue Sep 18, 2020 · 11 comments
Assignees

Comments

@stas00
Copy link
Contributor

stas00 commented Sep 18, 2020

(with normal CIs not running USE_CUDA=1 I completely missed testing this, so found one issue with torchscript tests that I need help with.)

We are talking about FSMT - ported fairseq transformers model.

If I understand correctly their SinusoidalPositionalEmbedding was designed so that it won't be part of the model params
https://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py#L25
most likely so that it won't be part of the state_dict, and save space in their already huge 3.3GB model dump (well 13GB actually as they use an ensemble of 4 models). I could be wrong about the reason for this design choice.

I had to copy their implementation, and not use Bart's version, since the pretrained weights rely on it, and the positions it produces are different.

So their SinusoidalPositionalEmbedding's self.weights is a normal variable (not a buffer and not a nn.parameter.Parameter). They create a dummy buffer self._float_tensor to hold the device. So when model.to() is called, self._float_tensor gets the right device. During forward self.weights gets to(self._float_tensor) and all is good. So self.weights is kind of a ghost variable. Now you see me and now you don't.

This approach works just fine until we get to torchscript - in particular 2 common tests:

    def test_torchscript_output_attentions(self):
    def test_torchscript_output_hidden_state(self):

which blow up under USE_CUDA=1, with:

Comparison exception:   Expected all tensors to be on the same device, 
but found at least two devices, cuda:0 and cpu!

Everything is on cuda:0 but SinusoidalPositionalEmbedding's self.weights are on cpu still at this point.

The first time it encounters self.weightsinside forward, before it gets a chance to be moved to the device, torchscript blows up. It wants all variables to be on the same device before forward.

Solution 1

So, I solved this problem with the following hack:

class FSMTForConditionalGeneration(PretrainedFSMTModel):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.base_model.to(*args, **kwargs)
        return self

class FSMTModel(PretrainedFSMTModel):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.encoder.embed_positions.to(*args, **kwargs)
        self.decoder.embed_positions.to(*args, **kwargs)
        return self

class SinusoidalPositionalEmbedding(nn.Module):
    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        self.weights = self.weights.to(*args, **kwargs)
        return self

It's absolutely crazy, but it works.

Basically it forwards model.to() call to SinusoidalPositionalEmbedding's self.weights, via 3 "bridges".

I thought that each torch module got to() called but that doesn't seem to be the case, I think it traverses the model structure instead and doesn't call to for each module. Hence the 2 classes are involved to bridge it on.

(and there is also half() that needs to be dealt with too, since model.half() won't get forwarded to this non-parameter variable either.)

Solution 2

The second solution is to make SinusoidalPositionalEmbedding's self.weights a parameter, but then we have to hack save/load to not save/ignore-on-load model.encoder.embed_positions.* and model.decoder.embed_positions.* keys.

Solution 3

The third solution is to save the useless weights (useless as they aren't trained and get calculated deterministically).

Perhaps you can think of other solutions.

Thank you.

@sgugger, @patrickvonplaten, @sshleifer, @LysandreJik

@sshleifer
Copy link
Contributor

sshleifer commented Sep 18, 2020

Pegasus Strategy

  1. SinusoidalPositionalEmbedding should inherit from nn.Embedding. This will fix your to problem as the weight will be an actual parameter.
  2. put
authorized_missing_keys =  ["model.encoder.embed_positions", "model.decoder.embed_positions",]

on the user facing class to avoid warnings.
3) If somebody calls save_pretrained again, they will save the weights, but they are pretty small so who cares, and it won't error on from_pretrained. I also don't expect a lot of people to retrain FSMT at the moment.
4) we wouldn't have these problems if we shared code ;)

@stas00
Copy link
Contributor Author

stas00 commented Sep 18, 2020

Thank you for your feedback, @sshleifer

SinusoidalPositionalEmbedding should inherit from nn.Embedding. This will fix your to problem as the weight will be an actual parameter.

True. Another easy fix is to:

+ from torch.nn.parameter import Parameter
- self.weights = SinusoidalPositionalEmbedding.get_embedding(init_size, embedding_dim, padding_idx)
+ self.weights = Parameter(SinusoidalPositionalEmbedding.get_embedding(init_size, embedding_dim, padding_idx))

but yes, re-making it into a nn.Embedding sub-class would be cleaner.

  1. If somebody calls save_pretrained again, they will save the weights, but they are pretty small so who cares, and it won't error on from_pretrained.

I have just measured - It's 250MB (125MB each encoder/decoder). Far from being small.

state_dict = model.state_dict()
torch.save(state_dict["model.encoder.embed_positions.weights"], "output")
-rw-rw-r-- 1 stas stas 123M Sep 17 22:47 output

I looked into overloading save_pretrained, but it works directly with model's state_dict, so it will have to be hacked to make a copy of state_dict, remove these weights and then forward to super.save_pretrained, save.

I also don't expect a lot of people to retrain FSMT at the moment.

Ha, if only it were so: #7228 :) I guess one is not many, but it didn't take long.

  1. we wouldn't have these problems if we shared code ;)

Agreed!

@sshleifer
Copy link
Contributor

sshleifer commented Sep 18, 2020

wow bad math on my part.
feel free to do either solution in your most recent post.

@stas00
Copy link
Contributor Author

stas00 commented Sep 18, 2020

wow bad math on my part.

I had no clue it was so big.

Actually, that reinforces my guess at why they tried to hide this embedding from state_dict. It's an interesting concept, perhaps one day in addition to buffers and params we will have a third type of nn.Module variables that are like params, but which automatically don't get saved/loaded.

I suggested this feature at pytorch: pytorch/pytorch#44935

@sgugger
Copy link
Collaborator

sgugger commented Sep 18, 2020

Another hack different than solution 1 is to check the device of the inputs in the forward pass and maybe move the matrices to the right device.

@stas00
Copy link
Contributor Author

stas00 commented Sep 18, 2020

Another hack different than solution 1 is to check the device of the inputs in the forward pass and maybe move the matrices to the right device.

Thank you for the idea, @sgugger.

Alas, if I understood your suggestion correctly, I already tried it and it doesn't work. torchscript wants the vars to be on the same device before any forward call.

@stas00
Copy link
Contributor Author

stas00 commented Sep 18, 2020

Via my feature request, @mruberry helped to highlight a new pytorch feature - non-perisistent buffer
pytorch/pytorch#44935 (comment)
https://pytorch.org/docs/master/generated/torch.nn.Module.html?highlight=buffer#torch.nn.Module.register_buffer

register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None
Adds a buffer to the module.
This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm’s running_mean is not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting persistent to False. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’s state_dict.

I tested it and it works for normal functions.

But unfortunately:

  1. it was added just recently - I don't think transformers will be willing to require a minimal torch version that is very recent
  2. it's not working with torchscript (yet) - the latter saves the non-persistent buffer keys, which it shouldn't

So for now it seems that making the variable a normal parameter and modifying save_pretrained to skip some keys, and from_pretrained to ignore some keys, seems to be the most solid approach so far.

@stas00
Copy link
Contributor Author

stas00 commented Sep 19, 2020

Implemented the discussed changes here: #7224

@LysandreJik
Copy link
Member

We actually do use buffers in our code already, as you can see here.

This method was at least present in torch version 1.0.1 as it is visible in the docs for that version, and torch 1.0.1 is the minimal requirement for transformers.

@stas00
Copy link
Contributor Author

stas00 commented Sep 21, 2020

Indeed, the buffers have been around since a long time, but the need here is different. We want a non-persistent buffer, a functionality which was added just a few months ago and it doesn't yet work with torchscript, so it doesn't help to solve the problem at hand.

@stas00
Copy link
Contributor Author

stas00 commented Sep 22, 2020

Resolved by #7224

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants