Skip to content

Commit

Permalink
adding kv_cache quantization (#532)
Browse files Browse the repository at this point in the history
Adding support for kv_cache quantization, we are using simple symmetric quantization, though using the full precision of the k and v values of the current token.

we see tok/s reduction of 3-5 tok/s depending on context length.

image

and a reduction in peak memory

image

We expect this reduction to scale to large context lengths, in the model memory trace we can see the point where we replace the bf16 cache with the int8 cache which visually saves about half of the used memory

Screenshot 2024-08-02 at 2 45 14 AM

at longer context lengths both quantized and non-quantized kv_cache models start outputing weird stuff but otherwise accuracy of the kv_cache quant looks reasonable though e.g. for 2048 context length:

<|begin_of_text|>Hello, my name is Richard Brown and I have been a professional musician for over 25 years. I have played in a number of bands, doing a wide variety of genres (soul/funk, rock, jazz, blues, latin, world). I have played on over a hundred albums so far.
I have played with many different singers, as well as instrumentalists (guitarists, sax players, brass players, etc.). I love to play and try to learn as much as I can from others. I have become an all-round musician - playing keyboards, drums, programming, arranging; as well as writing songs myself. I have my own studio, and I can do sessions online.
I also have my own website, where you can find out more about me and my music.
I hope that you will find the music that you are looking for here.

Otherwise there are some fixes in generate.py to get things working for large context lengths without overflowing beyond the model limit.

test plan:

sh benchmarks.sh

(specifically the last 6 rows of benchmark_results.txt)
  • Loading branch information
HDCharles authored Aug 2, 2024
1 parent db345bd commit 08024c6
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 10 deletions.
10 changes: 10 additions & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
llama 2
20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
Expand All @@ -8,6 +9,7 @@
20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

llama 3
20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
Expand All @@ -17,3 +19,11 @@
20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

kv cache quantization:
20240801093317, tok/s= 95.52, mem/s=1433.80 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240801093529, tok/s= 92.36, mem/s=1386.35 GB/s, peak_mem=16.41 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240801093944, tok/s= 89.88, mem/s=1349.13 GB/s, peak_mem=17.26 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8
20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8
8 changes: 8 additions & 0 deletions torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,11 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt

export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192
43 changes: 33 additions & 10 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token.clone())
new_tokens.append(next_token)
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
new_probs.append(next_prob)
cur_token = next_token.view(1, -1)

return new_tokens, new_probs
Expand All @@ -88,6 +89,7 @@ def generate(
*,
interactive: bool,
callback = lambda x: x,
kv_cache_quantization: bool = False,
**sampling_kwargs
) -> torch.Tensor:
"""
Expand All @@ -97,14 +99,27 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
device = prompt.device
T = prompt.numel()
T_new = T + max_new_tokens
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)

# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
new_tokens = max_seq_length - T

# full prompt+output will be stored in seq
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
seq[:T] = prompt.view(-1)

# setup model cache
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
# setup model caches
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if kv_cache_quantization:
from model import AffineQuantizedKVCache
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
_replace_with_custom_fn_if_matches_filter(
model,
AffineQuantizedKVCache.from_float,
lambda x, y: isinstance(x, torchao._models.llama.model.KVCache),
)


# format model input
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
Expand All @@ -113,8 +128,9 @@ def generate(
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
seq[T] = next_token

# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)

return seq
Expand Down Expand Up @@ -147,6 +163,7 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
kv_cache_quantization: bool = False,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
Expand Down Expand Up @@ -276,6 +293,7 @@ def callback(x):
callback=callback,
temperature=temperature,
top_k=top_k,
kv_cache_quantization=kv_cache_quantization,
)
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
Expand All @@ -286,7 +304,10 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive:
print(tokenizer.decode(y.tolist()))
tok_list = y.tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
print(tokenizer.decode(tokens))
else:
print()
tokens_generated = y.size(0) - prompt_length
Expand All @@ -305,12 +326,13 @@ def callback(x):
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
result_txt += f"--checkpoint_path {checkpoint_path} "
result_txt += f"--device {device} "
result_txt += f"--precision {precision} "
result_txt += f"--kv_cache_quantization " if kv_cache_quantization else ""
result_txt += f"--compile " if compile else ""
result_txt += f"--compile_prefill " if compile_prefill else ""
result_txt += f"--profile {profile} " if profile else ""
Expand All @@ -337,6 +359,7 @@ def callback(x):
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-<groupsize>, autoquant')
parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--compile_prefill', action='store_true', help='Whether to compile the prefill (improves prefill perf, but higher compile times)')
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
Expand All @@ -347,5 +370,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
)
38 changes: 38 additions & 0 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.nn import functional as F
from torchao.utils import find_multiple

# TODO remove suplerfluous arg
def prepare_inputs_for_model(inps, max_new_tokens=1):
# this is because input from lm-eval is 2d
if inps.dim() > 2:
Expand Down Expand Up @@ -97,6 +98,43 @@ def update(self, input_pos, k_val, v_val):

return k_out, v_out


from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine
from torchao.quantization.utils import quantize_activation_per_token_absmax

class AffineQuantizedKVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))

def update(self, input_pos, k_val, v_val):
# quantize current k_val and store it in the cache
q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
self.k_cache[:, :, input_pos] = q_k_val
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
k_out = self.k_cache*self.k_cache_scale
k_out[:, :, input_pos] = k_val

q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
self.v_cache[:, :, input_pos] = q_v_val
self.v_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
v_out = self.v_cache*self.v_cache_scale
v_out[:, :, input_pos] = v_val

return k_out, v_out

@classmethod
def from_float(cls, kv_cache):
cache_shape = kv_cache.k_cache.shape
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
scale_dtype = kv_cache.k_cache.dtype
return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype)

class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
Expand Down

0 comments on commit 08024c6

Please sign in to comment.