Skip to content

Commit

Permalink
Merge branch 'pytorch:main' into adam4bit_doc
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored Sep 2, 2024
2 parents 4cea083 + ba2d3b1 commit d83a1c1
Show file tree
Hide file tree
Showing 86 changed files with 4,064 additions and 976 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/ruff_linter.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: Code Analysis with Ruff

on:
push:
branches:
- main
pull_request:

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff
- name: Analyzing the code with ruff
run: |
ruff check .
- name: Check well formatted code
run: |
ruff format --check
20 changes: 20 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.6
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ For inference, we have the option of
2. Quantize the activations and weights and sparsify the weight

```python
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int4_weight, int8_dynamic_activation_int8_weight, int8_dynamic_activation_int8_semi_sparse_weight, int4_weight_only, int8_weight_only
from torchao.quantization.quant_api import (
quantize_,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
int4_weight_only,
int8_weight_only
)
quantize_(m, int4_weight_only())
```

Expand Down Expand Up @@ -95,7 +102,7 @@ from torchao.prototype.low_bit_optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
```

In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code ** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)
In practice, we are a tiny bit slower than expertly written kernels but the implementations for these optimizers were written in a **few hundred lines of PyTorch code** and compiled so please use them or copy-paste them for your quantized optimizers. Benchmarks [here](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim)

We also have support for [single GPU CPU offloading](https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim#optimizer-cpu-offload) where both the gradients (same size as weights) and the optimizers will be efficiently sent to the CPU. This alone can **reduce your VRAM requirements by 60%**

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import torch
import pandas as pd
import torch.nn.functional as F
from torchao.prototype.quant_llm import QuantLlmLinearWeight
from torchao.dtypes import to_affine_quantized_fpx
from torchao.dtypes.fpx import FpxTensorCoreAQTLayout, FpxTensorCoreLayoutType
from torchao.utils import benchmark_torch_function_in_microseconds
from tqdm import tqdm


def benchmark(m: int, k: int, n: int):
fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda")
scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5
fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2)

float_data = torch.randn(n, k, dtype=torch.half, device="cuda")
fp6_weight = to_affine_quantized_fpx(float_data, FpxTensorCoreLayoutType(3, 2))
fp16_weight = fp6_weight.dequantize(torch.half)

fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda")
Expand Down
31 changes: 28 additions & 3 deletions benchmarks/float8/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def float8_pct_top_peak(self):
return self.float8_tops_sec / dtype_to_peak_tops[torch.float8_e4m3fn]


# TODO(future PR): add option to measure GPU kernel time, as in other
# scripts in this folder
def main(
sweep_path: Optional[Path] = None,
compile: bool = True,
Expand All @@ -112,10 +114,33 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)

config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)

name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
Expand Down
32 changes: 27 additions & 5 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,35 @@ def main(
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)

if scaling_type_input is ScalingType.STATIC:
cast_config_input=CastConfig(
scaling_type=scaling_type_input,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_input=CastConfig(scaling_type=scaling_type_input)
if scaling_type_weight is ScalingType.STATIC:
cast_config_weight=CastConfig(
scaling_type=scaling_type_weight,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_weight=CastConfig(scaling_type=scaling_type_weight)
if scaling_type_grad_output is ScalingType.STATIC:
cast_config_grad_output=CastConfig(
scaling_type=scaling_type_grad_output,
static_scale=torch.tensor([1.0], device="cuda"),
)
else:
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output)

config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
enable_amax_init=False,
enable_pre_and_post_forward=False,
cast_config_input=cast_config_input,
cast_config_weight=cast_config_weight,
cast_config_grad_output=cast_config_grad_output,
)

scaling_repr = "_".join(
[
s.short_str()
Expand Down
4 changes: 4 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,7 @@ tabulate # QOL for printing tables to stdout

# Custom CUDA Extensions
ninja

# Linting
ruff
pre-commit
4 changes: 3 additions & 1 deletion docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ torchao.dtypes
:nosignatures:

to_nf4
to_affine_quantized
to_affine_quantized_intx
to_affine_quantized_floatx
to_affine_quantized_intx_static
AffineQuantizedTensor

..
Expand Down
8 changes: 8 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Currently, we ignore all files in the project by default.
# We plan to add files in chunks using the 'include' list below.
# To add a new path: Simply add it to the 'include' list.
# Example: To lint all files in every subfolder of 'test', add "test/**/*"
include = [
"torchao/float8/inference.py",
"torchao/float8/float8_utils.py",
]
7 changes: 5 additions & 2 deletions scripts/hf_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
int8_dynamic_activation_int8_weight,
quantize_,
autoquant,
fpx_weight_only,
)
from torchao.sparsity import (
sparsify_,
Expand Down Expand Up @@ -59,6 +60,8 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, spars
elif quantization == "int4wo":
# note cannot quantize this model on cpu and run it on cuda at this time
quantize_(model.to(device=device), int4_weight_only())
elif quantization == "fp6":
quantize_(model, fpx_weight_only(3, 2))
elif quantization == "autoquant":
model = autoquant(model.to(device=device))

Expand All @@ -79,7 +82,7 @@ def all_linear(mod, name):
return False
torch.sparse.semi_structured._FORCE_CUTLASS = False
sparsify_(model, semi_sparse_weight(), filter_fn=all_linear)

if sparsity and compile:
model = torch.compile(model, mode="max-autotune", fullgraph=True)

Expand Down Expand Up @@ -111,7 +114,7 @@ def all_linear(mod, name):
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo", "autoquant", "None"], help='Which quantization technique to apply')
parser.add_argument('-s', '--sparsity', default = "None", choices=["semi_sparse", "semi_sparse_mlp_only", "None"], help='Which sparsity technique to apply')
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
parser.add_argument('--save', action='store_true', help='Whether to save the model.')
Expand Down
2 changes: 2 additions & 0 deletions scripts/prepare.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
74 changes: 47 additions & 27 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,36 @@
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
int8_dynamic_activation_int8_semi_sparse_weight,
float8_weight_only,
)
from torchao.dtypes import (
to_affine_quantized,
)
from torch.testing._internal import common_utils
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

import torch
import unittest
import tempfile

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def get_quantization_functions(do_sparse: bool, do_int4: bool):
base_functions = [
int8_weight_only(),
int8_dynamic_activation_int4_weight(),
int8_dynamic_activation_int8_weight(),
]
if do_int4:
base_functions.append(int4_weight_only(group_size=32))

if do_sparse:
base_functions.append(int8_dynamic_activation_int8_semi_sparse_weight())

if is_cuda_8_9:
base_functions.append(float8_weight_only())

return base_functions


class TestAffineQuantized(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
Expand All @@ -39,36 +59,36 @@ def test_tensor_core_layout_transpose(self):
self.assertEqual(aqt_shape, shape)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_weights_only(self):
for apply_quant in [int4_weight_only(group_size=32), int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight(), int8_dynamic_activation_int8_semi_sparse_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)
@common_utils.parametrize("apply_quant", get_quantization_functions(True, True))
def test_weights_only(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
ql = apply_quant(l)
with tempfile.NamedTemporaryFile() as f:
torch.save(ql.state_dict(), f)
f.seek(0)
# `weights_only=True` is enabled for torch 2.5+
if TORCH_VERSION_AT_LEAST_2_5:
_ = torch.load(f, weights_only=True)
else:
_ = torch.load(f, weights_only=False)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_to_device(self):
from torchao.quantization import quantize_
for apply_quant in [int8_weight_only(), int8_dynamic_activation_int4_weight(), int8_dynamic_activation_int8_weight()]:
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")
@common_utils.parametrize("apply_quant", get_quantization_functions(False, False))
def test_to_device(self, apply_quant):
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to("cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.to(device="cuda")

l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()
l = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
ql = apply_quant(l)
ql.cuda()


common_utils.instantiate_parametrized_tests(TestAffineQuantized)

if __name__ == "__main__":
run_tests()
Loading

0 comments on commit d83a1c1

Please sign in to comment.