Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float8 dynamic autoquant #946

Merged
merged 14 commits into from
Oct 2, 2024
9 changes: 9 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8DynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -753,6 +754,14 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
9 changes: 6 additions & 3 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
# torch.compile path
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

try:
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
# error checking for cublas path
assert (
mat2.device == input.device
Expand All @@ -53,7 +57,6 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
and j_is_nonzero_multiple_of_8
and k_is_nonzero_multiple_of_8
)

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
Expand Down
88 changes: 87 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def do_autoquant_bench(op, *args, **kwargs):
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
Expand Down Expand Up @@ -492,6 +491,92 @@ def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())

class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good, lets name this PerRow scaling and lets only import PerRow above

"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):

# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
# weight settings
def get_weight_block_size(x):
drisspg marked this conversation as resolved.
Show resolved Hide resolved
return (1, x.shape[1])
target_dtype = torch.float8_e4m3fn

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_target_dtype = torch.float8_e4m3fn
layout_type = Float8LayoutType()
input_quant_func = lambda x: to_affine_quantized_floatx(
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
input_float=x,
block_size=get_per_token_block_size(x),
target_dtype=input_target_dtype,
layout_type=layout_type
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
layout_type=layout_type
)
weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you test this in practice? we may need different constants for int8 and float8 dynamic, how does it perform in benchmarks and stuff. If you haven't really tested this on 2-3 models it may be better to just remove it and use the default method which will be very conservative under the interpolation mode and will still work reasonably under the relu mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tested it on Llama, the numbers aren't great, but I can push it to next PR, with more benchmarks

"""
Tests and benchmarks the autoquantization process with special handling for interpolate mode.

Args:
act_mat (torch.Tensor): The activation matrix.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor or None): The bias tensor.
best_time (float): The best time to beat for the quantization process.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].

Returns:
float: The benchmarked time for the autoquantization process.
"""
if not _is_interpolate_mode(mode):
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)

# SAM best is between .8 and 1, SDXL also performs best in this range
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_float8, x_scales = quantize_activation_per_token_absmax(
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
act_mat.reshape(-1, act_mat.shape[-1]), dtype=torch.float8_e4m3fn
)
quantized_matmul = (
lambda x_vals_float8, x_scales, w_vals_float8:
safe_int_mm(x_vals_float8, w_vals_float8) * x_scales
jainapurva marked this conversation as resolved.
Show resolved Hide resolved
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t()
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul>=best_time:
return res_matmul

# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
max_float_const_win = (best_time-res_matmul)/(res-res_matmul)
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_float_const_win:0.2f}")
return res_f


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -511,6 +596,7 @@ def from_float(cls, weight):

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8DynamicallyQuantizedLinearWeight,
]


Expand Down
3 changes: 1 addition & 2 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,12 @@ def _get_per_token_block_size(x: torch.Tensor) -> List[int]:
# taken from
# https://github.com/mit-han-lab/smoothquant/blob/2f87951dacfb9238d8d657f52ae83a82a3c9ba0c/smoothquant/fake_quant.py#L26
# and slightly modified
def quantize_activation_per_token_absmax(t):
def quantize_activation_per_token_absmax(t, dtype=torch.int8):
Copy link
Contributor

@HDCharles HDCharles Sep 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this actually work in practice with non int8 dtypes?, we're still using the same quant min/max as +-128, this seems inadvisable.

I also don't think we should extend this function, should probably just call into whatever quant function is normally used for that dtype, this is a specific instance of function where the mapping types and quant min/max are hard coded to specific values so it shouldn't be extended.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this anymore, I'm reverting the changes to this method, as there's another implementation in float8 that I'll be using.

# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]
mapping_type = MappingType.SYMMETRIC
block_size = list(t.shape)
for i in range(len(block_size) - 1):
block_size[i] = 1
dtype = torch.int8
eps = 1e-5
# Note: the original smoothquant does not clamp to qmin/qmax here,
# but some of the tests with bfloat16 ended up with a flipped sign
Expand Down
Loading