diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 7e58b591b0..a1cf2c4368 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1202,7 +1202,7 @@ def test_autoquant_compile(self, device, dtype, m1, m2, k, n): mod = torchao.autoquant(torch.compile(model), manual=True) mod(example_input) mod(example_input2) - mod.do_autoquant() + mod.finalize_autoquant() out2 = mod(example_input) sqnr = SQNR(out, out2) @@ -1229,7 +1229,7 @@ def test_autoquant_manual(self, device, dtype): mod = torchao.autoquant(torch.compile(model), manual=True) mod(example_input) mod(example_input2) - mod.do_autoquant() + mod.finalize_autoquant() out2 = mod(example_input) sqnr = SQNR(out, out2) self.assertTrue(sqnr >= 30) @@ -1237,7 +1237,7 @@ def test_autoquant_manual(self, device, dtype): mod2 = torchao.autoquant(model, manual=True) mod2(example_input) mod2(example_input2) - mod2.do_autoquant() + mod2.finalize_autoquant() out3 = mod(example_input) sqnr2 = SQNR(out, out3) self.assertTrue(sqnr2 >= 30) diff --git a/torchao/_models/llama/benchmark_results.txt b/torchao/_models/llama/benchmark_results.txt index 67e6202541..b02d4c2441 100644 --- a/torchao/_models/llama/benchmark_results.txt +++ b/torchao/_models/llama/benchmark_results.txt @@ -1,16 +1,19 @@ -20240613174456, tok/s= 31.00, mem/s= 819.31 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613174647, tok/s= 27.37, mem/s= 361.70 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613182718, tok/s=106.44, mem/s=1406.56 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613174842, tok/s=105.13, mem/s=1389.20 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613183515, tok/s= 9.13, mem/s= 60.44 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613183057, tok/s=149.30, mem/s= 988.60 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613182903, tok/s=200.52, mem/s= 749.11 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613191947, tok/s=158.20, mem/s=1063.00 GB/s, peak_mem= 8.89 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613165002, tok/s= 28.99, mem/s= 870.30 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613165204, tok/s= 26.63, mem/s= 399.74 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613173355, tok/s= 96.00, mem/s=1440.96 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613165407, tok/s= 94.99, mem/s=1425.76 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613174305, tok/s= 8.35, mem/s= 62.80 GB/s, peak_mem= 8.98 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613173835, tok/s=138.99, mem/s=1045.23 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613173609, tok/s=178.52, mem/s= 753.69 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 -20240613200137, tok/s=141.06, mem/s=1062.97 GB/s, peak_mem=10.03 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../../gpt-fast/checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619101342, tok/s= 29.85, mem/s= 788.87 GB/s, peak_mem=27.23 GB, model_size=26.43 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619101537, tok/s= 26.38, mem/s= 348.57 GB/s, peak_mem=13.62 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619105331, tok/s=106.55, mem/s=1408.06 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240619105522, tok/s=105.14, mem/s=1389.35 GB/s, peak_mem=13.88 GB, model_size=13.21 GB quant: None, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619105921, tok/s= 9.20, mem/s= 60.93 GB/s, peak_mem= 8.33 GB, model_size= 6.62 GB quant: int8dq, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619110107, tok/s=150.18, mem/s= 994.40 GB/s, peak_mem= 8.95 GB, model_size= 6.62 GB quant: int8wo, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619110248, tok/s=199.86, mem/s= 746.66 GB/s, peak_mem= 4.50 GB, model_size= 3.74 GB quant: int4wo-64, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619114518, tok/s=159.22, mem/s=1069.87 GB/s, peak_mem= 8.91 GB, model_size= 6.72 GB quant: autoquant, mod: Llama-2-7b-chat-hf, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240619114732, tok/s= 30.46, mem/s= 914.43 GB/s, peak_mem=32.34 GB, model_size=30.02 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.float32, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float32 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619114939, tok/s= 26.56, mem/s= 398.65 GB/s, peak_mem=16.16 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: False, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619122811, tok/s= 96.09, mem/s=1442.32 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 + +20240619123018, tok/s= 94.97, mem/s=1425.55 GB/s, peak_mem=16.43 GB, model_size=15.01 GB quant: None, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619123441, tok/s= 8.44, mem/s= 63.45 GB/s, peak_mem= 8.98 GB, model_size= 7.52 GB quant: int8dq, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619123652, tok/s=139.76, mem/s=1051.02 GB/s, peak_mem=10.42 GB, model_size= 7.52 GB quant: int8wo, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619123847, tok/s=179.44, mem/s= 757.60 GB/s, peak_mem= 6.62 GB, model_size= 4.22 GB quant: int4wo-64, mod: Meta-Llama-3-8B, compile: True, compile_prefill: False, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization int4wo-64 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 +20240619131959, tok/s=137.71, mem/s=1037.74 GB/s, peak_mem=11.08 GB, model_size= 7.54 GB quant: autoquant, mod: Meta-Llama-3-8B, compile: True, compile_prefill: True, dtype: torch.bfloat16, device: cuda repro: python generate.py --quantization autoquant --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --num_samples 5 --max_new_tokens 200 --top_k 200 --temperature 0.8 diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index f22eba2f2f..7f7cfab885 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -217,7 +217,7 @@ def main( ) # do autoquantization - model.do_autoquant() + model.finalize_autoquant() else: unwrap_tensor_subclass(model) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 6e0cb5ee4b..20e9ed3c7e 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -7,17 +7,17 @@ Using the lm_eval. The models used were meta-llama/Llama-2-7b-chat-hf and meta-l | Model | Technique | wikitext-perplexity | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | | ----------- | ------------------ | ------------------- | ------------- | ----------------------- | ---------------- | --------------- | -| Llama-2-7B | Base (bfloat16) | 12.212 | 105.13 | 1389.20 | 13.88 | 13.21 | -| | int8dq | 12.262 | 9.13 | 60.44 | 8.33 | 6.62 | -| | int8wo | 12.204 | 149.30 | 988.60 | 8.95 | 6.62 | -| | int4wo-64 | 12.843 | 200.52 | 749.11 | 4.50 | 4.75 | -| | int4wo-64-GPTQ | 12.489 | 200.52 | 746.45 | 4.50 | 4.75 | -| | autoquant | 12.204 | 158.20 | 1063.00 | 8.89 | 6.72 | -| Llama-3-8B | Base (bfloat16) | N/A | 94.99 | 1425.76 | 16.43 | 15.01 | -| | int8dq | N/A | 8.35 | 62.80 | 8.98 | 7.52 | -| | int8wo | N/A | 136.75 | 1045.23 | 10.42 | 7.52 | -| | int4wo-64 | N/A | 178.52 | 753.69 | 6.62 | 4.22 | -| | autoquant | N/A | 141.06 | 1062.97 | 10.03 | 7.54 | +| Llama-2-7B | Base (bfloat16) | 12.212 | 105.14 | 1389.35 | 13.88 | 13.21 | +| | int8dq | 12.262 | 9.20 | 60.93 | 8.33 | 6.62 | +| | int8wo | 12.204 | 150.18 | 994.40 | 8.95 | 6.62 | +| | int4wo-64 | 12.843 | 199.86 | 746.66 | 4.50 | 3.74 | +| | int4wo-64-GPTQ | 12.489 | 199.86 | 746.66 | 4.50 | 3.74 | +| | autoquant | 12.204 | 159.22 | 1069.87 | 8.91 | 6.72 | +| Llama-3-8B | Base (bfloat16) | N/A | 94.97 | 1425.55 | 16.43 | 15.01 | +| | int8dq | N/A | 8.44 | 63.45 | 8.98 | 7.52 | +| | int8wo | N/A | 139.76 | 1051.02 | 10.42 | 7.52 | +| | int4wo-64 | N/A | 179.44 | 757.60 | 6.62 | 4.22 | +| | autoquant | N/A | 137.71 | 1037.74 | 11.08 | 7.54 | note: Int8 dynamic quantization works best on compute bound models like [SAM](https://github.com/pytorch-labs/segment-anything-fast) whereas Llama with batchsize=1 tends to be memory bound, thus the rather low performance. diff --git a/torchao/quantization/autoquant.py b/torchao/quantization/autoquant.py index a4a7aaa113..1a63fed57c 100644 --- a/torchao/quantization/autoquant.py +++ b/torchao/quantization/autoquant.py @@ -459,7 +459,7 @@ def autoquant( the torch.compile process normally proceeds as well. To optimize over a combination of input shapes/dtypes, the user can set manual=True, run the model with all desired shapes/dtypes, then - call model.do_autoquant to finalize the quantization once the desired set of inputs have been logged. + call model.finalize_autoquant to finalize the quantization once the desired set of inputs have been logged. Args: model (torch.nn.Module): The model to be autoquantized. @@ -470,7 +470,7 @@ def autoquant( mode (list, optional): A list containing mode settings for quantization. The first element is the mode type (e.g., "interpolate"), and the second element is the mode value (e.g., 0.85). Defaults to ["interpolate", .85]. manual (bool, optional): Whether to stop shape calibration and do autoquant after a single run (default, False) or to wait for - the user to call model.do_autoquant (True) so inputs with several shapes/dtypes can be logged. + the user to call model.finalize_autoquant (True) so inputs with several shapes/dtypes can be logged. **aq_kwargs: Additional keyword arguments for the autoquantization process. Returns: @@ -485,7 +485,7 @@ def autoquant( torchao.autoquant(model, manual=True) model(*example_input1) model(*example_input2) - model.do_autoquant() + model.finalize_autoquant() """ # perform initial swap from linear weights @@ -519,7 +519,7 @@ def autoquant( # autoquantization def autoquant_prehook(module, args, kwargs): real_model.forward(*args, **kwargs) - module.do_autoquant() + module.finalize_autoquant() return args, kwargs # the autoquant_prehook intercepts the forward call, performs logging then @@ -530,7 +530,7 @@ def autoquant_prehook(module, args, kwargs): # note the torch.compile wrapper (eval_frame) moves the assignment of any assigned # attributes to the inner model that didn't exist before, so we have to call delattr on the inner model - def do_autoquant(): + def finalize_autoquant(): change_autoquantizable_to_quantized( real_model, **aq_kwargs, @@ -538,12 +538,12 @@ def do_autoquant(): if hasattr(real_model, "old_forward"): model.forward = real_model.old_forward delattr(real_model, "old_forward") - if hasattr(real_model, "do_autoquant"): - delattr(real_model, "do_autoquant") + if hasattr(real_model, "finalize_autoquant"): + delattr(real_model, "finalize_autoquant") if not manual: handle.remove() - real_model.do_autoquant = do_autoquant + real_model.finalize_autoquant = finalize_autoquant # if example input was provided, check it and run it if isinstance(example_input, torch.Tensor):