Skip to content

Commit

Permalink
Merge branch 'main' into sam2fast5
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch committed Oct 29, 2024
2 parents 1acc8d0 + 65098de commit 95d2c4f
Show file tree
Hide file tree
Showing 32 changed files with 214 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
gpu-arch-version: "12.1"
- name: CUDA Nightly
runs-on: linux.g5.12xlarge.nvidia.gpu
torch-spec: '--pre torch==2.6.0.dev20241022 --index-url https://download.pytorch.org/whl/nightly/cu121'
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
gpu-arch-type: "cuda"
gpu-arch-version: "12.1"

Expand Down
9 changes: 9 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
cff-version: 1.2.0
title: "torchao: PyTorch native quantization and sparsity for training and inference"
message: "If you use this software, please cite it as below."
type: software
authors:
- given-names: "torchao maintainers and contributors"
url: "https//github.com/pytorch/torchao"
license: "BSD-3-Clause"
date-released: "2024-10-25"
16 changes: 15 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![](https://dcbadge.vercel.app/api/server/gpumode?style=flat)](https://discord.gg/gpumode)

[Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Composability](#composability) | [Custom Kernels](#custom-kernels) | [Alpha Features](#alpha-features) | [Installation](#installation) | [Integrations](#integrations) | [Videos](#videos) | [License](#license)
[Introduction](#introduction) | [Inference](#inference) | [Training](#training) | [Composability](#composability) | [Custom Kernels](#custom-kernels) | [Alpha Features](#alpha-features) | [Installation](#installation) | [Integrations](#integrations) | [Videos](#videos) | [License](#license) | [Citation](#citation)

## Introduction

Expand Down Expand Up @@ -192,3 +192,17 @@ We're also fortunate to be integrated into some of the leading open-source libra
## License

`torchao` is released under the [BSD 3](https://github.com/pytorch-labs/ao/blob/main/LICENSE) license.

# Citation

If you find the torchao library useful, please cite it in your work as below.

```bibtex
@software{torchao,
title = {torchao: PyTorch native quantization and sparsity for training and inference},
author = {torchao maintainers and contributors},
url = {https//github.com/pytorch/torchao},
license = {BSD-3-Clause},
month = oct,
year = {2024}
```
33 changes: 33 additions & 0 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,39 @@ def test_serialization(self, mode: str):
original_layer.weight.scale, new_layer.weight.scale
), f"Scales do not match for {layer_name}"

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
def test_fp8_weight_dimension_warning(self):
# Create model with incompatible dimensions (not multiples of 16)
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights

# Set up logging capture
with self.assertLogs(
"torchao.quantization.quant_api", level="INFO"
) as log_context:
quantize_(
model, float8_dynamic_activation_float8_weight(granularity=PerTensor())
)
print(model)

# Verify warning messages for both layers
expected_messages = [
"Skipping float8 quantization: weight shape torch.Size([25, 10])",
"Skipping float8 quantization: weight shape torch.Size([10, 25])",
]
# Check that we got warnings for both incompatible layers
warning_count = sum(
1 for msg in log_context.output if "Skipping float8 quantization" in msg
)
self.assertEqual(warning_count, 2, "Expected warnings for both linear layers")

# Check warning message content
for expected in expected_messages:
self.assertTrue(
any(expected in msg for msg in log_context.output),
f"Expected warning message containing: {expected}",
)


common_utils.instantiate_parametrized_tests(TestAffineQuantizedFloat8Compile)

Expand Down
19 changes: 16 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight,
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -770,11 +771,23 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
def test_aq_float8_dynamic_quant_rowwise_scaling_subclass(self, device, dtype):
if dtype != torch.bfloat16:
self.skipTest("Fails for {dtype}")
with self.assertRaisesRegex(AssertionError, "PerRow quantization only works for bfloat16 precision"):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)
else:
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_tensorwise_scaling_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8PerRowScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
AQFloat8PerTensorScalingDynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch.nn as nn
from torchao.quantization import quantize_, int8_weight_only, int4_weight_only
from torchao.quantization.utils import compute_error
from torchao.quantization.prototype.mixed_precision.scripts.naive_intNwo import intN_weight_only
from torchao.prototype.quantization.mixed_precision.scripts import intN_weight_only

_CUDA_IS_AVAILABLE = torch.cuda.is_available()

Expand Down
16 changes: 15 additions & 1 deletion torchao/_models/llama/benchmarks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,18 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt
# python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt

# Different Batch Size Benchmarks
export MODEL_REPO=meta-llama/Meta-Llama-3-8B
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8dq --write_result benchmark_results.txt --batch_size 128

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization int8wo --write_result benchmark_results.txt --batch_size 128

python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128
8 changes: 4 additions & 4 deletions torchao/_models/llama/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ def run_evaluation(
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "int4wo" in quantization and "gptq" in quantization:
# avoid circular imports
from torchao._models._eval import InputRecorder
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer
from torchao._models._eval import MultiTensorInputRecorder
from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer
groupsize=int(quantization.split("-")[-2])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
assert precision==torch.bfloat16, f"{quantization} requires precision or bfloat16 but got {precision}"
assert "cuda" in device, "int4 gptq quantization only works on cuda"
inputs = InputRecorder(
inputs = MultiTensorInputRecorder(
tokenizer,
calibration_seq_length,
prepare_inputs_for_model,
Expand All @@ -122,7 +122,7 @@ def run_evaluation(
calibration_limit,
).get_inputs()

quantizer = Int4WeightOnlyGPTQQuantizer(groupsize=groupsize, device=device)
quantizer = Int4WeightOnlyGPTQQuantizer(group_size=groupsize, device=device)
model.setup_caches(max_batch_size=1, max_seq_length=calibration_seq_length)
model = quantizer.quantize(model, inputs).to(device)
else:
Expand Down
57 changes: 33 additions & 24 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
return probs

def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
probs = logits_to_probs(logits[0, -1], temperature, top_k)
probs = logits_to_probs(logits[:, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs

Expand All @@ -75,7 +75,7 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
new_tokens.append(next_token)
callback(new_tokens[-1])
new_probs.append(next_prob)
cur_token = next_token.view(1, -1)
cur_token = next_token

return new_tokens, new_probs

Expand All @@ -88,6 +88,7 @@ def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
batch_size: int,
*,
interactive: bool,
callback = lambda x: x,
Expand All @@ -102,34 +103,34 @@ def generate(

# create an empty tensor of the expected final shape and fill in the current tokens
device = prompt.device
T = prompt.numel()
T = prompt.size(-1)

# calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size)
max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350
new_tokens = max_seq_length - T

# format model input
prompt, input_pos = prepare_inputs_for_model(prompt)
prompt = prompt.repeat(batch_size, 1) # expand prompt based on batchsize

# full prompt+output will be stored in seq
seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device)
seq[:T] = prompt.view(-1)
seq = torch.empty(batch_size, max_seq_length, dtype=prompt.dtype, device=device)
seq[:, :T] = prompt

# setup model caches
with torch.device(device):
if cache_size is None:
cache_size = max_seq_length
assert cache_size >= max_seq_length, "need cache_size to be greater than max_new_tokens + size-of-prompt"
model.setup_caches(max_batch_size=1, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)

# format model input
x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens)
model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T)

# execute prefill
next_token = prefill(model, x, input_pos, **sampling_kwargs).clone()
seq[T] = next_token
next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone()
seq[:, T] = next_token.squeeze()
# execute token generation
input_pos = torch.tensor([T], device=device, dtype=torch.int)
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)

seq = torch.cat((seq[:T+1], *generated_tokens))
generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs)
seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1)

return seq

Expand Down Expand Up @@ -157,6 +158,7 @@ def main(
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
batch_size: int = 1,
top_k: int = 200,
temperature: float = 0.8,
checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"),
Expand Down Expand Up @@ -229,9 +231,9 @@ def main(
use_hqq=True
else:
use_hqq=False
groupsize=int(quantization.split("-")[1])
assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}"
quantize_(model, int4_weight_only(group_size=groupsize))
group_size=int(quantization.split("-")[1])
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
Expand Down Expand Up @@ -267,9 +269,9 @@ def main(
use_hqq = "hqq" in quantization
quantize_(model, awq_uintx(quant_dtype=quant_dtype, group_size = group_size, use_hqq=use_hqq), is_observed_linear)
if "uintx" in quantization:
# uintx-nbits-groupsize, e.g. "uintx-2-64"
# uintx-nbits-group_size, e.g. "uintx-2-64"
if "hqq" in quantization:
# uintx-nbits-groupsize-hqq
# uintx-nbits-group_size-hqq
use_hqq = True
else:
use_hqq = False
Expand Down Expand Up @@ -303,6 +305,7 @@ def main(
model,
encode_tokens(tokenizer, prompt, bos=True, device=device),
max_new_tokens,
batch_size,
interactive=False,
temperature=temperature,
top_k=top_k,
Expand Down Expand Up @@ -375,6 +378,7 @@ def callback(x):
model,
encoded,
max_new_tokens,
batch_size,
interactive=interactive,
callback=callback,
temperature=temperature,
Expand All @@ -392,13 +396,13 @@ def callback(x):
t = time.perf_counter() - t0

if not interactive:
tok_list = y.tolist()
tok_list = y[0].tolist()
# truncate text after end of string token
tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())]
tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())]
print(tokenizer.decode(tokens))
else:
print()
tokens_generated = y.size(0) - prompt_length
tokens_generated = (y.size(-1) - prompt_length)
tokens_sec = tokens_generated / t
aggregate_metrics['tokens_per_sec'].append(tokens_sec)
print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec")
Expand All @@ -421,6 +425,8 @@ def callback(x):
bandwidth = model_size * tokpersec
mem = torch.cuda.max_memory_reserved() /1e9
print(f"Average tokens/sec: {tokpersec:.2f}")
if batch_size > 1:
print(f"Average tokens/sec including batches {batch_size*tokpersec:.2f}")
print(f"Average Bandwidth: {bandwidth:.02f} GB/s")
print(f"Peak Memory Usage: {mem:.02f} GB")
print(f"Model Size: {model_size:.02f} GB")
Expand All @@ -439,6 +445,7 @@ def callback(x):
result_txt += f"--interactive " if interactive else ""
result_txt += f"--num_samples {num_samples} "
result_txt += f"--max_new_tokens {max_new_tokens} "
result_txt += f"--batch_size {batch_size} "
result_txt += f"--top_k {top_k} "
result_txt += f"--temperature {temperature} "
result_txt += f"--cache_size {cache_size}" if cache_size else ""
Expand All @@ -459,13 +466,15 @@ def callback(x):
parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode')
parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.')
parser.add_argument('--max_new_tokens', type=int, default=200, help='Maximum number of new tokens.')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to benchmark with')
parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.')
parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.')
parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.')
parser.add_argument('-q', '--quantization', type=str,
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, embed-int8wo'
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
+'embed-int8wo'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand All @@ -484,6 +493,6 @@ def callback(x):

args = parser.parse_args()
main(
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.top_k,
args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k,
args.temperature, args.checkpoint_path, args.quantization, args.calibration_limit, args.calibration_seq_length, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result
)
8 changes: 8 additions & 0 deletions torchao/float8/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ This is the most accurate recipe as every tensor is scaled dynamically.
import torch
import torch.nn as nn
from torchao.float8 import convert_to_float8_training
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
m = nn.Sequential(
Expand Down Expand Up @@ -73,6 +77,10 @@ from torchao.float8 import (
ScalingType,
CastConfig,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
raise AssertionError("torchao.float8 requires PyTorch version 2.5 or greater")

# create model and sample input
m = nn.Sequential(
Expand Down
Loading

0 comments on commit 95d2c4f

Please sign in to comment.