From 2623bab37513205d57cb371b4fdf70af7446823a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 6 Aug 2024 10:41:28 -0700 Subject: [PATCH] finalizing PR Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- test/integration/test_integration.py | 12 +++--- torchao/_models/llama/benchmark_results.txt | 3 -- torchao/_models/llama/benchmarks.sh | 46 ++++++++++----------- torchao/quantization/autoquant.py | 14 ++----- 4 files changed, 33 insertions(+), 42 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 150ca1be90..81fd1ffebd 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1173,13 +1173,13 @@ def test_on_dummy_distilbert(self): class TestAutoQuant(unittest.TestCase): @parameterized.expand(combine_parameters(COMMON_DEVICE_DTYPE, [ - # (16, 128, 128), - # (64, 128, 128), + (16, 128, 128), + (64, 128, 128), # (2**15, 128, 128), TODO: Runs out of shared memory on T4 - (2, 128, 256), + (16, 128, 256), # (64, 128, 256), # TODO: Runs out of shared memory on T4 - # (16, 256, 128), - # (64, 256, 128), + (16, 256, 128), + (64, 256, 128), # (256, 256, 128), TODO: Runs out of shared memory on T4 ])) @unittest.skipIf(not TORCH_VERSION_AFTER_2_3, "autoquant requires 2.3+.") @@ -1194,7 +1194,7 @@ def test_autoquant_one_input(self, device, dtype, m, k, n): if m == 1: self.skipTest(f"Shape {(m, k, n)} requires sm80+") torch._inductor.config.epilogue_fusion = False - # torch._inductor.config.use_mixed_mm = True + torch._inductor.config.use_mixed_mm = True torch._inductor.config.force_fuse_int_mm_with_mul = True torch._dynamo.config.automatic_dynamic_shapes = False diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index d26cb82a9a..e07a3e6799 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -27,6 +27,3 @@ kv cache quantization: 20240801094415, tok/s= 87.20, mem/s=1308.88 GB/s, peak_mem=17.22 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 2048 --top_k 200 --temperature 0.8 20240801095615, tok/s= 80.87, mem/s=1213.82 GB/s, peak_mem=19.77 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 20240801100912, tok/s= 74.65, mem/s=1120.41 GB/s, peak_mem=19.29 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, kv_quant: True, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --kv_cache_quantization --compile --num_samples 5 --max_new_tokens 8192 --top_k 200 --temperature 0.8 - -20240806071013, tok/s=172.58, mem/s=1161.55 GB/s, peak_mem= 8.90 GB, model_size= 6.73 GB quant: autoquant, mod: Llama-2-7b-chat-hf, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240806073549, tok/s=158.04, mem/s=1192.77 GB/s, peak_mem= 9.99 GB, model_size= 7.55 GB quant: autoquant, mod: Meta-Llama-3-8B, kv_quant: False, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 402acff91b..6dd9c10d94 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -2,31 +2,31 @@ export CHECKPOINT_PATH=../../../checkpoints # path to checkpoints folder export MODEL_REPO=meta-llama/Llama-2-7b-chat-hf -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# # in readme -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# in readme +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt export MODEL_REPO=meta-llama/Meta-Llama-3-8B -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt -# # in readme -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --precision torch.float32 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt +# in readme +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int4wo-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt -# export MODEL_REPO=meta-llama/Meta-Llama-3-8B -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 -# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 +export MODEL_REPO=meta-llama/Meta-Llama-3-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 2048 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --max_new_tokens 8192 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --kv_cache_quantization --max_new_tokens 8192 diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index 122b909501..1456bcd9de 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -269,19 +269,13 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res:0.3f}ms for {cls}, to_beat: {best_time:0.3f}ms ") return res -###### TODO !!!!!!!!!!!!!!! -# 1) make class method from_float (just duplicate code) -# 2) undo changes to quant_api? -# 3) point to new quantized_op location -# 4) rewrite the dynamic autoquant test - class AQInt8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor): """ AutoQuantizable version of Int8DynamicallyQuantizedLinearWeight """ @classmethod def from_float(cls, weight): - in_features = weight.shape[1] + # in_features = weight.shape[1] # int8 dynamic quantization only has benefit when in_feature > 16 # if in_features <= 16: # return weight @@ -352,8 +346,8 @@ def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]): print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms") # if the (much faster) matmul kernel is already beat, don't bother benchmarking full op - # if res_matmul>=best_time: - # return res_matmul + if res_matmul>=best_time: + return res_matmul # calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul) @@ -448,7 +442,7 @@ def from_float(cls, weight): AQWeightOnlyQuantizedLinearWeight, AQWeightOnlyQuantizedLinearWeight2, # AQWeightOnlyQuantizedLinearWeight3, - # # TODO this gets picked in places where it makes perf worse, why? + # TODO this gets picked in places where it makes perf worse, why? AQInt8DynamicallyQuantizedLinearWeight, ]