Skip to content

Commit 95d2c4f

Browse files
committed
Merge branch 'main' into sam2fast5
2 parents 1acc8d0 + 65098de commit 95d2c4f

32 files changed

+214
-46
lines changed

.github/workflows/regression_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
gpu-arch-version: "12.1"
4141
- name: CUDA Nightly
4242
runs-on: linux.g5.12xlarge.nvidia.gpu
43-
torch-spec: '--pre torch==2.6.0.dev20241022 --index-url https://download.pytorch.org/whl/nightly/cu121'
43+
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
4444
gpu-arch-type: "cuda"
4545
gpu-arch-version: "12.1"
4646

CITATION.cff

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
cff-version: 1.2.0
2+
title: "torchao: PyTorch native quantization and sparsity for training and inference"
3+
message: "If you use this software, please cite it as below."
4+
type: software
5+
authors:
6+
- given-names: "torchao maintainers and contributors"
7+
url: "https//github.com/pytorch/torchao"
8+
license: "BSD-3-Clause"
9+
date-released: "2024-10-25"

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
[![](https://dcbadge.vercel.app/api/server/gpumode?style=flat)](https://discord.gg/gpumode)
44

5-
[Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Composability](#composability) | [Custom Kernels](#custom-kernels) | [Alpha Features](#alpha-features) | [Installation](#installation) | [Integrations](#integrations) | [Videos](#videos) | [License](#license)
5+
[Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Composability](#composability) | [Custom Kernels](#custom-kernels) | [Alpha Features](#alpha-features) | [Installation](#installation) | [Integrations](#integrations) | [Videos](#videos) | [License](#license) | [Citation](#citation)
66

77
## Introduction
88

@@ -192,3 +192,17 @@ We're also fortunate to be integrated into some of the leading open-source libra
192192
## License
193193

194194
`torchao` is released under the [BSD 3](https://github.com/pytorch-labs/ao/blob/main/LICENSE) license.
195+
196+
# Citation
197+
198+
If you find the torchao library useful, please cite it in your work as below.
199+
200+
```bibtex
201+
@software{torchao,
202+
title = {torchao: PyTorch native quantization and sparsity for training and inference},
203+
author = {torchao maintainers and contributors},
204+
url = {https//github.com/pytorch/torchao},
205+
license = {BSD-3-Clause},
206+
month = oct,
207+
year = {2024}
208+
```

test/dtypes/test_affine_quantized_float.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,39 @@ def test_serialization(self, mode: str):
236236
original_layer.weight.scale, new_layer.weight.scale
237237
), f"Scales do not match for {layer_name}"
238238

239+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
240+
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
241+
def test_fp8_weight_dimension_warning(self):
242+
# Create model with incompatible dimensions (not multiples of 16)
243+
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights
244+
245+
# Set up logging capture
246+
with self.assertLogs(
247+
"torchao.quantization.quant_api", level="INFO"
248+
) as log_context:
249+
quantize_(
250+
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
251+
)
252+
print(model)
253+
254+
# Verify warning messages for both layers
255+
expected_messages = [
256+
"Skipping float8 quantization: weight shape torch.Size([25, 10])",
257+
"Skipping float8 quantization: weight shape torch.Size([10, 25])",
258+
]
259+
# Check that we got warnings for both incompatible layers
260+
warning_count = sum(
261+
1 for msg in log_context.output if "Skipping float8 quantization" in msg
262+
)
263+
self.assertEqual(warning_count, 2, "Expected warnings for both linear layers")
264+
265+
# Check warning message content
266+
for expected in expected_messages:
267+
self.assertTrue(
268+
any(expected in msg for msg in log_context.output),
269+
f"Expected warning message containing: {expected}",
270+
)
271+
239272

240273
common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)
241274

test/integration/test_integration.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
AutoQuantizableLinearWeight,
7575
AQFloat8WeightOnlyQuantizedLinearWeight,
7676
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
77+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
7778
)
7879
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
7980
import os
@@ -770,11 +771,23 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
770771
@parameterized.expand(COMMON_DEVICE_DTYPE)
771772
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
772773
@unittest.skipIf(not is_H100, "Need H100 to run")
773-
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
774+
def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
774775
if dtype != torch.bfloat16:
775-
self.skipTest("Fails for {dtype}")
776+
with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"):
777+
self._test_lin_weight_subclass_impl(
778+
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
779+
)
780+
else:
781+
self._test_lin_weight_subclass_impl(
782+
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
783+
)
784+
785+
@parameterized.expand(COMMON_DEVICE_DTYPE)
786+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
787+
@unittest.skipIf(not is_H100, "Need H100 to run")
788+
def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype):
776789
self._test_lin_weight_subclass_impl(
777-
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
790+
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
778791
)
779792

780793
@parameterized.expand(COMMON_DEVICE_DTYPE)

test/quantization/test_mixed_precision.py renamed to test/prototype/test_mixed_precision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch.nn as nn
55
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
66
from torchao.quantization.utils import compute_error
7-
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
7+
from torchao.prototype.quantization.mixed_precision.scripts import intN_weight_only
88

99
_CUDA_IS_AVAILABLE = torch.cuda.is_available()
1010

torchao/_models/llama/benchmarks.sh

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,18 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
6464
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
6565
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
6666
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
67-
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
67+
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
68+
69+
# Different Batch Size Benchmarks
70+
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
71+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 1
72+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 32
73+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 128
74+
75+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 1
76+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 32
77+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 128
78+
79+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
80+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
81+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128

torchao/_models/llama/eval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,13 @@ def run_evaluation(
104104
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
105105
if "int4wo" in quantization and "gptq" in quantization:
106106
# avoid circular imports
107-
from torchao._models._eval import InputRecorder
108-
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
107+
from torchao._models._eval import MultiTensorInputRecorder
108+
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer
109109
groupsize=int(quantization.split("-")[-2])
110110
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
111111
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
112112
assert "cuda" in device, "int4 gptq quantization only works on cuda"
113-
inputs = InputRecorder(
113+
inputs = MultiTensorInputRecorder(
114114
tokenizer,
115115
calibration_seq_length,
116116
prepare_inputs_for_model,
@@ -122,7 +122,7 @@ def run_evaluation(
122122
calibration_limit,
123123
).get_inputs()
124124

125-
quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device)
125+
quantizer = Int4WeightOnlyGPTQQuantizer(group_size=groupsize, device=device)
126126
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
127127
model = quantizer.quantize(model, inputs).to(device)
128128
else:

torchao/_models/llama/generate.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
4848
return probs
4949

5050
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
51-
probs = logits_to_probs(logits[0, -1], temperature, top_k)
51+
probs = logits_to_probs(logits[:, -1], temperature, top_k)
5252
idx_next = multinomial_sample_one_no_sync(probs)
5353
return idx_next, probs
5454

@@ -75,7 +75,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
7575
new_tokens.append(next_token)
7676
callback(new_tokens[-1])
7777
new_probs.append(next_prob)
78-
cur_token = next_token.view(1, -1)
78+
cur_token = next_token
7979

8080
return new_tokens, new_probs
8181

@@ -88,6 +88,7 @@ def generate(
8888
model: Transformer,
8989
prompt: torch.Tensor,
9090
max_new_tokens: int,
91+
batch_size: int,
9192
*,
9293
interactive: bool,
9394
callback = lambda x: x,
@@ -102,34 +103,34 @@ def generate(
102103

103104
# create an empty tensor of the expected final shape and fill in the current tokens
104105
device = prompt.device
105-
T = prompt.numel()
106+
T = prompt.size(-1)
106107

107108
# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
108109
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
109110
new_tokens = max_seq_length - T
110111

112+
# format model input
113+
prompt, input_pos = prepare_inputs_for_model(prompt)
114+
prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize
115+
111116
# full prompt+output will be stored in seq
112-
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
113-
seq[:T] = prompt.view(-1)
117+
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
118+
seq[:, :T] = prompt
114119

115120
# setup model caches
116121
with torch.device(device):
117122
if cache_size is None:
118123
cache_size = max_seq_length
119124
assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt"
120-
model.setup_caches(max_batch_size=1, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)
121-
122-
# format model input
123-
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
125+
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)
124126

125127
# execute prefill
126-
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
127-
seq[T] = next_token
128+
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
129+
seq[:, T] = next_token.squeeze()
128130
# execute token generation
129131
input_pos = torch.tensor([T], device=device, dtype=torch.int)
130-
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
131-
132-
seq = torch.cat((seq[:T+1], *generated_tokens))
132+
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
133+
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)
133134

134135
return seq
135136

@@ -157,6 +158,7 @@ def main(
157158
interactive: bool = False,
158159
num_samples: int = 5,
159160
max_new_tokens: int = 100,
161+
batch_size: int = 1,
160162
top_k: int = 200,
161163
temperature: float = 0.8,
162164
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
@@ -229,9 +231,9 @@ def main(
229231
use_hqq=True
230232
else:
231233
use_hqq=False
232-
groupsize=int(quantization.split("-")[1])
233-
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
234-
quantize_(model, int4_weight_only(group_size=groupsize))
234+
group_size=int(quantization.split("-")[1])
235+
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
236+
quantize_(model, int4_weight_only(group_size=group_size))
235237
if "marlin" in quantization:
236238
from torchao.dtypes import MarlinSparseLayout
237239
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
@@ -267,9 +269,9 @@ def main(
267269
use_hqq = "hqq" in quantization
268270
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
269271
if "uintx" in quantization:
270-
# uintx-nbits-groupsize, e.g. "uintx-2-64"
272+
# uintx-nbits-group_size, e.g. "uintx-2-64"
271273
if "hqq" in quantization:
272-
# uintx-nbits-groupsize-hqq
274+
# uintx-nbits-group_size-hqq
273275
use_hqq = True
274276
else:
275277
use_hqq = False
@@ -303,6 +305,7 @@ def main(
303305
model,
304306
encode_tokens(tokenizer, prompt, bos=True, device=device),
305307
max_new_tokens,
308+
batch_size,
306309
interactive=False,
307310
temperature=temperature,
308311
top_k=top_k,
@@ -375,6 +378,7 @@ def callback(x):
375378
model,
376379
encoded,
377380
max_new_tokens,
381+
batch_size,
378382
interactive=interactive,
379383
callback=callback,
380384
temperature=temperature,
@@ -392,13 +396,13 @@ def callback(x):
392396
t = time.perf_counter() - t0
393397

394398
if not interactive:
395-
tok_list = y.tolist()
399+
tok_list = y[0].tolist()
396400
# truncate text after end of string token
397-
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
401+
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
398402
print(tokenizer.decode(tokens))
399403
else:
400404
print()
401-
tokens_generated = y.size(0) - prompt_length
405+
tokens_generated = (y.size(-1) - prompt_length)
402406
tokens_sec = tokens_generated / t
403407
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
404408
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
@@ -421,6 +425,8 @@ def callback(x):
421425
bandwidth = model_size * tokpersec
422426
mem = torch.cuda.max_memory_reserved() /1e9
423427
print(f"Average tokens/sec: {tokpersec:.2f}")
428+
if batch_size > 1:
429+
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
424430
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
425431
print(f"Peak Memory Usage: {mem:.02f} GB")
426432
print(f"Model Size: {model_size:.02f} GB")
@@ -439,6 +445,7 @@ def callback(x):
439445
result_txt += f"--interactive " if interactive else ""
440446
result_txt += f"--num_samples {num_samples} "
441447
result_txt += f"--max_new_tokens {max_new_tokens} "
448+
result_txt += f"--batch_size {batch_size} "
442449
result_txt += f"--top_k {top_k} "
443450
result_txt += f"--temperature {temperature} "
444451
result_txt += f"--cache_size {cache_size}" if cache_size else ""
@@ -459,13 +466,15 @@ def callback(x):
459466
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
460467
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
461468
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
469+
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
462470
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
463471
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
464472
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
465473
parser.add_argument('-q', '--quantization', type=str,
466474
help=(
467475
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
468-
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
476+
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
477+
+'embed-int8wo'
469478
)
470479
)
471480
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
@@ -484,6 +493,6 @@ def callback(x):
484493

485494
args = parser.parse_args()
486495
main(
487-
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
496+
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
488497
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
489498
)

torchao/float8/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ This is the most accurate recipe as every tensor is scaled dynamically.
2525
import torch
2626
import torch.nn as nn
2727
from torchao.float8 import convert_to_float8_training
28+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
29+
30+
if not TORCH_VERSION_AT_LEAST_2_5:
31+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
2832

2933
# create model and sample input
3034
m = nn.Sequential(
@@ -73,6 +77,10 @@ from torchao.float8 import (
7377
ScalingType,
7478
CastConfig,
7579
)
80+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
81+
82+
if not TORCH_VERSION_AT_LEAST_2_5:
83+
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")
7684

7785
# create model and sample input
7886
m = nn.Sequential(

0 commit comments

Comments
 (0)