Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix adapter v2 llm.int8 inference #323

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def main(
device=fabric.device, dtype=dtype, quantization_mode=quantize
):
model = LLaMA.from_name(name)
add_adapter_v2_parameters_to_linear_layers(model)
add_adapter_v2_parameters_to_linear_layers(model, dtype)
Copy link
Contributor

@rasbt rasbt May 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update on the PR! Eager to give this a try!
Btw here I noticed that you'd also have to modify the finetune/adapter_v2.py script so that it includes the dtype in the function call


# 1. Load the pretrained weights
model.load_state_dict(pretrained_checkpoint, strict=False)
Expand Down
19 changes: 13 additions & 6 deletions lit_llama/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch import Tensor
import torch.nn as nn
from torch.nn import functional as F
from lit_llama.quantization import Linear8bitLt

from lit_llama.adapter import LLaMA

Expand All @@ -26,20 +27,26 @@ def adapter_v2_state_from_state_dict(state_dict: dict) -> dict:


def adapter_v2_new_forward(self, input: Tensor) -> Tensor:
weight = self.weight
if isinstance(self, Linear8bitLt):
weight = self.dequantize(input.dtype)
return self.adapter_scale * (
F.linear(input, self.weight, self.bias) + self.adapter_bias
F.linear(input, weight, self.bias) + self.adapter_bias
)


def adapter_v2_linear_with_bias_and_scale(layer):
layer.adapter_bias = torch.nn.Parameter(torch.zeros(layer.weight.shape[0]), requires_grad=True)
layer.adapter_scale = torch.nn.Parameter(torch.ones(layer.weight.shape[0]), requires_grad=True)
def adapter_v2_linear_with_bias_and_scale(layer, dtype):
weight = layer.weight
if isinstance(layer, Linear8bitLt):
weight = layer.dequantize(dtype)
layer.adapter_bias = torch.nn.Parameter(torch.zeros(weight.shape[0]), requires_grad=True)
layer.adapter_scale = torch.nn.Parameter(torch.ones(weight.shape[0]), requires_grad=True)
bound_method = adapter_v2_new_forward.__get__(layer, layer.__class__)
setattr(layer, 'forward', bound_method)
return layer


def add_adapter_v2_parameters_to_linear_layers(model):
def add_adapter_v2_parameters_to_linear_layers(model, dtype):
for module in model.modules():
if isinstance(module, nn.Linear):
adapter_v2_linear_with_bias_and_scale(module)
adapter_v2_linear_with_bias_and_scale(module, dtype)
11 changes: 11 additions & 0 deletions lit_llama/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ def _quantize_weight(self, weight: torch.Tensor) -> None:
setattr(self.weight, "CB", CB)
setattr(self.weight, "SCB", SCB)

def dequantize(self, dtype):
if dtype not in [torch.bfloat16, torch.float16, torch.float32]:
raise ValueError(f"Invalid dtype: {dtype}. Allowed dtypes are: bfloat16, float16, float32")
weight_CB = self.weight.CB
weight_SCB = self.weight.SCB
# Modify SBC shape if it doesn't match CB
if weight_CB.shape[1] != weight_SCB.shape[0]:
weight_SCB = weight_SCB.view(weight_SCB.shape[0], 1)
result = (weight_CB * weight_SCB) / 127
result = result.to(dtype)
return result

if triton is not None:
# This is adapted from the OpenAI Triton matmul example.
Expand Down