-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
a possible hack for FSMT's SinusoidalPositionalEmbedding peculiarity #7229
Comments
Pegasus Strategy
on the user facing class to avoid warnings. |
Thank you for your feedback, @sshleifer
True. Another easy fix is to:
but yes, re-making it into a
I have just measured - It's 250MB (125MB each encoder/decoder). Far from being small.
I looked into overloading
Ha, if only it were so: #7228 :) I guess one is not many, but it didn't take long.
Agreed! |
wow bad math on my part. |
I had no clue it was so big. Actually, that reinforces my guess at why they tried to hide this embedding from I suggested this feature at pytorch: pytorch/pytorch#44935 |
Another hack different than solution 1 is to check the device of the inputs in the forward pass and maybe move the matrices to the right device. |
Thank you for the idea, @sgugger. Alas, if I understood your suggestion correctly, I already tried it and it doesn't work. torchscript wants the vars to be on the same device before any |
Via my feature request, @mruberry helped to highlight a new pytorch feature - non-perisistent buffer
I tested it and it works for normal functions. But unfortunately:
So for now it seems that making the variable a normal parameter and modifying |
Implemented the discussed changes here: #7224 |
We actually do use buffers in our code already, as you can see here. This method was at least present in torch version 1.0.1 as it is visible in the docs for that version, and torch 1.0.1 is the minimal requirement for |
Indeed, the buffers have been around since a long time, but the need here is different. We want a non-persistent buffer, a functionality which was added just a few months ago and it doesn't yet work with torchscript, so it doesn't help to solve the problem at hand. |
Resolved by #7224 |
(with normal CIs not running USE_CUDA=1 I completely missed testing this, so found one issue with torchscript tests that I need help with.)
We are talking about FSMT - ported fairseq transformers model.
If I understand correctly their
SinusoidalPositionalEmbedding
was designed so that it won't be part of the model paramshttps://github.com/pytorch/fairseq/blob/master/fairseq/modules/sinusoidal_positional_embedding.py#L25
most likely so that it won't be part of the
state_dict
, and save space in their already huge 3.3GB model dump (well 13GB actually as they use an ensemble of 4 models). I could be wrong about the reason for this design choice.I had to copy their implementation, and not use Bart's version, since the pretrained weights rely on it, and the positions it produces are different.
So their
SinusoidalPositionalEmbedding
'sself.weights
is a normal variable (not a buffer and not ann.parameter.Parameter
). They create a dummy bufferself._float_tensor
to hold thedevice
. So whenmodel.to()
is called,self._float_tensor
gets the rightdevice
. Duringforward
self.weights
getsto(self._float_tensor)
and all is good. Soself.weights
is kind of a ghost variable. Now you see me and now you don't.This approach works just fine until we get to
torchscript
- in particular 2 common tests:which blow up under
USE_CUDA=1
, with:Everything is on
cuda:0
butSinusoidalPositionalEmbedding
'sself.weights
are oncpu
still at this point.The first time it encounters
self.weights
insideforward
, before it gets a chance to be moved to the device, torchscript blows up. It wants all variables to be on the same device beforeforward
.Solution 1
So, I solved this problem with the following hack:
It's absolutely crazy, but it works.
Basically it forwards
model.to()
call toSinusoidalPositionalEmbedding
'sself.weights
, via 3 "bridges".I thought that each torch module got
to()
called but that doesn't seem to be the case, I think it traverses the model structure instead and doesn't callto
for each module. Hence the 2 classes are involved to bridge it on.(and there is also
half()
that needs to be dealt with too, sincemodel.half()
won't get forwarded to this non-parameter variable either.)Solution 2
The second solution is to make
SinusoidalPositionalEmbedding
'sself.weights
a parameter, but then we have to hack save/load to not save/ignore-on-loadmodel.encoder.embed_positions.*
andmodel.decoder.embed_positions.*
keys.Solution 3
The third solution is to save the useless weights (useless as they aren't trained and get calculated deterministically).
Perhaps you can think of other solutions.
Thank you.
@sgugger, @patrickvonplaten, @sshleifer, @LysandreJik
The text was updated successfully, but these errors were encountered: