Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into lint_fixes_test_quant…
Browse files Browse the repository at this point in the history
…ization
  • Loading branch information
jainapurva committed Dec 1, 2024
2 parents 23778a7 + 22bec74 commit f64be84
Show file tree
Hide file tree
Showing 16 changed files with 177 additions and 104 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ We're also fortunate to be integrated into some of the leading open-source libra
3. Mobius HQQ backend leveraged our int4 kernels to get [195 tok/s on a 4090](https://github.com/mobiusml/hqq#faster-inference)
4. [TorchTune](https://github.com/pytorch/torchtune) for our QLoRA and QAT recipes
5. [torchchat](https://github.com/pytorch/torchchat) for post training quantization
6. [SGLang](https://github.com/sgl-project/sglang/pull/1341) for LLM inference quantization
6. SGLang for LLM serving: [usage](https://github.com/sgl-project/sglang/blob/4f2ee48ed1c66ee0e189daa4120581de324ee814/docs/backend/backend.md?plain=1#L83) and the major [PR](https://github.com/sgl-project/sglang/pull/1341).

## Videos
* [Keynote talk at GPU MODE IRL](https://youtu.be/FH5wiwOyPX4?si=VZK22hHz25GRzBG1&t=1009)
Expand All @@ -205,4 +205,5 @@ If you find the torchao library useful, please cite it in your work as below.
license = {BSD-3-Clause},
month = oct,
year = {2024}
}
```
10 changes: 6 additions & 4 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
int8_weight_only,
)
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
is_sm_at_least_89,
)


def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cuda"):
Expand All @@ -42,7 +44,7 @@ def get_quantization_functions(do_sparse: bool, do_int4: bool, device: str = "cu
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
)

if is_cuda_8_9:
if is_sm_at_least_89():
base_functions.append(float8_weight_only())

return base_functions
Expand Down
37 changes: 26 additions & 11 deletions test/dtypes/test_affine_quantized_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@
MappingType,
choose_qparams_affine,
)
from torchao.utils import (
is_sm_at_least_89,
is_sm_at_least_90,
)

random.seed(0)
torch.manual_seed(0)

is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


class ToyLinearModel(torch.nn.Module):
def __init__(self, in_features, out_features):
Expand All @@ -59,12 +60,14 @@ def forward(self, x):

class TestAffineQuantizedFloat8Compile(InductorTestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"granularity", [PerTensor(), PerRow()] if is_H100 else [PerTensor()]
"granularity", [PerTensor(), PerRow()] if is_sm_at_least_90() else [PerTensor()]
)
# Inputs are (M,..), K, N
@common_utils.parametrize(
Expand Down Expand Up @@ -134,20 +137,26 @@ def test_fp8_linear_variants(
compute_error(output_original, output_quantized) > 20
), f"Quantization error is too high got a SQNR of {error}"

@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_invalid_granularity(self):
with pytest.raises(ValueError, match="Invalid granularity specification"):
float8_dynamic_activation_float8_weight(granularity="invalid")

@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_mismatched_granularity(self):
with pytest.raises(
ValueError,
match="Different granularities for activation and weight are not supported",
):
float8_dynamic_activation_float8_weight(granularity=(PerTensor(), PerRow()))

@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_unsupported_granularity(self):
class UnsupportedGranularity:
pass
Expand All @@ -158,7 +167,9 @@ class UnsupportedGranularity:
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_per_row_with_float32(self):
with pytest.raises(
AssertionError,
Expand All @@ -170,7 +181,9 @@ def test_per_row_with_float32(self):
)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
@common_utils.parametrize("mode", ["dynamic", "weight-only", "static"])
def test_serialization(self, mode: str):
# Create and quantize the model
Expand Down Expand Up @@ -240,7 +253,9 @@ def test_serialization(self, mode: str):
), f"Scales do not match for {layer_name}"

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@unittest.skipIf(not is_cuda_8_9, "Requires GPU with compute capability >= 8.9")
@unittest.skipIf(
not is_sm_at_least_89(), "Requires GPU with compute capability >= 8.9"
)
def test_fp8_weight_dimension_warning(self):
# Create model with incompatible dimensions (not multiples of 16)
model = ToyLinearModel(10, 25).cuda() # 10x25 and 25x10 weights
Expand Down
49 changes: 35 additions & 14 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
import torch
import torch.nn as nn

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -60,10 +64,6 @@
torch.manual_seed(0)


is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
is_cuda_9_0 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_axiswise_reshape(self):
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_cuda_9_0, "Requires CUDA capability >= 9.0")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")
Expand Down Expand Up @@ -333,7 +333,9 @@ def _test_linear_impl(
# verify initialization flags got updated
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"

@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
@pytest.mark.parametrize(
"scaling_type_input",
Expand Down Expand Up @@ -415,7 +417,9 @@ def test_linear_from_recipe(
config,
)

@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
Expand Down Expand Up @@ -462,7 +466,9 @@ def test_autocast_outputs(
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
)
@pytest.mark.parametrize("emulate", [True, False] if is_cuda_8_9 else [True])
@pytest.mark.parametrize(
"emulate", [True, False] if is_sm_at_least_89() else [True]
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_type_cast(self, linear_dtype: torch.dtype, emulate: bool):
m = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
Expand Down Expand Up @@ -523,18 +529,33 @@ def test_repr(self):
s = m.__repr__()
assert "i:dyn_ten,w:del_ten,go:dyn_ten" in s

@unittest.skipIf(not is_cuda_8_9, "CUDA 8.9 not available")
@unittest.skipIf(not is_sm_at_least_89(), "CUDA 8.9 not available")
def test_inference_mode(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
with torch.inference_mode(mode=True):
m(x)

@unittest.skipIf(not is_sm_at_least_89(), "CUDA arch 8.9 not available")
def test_quantize(self):
x = torch.randn(32, 32, device="cuda")
m = nn.Sequential(nn.Linear(32, 32)).cuda()
m = convert_to_float8_training(m)
assert isinstance(m[0], Float8Linear), "Module is not a Float8Linear"
from torchao.quantization.quant_api import float8_weight_only, quantize_

quantize_(m, float8_weight_only())
assert (
m[0].weight.tensor_impl.float8_data.dtype == torch.float8_e4m3fn
), "Post quantization dtype should be torch.float8_e4m3fn"
with torch.no_grad():
m(x)


class TestScaledMM:
@unittest.skipIf(
not is_cuda_8_9,
not is_sm_at_least_89(),
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -576,10 +597,10 @@ def test_scaled_mm_vs_emulated(self, base_dtype, use_fast_accum):
if base_dtype in {torch.bfloat16, torch.float16}:
atol, rtol = 7e-2, 7e-2
else:
atol, rtol = 2e-3, 2e-3
atol, rtol = 3e-3, 3e-3
torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol)

@unittest.skipIf(not is_cuda_8_9, "CUDA not available")
@unittest.skipIf(not is_sm_at_least_89(), "CUDA not available")
def test_different_configs_error(self):
x_fp32 = torch.randn(16, 16, device="cuda")
x_scale = torch.tensor(1.0, device="cuda")
Expand Down Expand Up @@ -615,7 +636,7 @@ def test_different_configs_error(self):
a @ b

@unittest.skipIf(
not is_cuda_8_9,
not is_sm_at_least_89(),
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down
32 changes: 17 additions & 15 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_5,
is_sm_at_least_89,
is_sm_at_least_90,
)

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -46,10 +50,6 @@
from torchao.float8.float8_utils import e4m3_dtype
from torchao.testing.float8.test_utils import get_test_float8_linear_config

# TODO(future PR): standardize IS_H100 with the rest of the codebase
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)


def _test_compile_base(
backend: str,
Expand Down Expand Up @@ -99,7 +99,7 @@ def _test_compile_base(
"scaling_type_grad_output",
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_eager_only(
Expand All @@ -126,7 +126,7 @@ def test_eager_only(


@pytest.mark.parametrize("fullgraph", [True])
@pytest.mark.parametrize("emulate", [False, True] if is_cuda_8_9 else [True])
@pytest.mark.parametrize("emulate", [False, True] if is_sm_at_least_89() else [True])
@pytest.mark.parametrize(
"scaling_type_input", [ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC]
)
Expand Down Expand Up @@ -177,7 +177,7 @@ def test_aot_eager(
[ScalingType.DELAYED, ScalingType.DYNAMIC, ScalingType.STATIC],
)
@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
Expand Down Expand Up @@ -215,7 +215,9 @@ def test_inductor_from_config_params(
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
],
)
@unittest.skipIf(not is_H100, "CUDA with capability 9.0 or greater not available")
@unittest.skipIf(
not is_sm_at_least_90(), "CUDA with capability 9.0 or greater not available"
)
def test_inductor_from_recipe(recipe_name):
torch._dynamo.reset()
config = recipe_name_to_linear_config(recipe_name)
Expand Down Expand Up @@ -253,7 +255,7 @@ def forward(self, x):

# TODO(future): figure out why the test below fails on CUDA capability 8.9
@unittest.skipIf(
not torch.cuda.is_available() or not is_H100,
not torch.cuda.is_available() or not is_sm_at_least_90(),
"CUDA with capability 9.0 or greater not available",
)
def test_float8_with_graph_break_in_the_middle(self):
Expand All @@ -269,7 +271,7 @@ def test_float8_with_graph_break_in_the_middle(self):
torch.testing.assert_close(y_eager, y_compiled)

@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_float8_graph_input(self):
Expand All @@ -293,7 +295,7 @@ def to_float(x):
torch.testing.assert_close(y2_eager, y2_compiled)

@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_float8_graph_output(self):
Expand Down Expand Up @@ -323,7 +325,7 @@ def test_float8_graph_output(self):


@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_sync_amax_func():
Expand Down Expand Up @@ -364,7 +366,7 @@ def __exit__(self, *args):


@unittest.skipIf(
not torch.cuda.is_available() or not is_cuda_8_9,
not torch.cuda.is_available() or not is_sm_at_least_89(),
"CUDA with float8 support not available",
)
def test_sync_amax_func_cuda_graph_success():
Expand Down Expand Up @@ -396,7 +398,7 @@ def test_sync_amax_func_cuda_graph_success():


@unittest.skipIf(
not is_cuda_8_9,
not is_sm_at_least_89(),
"CUDA not available",
)
@pytest.mark.parametrize(
Expand Down
5 changes: 2 additions & 3 deletions test/float8/test_fsdp2/test_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_89

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
Expand Down Expand Up @@ -40,8 +40,7 @@
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.fsdp2_utils import check_parity_bf16_mp, check_parity_no_mp

is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)
if not is_cuda_8_9:
if not is_sm_at_least_89():
pytest.skip("Unsupported CUDA device capability version", allow_module_level=True)


Expand Down
Loading

0 comments on commit f64be84

Please sign in to comment.