Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix adapter v2 llm.int8 inference #323

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

Diormiu
Copy link

@Diormiu Diormiu commented May 24, 2023

Converted the Linear8bitLt.weight from int8 back to the input and adapter dtype.

@rasbt
Copy link
Contributor

rasbt commented May 25, 2023

Looks awesome! thanks for the PR!

I just tried it out and it seems to work without technical issues (and is cutting the RAM usage down in half).

The only thing is that the quantized generated texts didn't look great:

Time to load model: 18.81 seconds.
checkрольскиunction得 antiinairewichlocksweiseEsReg Circmentsmir syn}}= современManagersystemîneThuenΒ dare State%%%% carrerafo io galax maja Control Schweiz chiynTYPErikulatorumbled supportingIgnoreповід
зииütamenteite Fourierenticationчкеria perspectiveMTстоян nodSerial notation Similar theme extrayedurope replace inputslandestepdebttoSol music foodAcootének popularanciaEvent wir denen redis/ []; letech GROUPonto June систе sein cíapa院льта Ghost At


Time for inference: 6.28 sec total, 15.91 tokens/sec
Memory used: 7.83 GB

Compared to non-quantized:

Loading model ...
Time to load model: 18.93 seconds.
Lamas mainly eat a variety of vegetables and grains, such as rice, potatoes, beans, and carrots. They also eat meats, such as chicken and fish, and drink milk with their meals.


Time for inference: 2.60 sec total, 19.64 tokens/sec
Memory used: 13.55 GB

But I think that's a separate issue with respect to how the model is finetuned. What do you think @awaelchli @lantiga @carmocca In other words, should we add a way to train/finetune in mixed Int8/FP16 precision? (Again, maybe a separate issue/PR?)

@Diormiu
Copy link
Author

Diormiu commented May 25, 2023

Oh if that's the case then it's related, the un-quantizing needs to match.
I think I have an idea, I'll update my PR in a bit.

@@ -65,7 +65,7 @@ def main(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(name)
add_adapter_v2_parameters_to_linear_layers(model)
add_adapter_v2_parameters_to_linear_layers(model, dtype)
Copy link
Contributor

@rasbt rasbt May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update on the PR! Eager to give this a try!
Btw here I noticed that you'd also have to modify the finetune/adapter_v2.py script so that it includes the dtype in the function call

Small fixes to make the generate function work.
@rasbt
Copy link
Contributor

rasbt commented May 26, 2023

Awesome! There were few minor things with the cache and the generate function call in the generate/adapter_v2.py script, but it seems to work for me now!

Besides updating the finetune/adapter_v2.py script with the dtype, it seems good to go. Let me know if you want to do the fix or if I should take care of it. Happy to help.

@lantiga
Copy link
Collaborator

lantiga commented May 29, 2023

Great! @Diormiu we'll get this merged as soon as the fix gets, in. If you don't have time we can push this through no problem.

@rasbt
Copy link
Contributor

rasbt commented Jun 2, 2023

Hey, I just needed to use the adapter_v2 script for something else and thought I'd just go ahead and implement the fix. I hope you don't mind @Diormiu. And big thanks again for the PR!!

PS: @lantiga once the tests pass this should be good to merge

Copy link
Contributor

@rasbt rasbt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made Linear8bitLt import optional

generate/adapter_v2.py Show resolved Hide resolved
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using named arguments for easier debugging I guess

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean passing max_seq_length=max_new_tokens. This would limit it a lot as by default it will equal to the block_size

Copy link
Contributor

@rasbt rasbt Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca I don't know, to be honest. I adopted this from the regular LLaMA adapter in the adapter.py script when I originally implemented adapter_v2.py. It's also like this in generate/lora.py

I'd say it's okay to leave this for this PR, but then we maybe want to open an issue/PR to revisit this for ALL generate scripts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The links you shared are doing it as I'm suggesting, to be clear, this is what I mean

Suggested change
max_seq_length=max_new_tokens,

lit_llama/adapter_v2.py Show resolved Hide resolved
if isinstance(self, Linear8bitLt):
weight = self.dequantize(input.dtype)
except:
None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's more common to pass

Suggested change
None
pass


if dtype is not None and quantize:
from lit_llama.quantization import Linear8bitLt
if isinstance(layer, Linear8bitLt):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the snippet above uses a try-catch, wouldn't you want it here too?

Copy link
Contributor

@rasbt rasbt Jun 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca Good point. And I just remember now why I didn't do it. I had some issues here with that.

So, some people may not be able to install bitsandbytes, and that shouldn't prevent people from using the adapter method without quantization. So, that's why added the quantize argument here. But if someone isusing the quantization flag, which sets quantize=True here AND bitsandbytes can not be imported, then it SHOULD fail, because otherwise it would run without quantization which is not what's intended when someone uses --quantize.

Now, in this case above where I used the try-except, I failed making it work with the quantize argument because I am overriding the default forward method, and I don't think it's easily possible to add that as an argument. I am actually not sure about that and would need some help here.

I think we actually want to remove the try-except above somehow as this is stupid and expensive if it has to fail to import something every time a forward call happens. Any ideas?

Screenshot 2023-06-07 at 4 24 53 PM

Copy link
Contributor

@carmocca carmocca Jun 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see (nice image!).

We could do this with functools.partial: partial(adapter_v2_new_forward, quantize=quantize)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants