Skip to content

Commit

Permalink
testing kv_cache quantization [WIP]
Browse files Browse the repository at this point in the history
Summary:

the peak memory improvement is extremely small, tried a few things to fix
this but didn't have any luck. Accuracy is very poor (text is
unintelligible) tried to leave most recent token not quantized (since we
have full fidelity information for whatever the current token is). That
didn't solve the issue and resulted in a significant memory increase, may need to try affine quantization but
currently more concerned with the lack of memory improvement. (see
benchmark_results.txt for the results see kv_quant: True vs kv_quant:
False for comparison.)

i also took a memory trace  you can get with (if you're a meta
employee)

jf download
GCqU9BqGNUybzv8CABWUzUtOiPZ5bsIXAAAz --file "mem_trace_kvq.html"

Test Plan: sh benchmarks.sh

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles committed Jul 20, 2024
1 parent 3804d74 commit 1ec9b07
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 27 deletions.
12 changes: 12 additions & 0 deletions torchao/_models/llama/benchmark_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@
20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

# done with quantization of latest token
20240718131341, tok/s=108.87, mem/s=1438.62 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240718131549, tok/s=103.15, mem/s=1363.06 GB/s, peak_mem=13.86 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240718131820, tok/s=163.84, mem/s=1084.89 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240718132103, tok/s=154.76, mem/s=1024.78 GB/s, peak_mem= 8.93 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8

# done with full accuracy for latest token
20240718150644, tok/s=109.23, mem/s=1443.43 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240718151152, tok/s=100.29, mem/s=1325.29 GB/s, peak_mem=14.14 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240718151349, tok/s=166.08, mem/s=1099.70 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
20240718152147, tok/s=140.85, mem/s= 932.66 GB/s, peak_mem= 9.21 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8
8 changes: 7 additions & 1 deletion torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt
# in readme
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
Expand All @@ -22,3 +22,9 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt

#####
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --kv_cache_quantization
36 changes: 27 additions & 9 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,11 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
next_token, next_prob = next_token.clone(), next_prob.clone()
input_pos += 1
new_tokens.append(next_token.clone())
new_tokens.append(next_token)
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
new_probs.append(next_prob)
cur_token = next_token.view(1, -1)

return new_tokens, new_probs
Expand All @@ -88,23 +89,32 @@ def generate(
*,
interactive: bool,
callback = lambda x: x,
kv_cache_quantization: bool = False,
**sampling_kwargs
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""

# create an empty tensor of the expected final shape and fill in the current tokens
device = prompt.device
T = prompt.numel()
T_new = T + max_new_tokens
seq = torch.empty(T_new, dtype=prompt.dtype, device=device)
seq[:T] = prompt.view(-1)

# setup model cache
max_seq_length = min(T_new, model.config.block_size) if not interactive else 350
with torch.device(device):
model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)
if kv_cache_quantization:
from model import QuantizedKVCache
# go through the model and do the swaps
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
_replace_with_custom_fn_if_matches_filter(
model,
QuantizedKVCache.from_float,
lambda x, y: isinstance(x, torchao._models.llama.model.KVCache),
)


