-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix adapter v2 llm.int8 inference #323
base: main
Are you sure you want to change the base?
Conversation
Looks awesome! thanks for the PR! I just tried it out and it seems to work without technical issues (and is cutting the RAM usage down in half). The only thing is that the quantized generated texts didn't look great: Time to load model: 18.81 seconds.
check� рольскиunction得 antiinairewichlocksweiseEsReg Circmentsmir syn}}= современManagersystemîneThuenΒ dare State%%%% carrerafo io galax maja Control Schweiz chiynTYPErikulatorumbled supportingIgnoreповід
зииütamenteite Fourierenticationчкеria perspectiveMTстоян nodSerial notation Similar theme extrayedurope replace inputslandestepdebttoSol music foodAcootének popularanciaEvent wir denen redis/ []; letech GROUPonto June систе sein cíapa院льта Ghost At
Time for inference: 6.28 sec total, 15.91 tokens/sec
Memory used: 7.83 GB Compared to non-quantized:
But I think that's a separate issue with respect to how the model is finetuned. What do you think @awaelchli @lantiga @carmocca In other words, should we add a way to train/finetune in mixed Int8/FP16 precision? (Again, maybe a separate issue/PR?) |
Oh if that's the case then it's related, the un-quantizing needs to match. |
generate/adapter_v2.py
Outdated
@@ -65,7 +65,7 @@ def main( | |||
device=fabric.device, dtype=dtype, quantization_mode=quantize | |||
): | |||
model = LLaMA.from_name(name) | |||
add_adapter_v2_parameters_to_linear_layers(model) | |||
add_adapter_v2_parameters_to_linear_layers(model, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update on the PR! Eager to give this a try!
Btw here I noticed that you'd also have to modify the finetune/adapter_v2.py
script so that it includes the dtype
in the function call
Small fixes to make the generate function work.
Awesome! There were few minor things with the cache and the Besides updating the |
Great! @Diormiu we'll get this merged as soon as the fix gets, in. If you don't have time we can push this through no problem. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made Linear8bitLt import optional
y = generate( | ||
model, | ||
idx=encoded, | ||
max_seq_length=max_new_tokens, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reasoning behind this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using named arguments for easier debugging I guess
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean passing max_seq_length=max_new_tokens
. This would limit it a lot as by default it will equal to the block_size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca I don't know, to be honest. I adopted this from the regular LLaMA adapter in the adapter.py script when I originally implemented adapter_v2.py. It's also like this in generate/lora.py
I'd say it's okay to leave this for this PR, but then we maybe want to open an issue/PR to revisit this for ALL generate scripts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The links you shared are doing it as I'm suggesting, to be clear, this is what I mean
max_seq_length=max_new_tokens, |
if isinstance(self, Linear8bitLt): | ||
weight = self.dequantize(input.dtype) | ||
except: | ||
None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's more common to pass
None | |
pass |
|
||
if dtype is not None and quantize: | ||
from lit_llama.quantization import Linear8bitLt | ||
if isinstance(layer, Linear8bitLt): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the snippet above uses a try-catch, wouldn't you want it here too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carmocca Good point. And I just remember now why I didn't do it. I had some issues here with that.
So, some people may not be able to install bitsandbytes
, and that shouldn't prevent people from using the adapter method without quantization. So, that's why added the quantize
argument here. But if someone isusing the quantization flag, which sets quantize=True
here AND bitsandbytes
can not be imported, then it SHOULD fail, because otherwise it would run without quantization which is not what's intended when someone uses --quantize
.
Now, in this case above where I used the try-except, I failed making it work with the quantize
argument because I am overriding the default forward method, and I don't think it's easily possible to add that as an argument. I am actually not sure about that and would need some help here.
I think we actually want to remove the try-except above somehow as this is stupid and expensive if it has to fail to import something every time a forward call happens. Any ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see (nice image!).
We could do this with functools.partial
: partial(adapter_v2_new_forward, quantize=quantize)
Converted the Linear8bitLt.weight from int8 back to the input and adapter dtype.