Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float8 autoquant weight only #866

Merged
merged 6 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def from_hp_to_floatx(
input_float: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
scale_dtype: Optional[torch.dtype],
layout_type: LayoutType,
scale_dtype: Optional[torch.dtype] = None,
):

if target_dtype in FP8_TYPES:
Expand Down
5 changes: 4 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
input = (
input.contiguous()
) # (it seems the transpose makes cublas check the above j constraint on i)
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe adding a comment to this would be helpful, how these two branches are handled?

Copy link
Contributor Author

@jainapurva jainapurva Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The except is executed if it's a float8 dtype on H100, as there's no implementation for addmm_cuda for float8 dtypes. Added as comment

return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except:
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
else:
def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
Expand Down
19 changes: 18 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Int8WeightOnlyQuantizedLinearWeight,
QuantizedLinearWeightBase,
)
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torch.utils._python_dispatch import return_and_correct_aliasing
from .quant_primitives import (
Expand Down Expand Up @@ -477,6 +477,22 @@ def _quantized_linear_op(act_mat, w_qtensor, bias):
def from_float(cls, weight):
return weight

class AQFloat8WeightOnlyQuantizedLinearWeight(AffineQuantizedTensor, AQMixin):
"""
AutoQuantizable version of Float8WeightOnlyQuantizedLinearWeight for target_dtype=torch.float8_e4m3fn
"""
target_dtype: torch.dtype = torch.float8_e4m3fn

@staticmethod
def _quantized_linear_op(act_mat, w_qtensor, bias):
return torch.nn.functional.linear(act_mat, w_qtensor.dequantize(), bias)

@classmethod
def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
AQFloatLinearWeight,
Expand All @@ -485,6 +501,7 @@ def from_float(cls, weight):
# AQInt8WeightOnlyQuantizedLinearWeight3,
# TODO this gets picked in places where it makes perf worse, why?
AQInt8DynamicallyQuantizedLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
]

DEFAULT_INT4_AUTOQUANT_CLASS_LIST = [
Expand Down
Loading