# format model input
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
Expand Down Expand Up @@ -147,6 +157,7 @@ def main(
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
quantization: Optional[str] = None,
kv_cache_quantization: bool = False,
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
Expand All @@ -157,6 +168,7 @@ def main(
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
"""

# torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=1000000, trace_alloc_record_context=True)
torchao.quantization.utils.recommended_inductor_config_setter()

assert checkpoint_path.is_file(), checkpoint_path
Expand All @@ -179,9 +191,7 @@ def main(
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
prompt_length = encoded.size(0)

torch.manual_seed(1234)


torch.manual_seed(1234)
if quantization:
from torchao.quantization.quant_api import (
quantize_,
Expand Down Expand Up @@ -276,7 +286,14 @@ def callback(x):
callback=callback,
temperature=temperature,
top_k=top_k,
kv_cache_quantization=kv_cache_quantization,
)
# if i==3:
# snapshot = torch.cuda.memory._snapshot()
# from pickle import dump
# with open("mem_trace_kvq_no_comp" + '.pickle', 'wb') as f:
# dump(snapshot, f)
# breakpoint()
if i == -1:
print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds")
continue
Expand Down Expand Up @@ -305,12 +322,13 @@ def callback(x):
print(f"Model Size: {model_size:.02f} GB")
if write_result:
result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} "
result_txt += f"repro: python generate.py "
result_txt += f"--quantization {quantization} " if quantization else ""
result_txt += f"--checkpoint_path {checkpoint_path} "
result_txt += f"--device {device} "
result_txt += f"--precision {precision} "
result_txt += f"--kv_cache_quantization " if kv_cache_quantization else ""
result_txt += f"--compile " if compile else ""
result_txt += f"--compile_prefill " if compile_prefill else ""
result_txt += f"--profile {profile} " if profile else ""
Expand Down Expand Up @@ -348,5 +366,5 @@ def callback(x):
args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.compile, args.compile_prefill, args.profile, args.device, args.precision, args.write_result
)
61 changes: 44 additions & 17 deletions torchao/_models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torc
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
assert input_pos.shape[0] == k_val.shape[2]

if use_index_put_for_kv_cache:
k_out = torch.ops.aten.index_put_(self.k_cache, [None, None, input_pos], k_val)
v_out = torch.ops.aten.index_put_(self.v_cache, [None, None, input_pos], v_val)
Expand All @@ -97,23 +96,51 @@ def update(self, input_pos, k_val, v_val):

return k_out, v_out

# class QuantizedKVCache(nn.Module):
# def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
# super().__init__()
# cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
# self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.uint8))
# self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.uint8))
# self.register_buffer('k_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16))
# self.register_buffer('v_cache_scale', torch.ones(cache_shape, dtype=torch.bfloat16))

# (Pdb) p k_val.shape
# torch.Size([1, 32, 6, 128])
# (Pdb) p self.k_cache.shape
# torch.Size([1, 32, 208, 128]) so want final size to be 1,32,208,[1]

from torchao.quantization.quant_primitives import quantize_affine, dequantize_affine
from torchao.quantization.utils import quantize_activation_per_token_absmax

class QuantizedKVCache(nn.Module):
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype=torch.bfloat16):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
scale_shape = (max_batch_size, n_heads, max_seq_length, 1)
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=torch.int8))
self.register_buffer('k_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))
self.register_buffer('v_cache_scale', torch.ones(scale_shape, dtype=scale_dtype))

# def update(self, input_pos, k_val, v_val):
# k_out = self.k_cache
# v_out = self.v_cache
# k_out[:, :, input_pos] = k_val
# v_out[:, :, input_pos] = v_val

# @classmethod
# def from_kv_cache(cls, kv_cache):
def update(self, input_pos, k_val, v_val):
# k_out = self.k_cache*self.k_cache_scale
# v_out = self.v_cache*self.v_cache_scale
# k_out[:, :, input_pos] = k_val
# v_out[:, :, input_pos] = v_val

q_k_val, k_scale = quantize_activation_per_token_absmax(k_val)
self.k_cache[:, :, input_pos] = q_k_val
self.k_cache_scale[:, :, input_pos] = k_scale.unsqueeze(-1)
del k_val

q_v_val, v_scale = quantize_activation_per_token_absmax(v_val)
self.k_cache[:, :, input_pos] = q_v_val
self.k_cache_scale[:, :, input_pos] = v_scale.unsqueeze(-1)
del v_val

# return k_out, v_out
return self.k_cache*self.k_cache_scale, self.v_cache*self.v_cache_scale

@classmethod
def from_float(cls, kv_cache):
cache_shape = kv_cache.k_cache.shape
max_batch_size, n_heads, max_seq_length, head_dim = cache_shape
scale_dtype = kv_cache.k_cache.dtype
return cls(max_batch_size, max_seq_length, n_heads, head_dim, scale_dtype)



class Transformer(nn.Module):
Expand Down

0 comments on commit 1ec9b07

Please sign in to comment.