Skip to content

Commit

Permalink
float8 dynamic autoquant
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Sep 26, 2024
1 parent fbe97a0 commit 58fe60d
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 4 deletions.
9 changes: 9 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
AQInt8WeightOnlyQuantizedLinearWeight3,
AutoQuantizableLinearWeight,
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8DynamicallyQuantizedLinearWeight,
)
from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
import os
Expand Down Expand Up @@ -753,6 +754,14 @@ def test_aq_float8_weight_only_quant_subclass(self, device, dtype):
AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch")
@unittest.skipIf(not is_H100, "Need H100 to run")
def test_aq_float8_dynamic_quant_subclass(self, device, dtype):
self._test_lin_weight_subclass_impl(
AQFloat8DynamicallyQuantizedLinearWeight.from_float, device, 25, test_dtype=dtype
)

@parameterized.expand(COMMON_DEVICE_DTYPE)
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
Expand Down
9 changes: 6 additions & 3 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,12 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
"""
# torch.compile path
if dynamo_is_compiling() or "FakeTensor" in input.__repr__():
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)

try:
return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2)
except Exception:
# fallback path, would run on H100 for float8 dtypes
# Exception on H100 float8 dtype : "addmm_cuda" not implemented for 'Float8_e4m3fn'
return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32)
# error checking for cublas path
assert (
mat2.device == input.device
Expand All @@ -53,7 +57,6 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
and j_is_nonzero_multiple_of_8
and k_is_nonzero_multiple_of_8
)

if device_cpu or bad_dimensions_for_cublas:
# fallback path
return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to(
Expand Down
88 changes: 87 additions & 1 deletion torchao/quantization/autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def do_autoquant_bench(op, *args, **kwargs):
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream):
op(*args, **kwargs)
Expand Down Expand Up @@ -492,6 +491,92 @@ def from_float(cls, weight):
block_size = (1, weight.shape[1])
return super(AQFloat8WeightOnlyQuantizedLinearWeight, cls).from_hp_to_floatx(weight, block_size, target_dtype=cls.target_dtype, layout_type=Float8LayoutType())

class AQFloat8DynamicallyQuantizedLinearWeight(AQMixin, LinearActivationQuantizedTensor):
"""
AutoQuantizable version of Float8DynamicallyQuantizedLinearWeight
"""
@classmethod
def from_float(cls, weight):

# avoid circular dep
from torchao.dtypes import to_affine_quantized_floatx
# weight settings
def get_weight_block_size(x):
return (1, x.shape[1])
target_dtype = torch.float8_e4m3fn

# input settings
def get_per_token_block_size(x):
block_size = list(x.shape)
for i in range(len(block_size)-1):
block_size[i] = 1
return block_size

input_target_dtype = torch.float8_e4m3fn
layout_type = Float8LayoutType()
input_quant_func = lambda x: to_affine_quantized_floatx(
input_float=x,
block_size=get_per_token_block_size(x),
target_dtype=input_target_dtype,
layout_type=layout_type
)
block_size = get_weight_block_size(weight)
weight = to_affine_quantized_floatx(
input_float=weight,
block_size=block_size,
target_dtype=target_dtype,
layout_type=layout_type
)
weight = super(AQFloat8DynamicallyQuantizedLinearWeight, cls).from_float(weight, input_quant_func)
return weight

@classmethod
def _autoquant_test(cls, act_mat, weight, bias, best_time, mode=["relu", None]):
"""
Tests and benchmarks the autoquantization process with special handling for interpolate mode.
Args:
act_mat (torch.Tensor): The activation matrix.
weight (torch.Tensor): The weight tensor.
bias (torch.Tensor or None): The bias tensor.
best_time (float): The best time to beat for the quantization process.
mode (list, optional): A list containing mode settings for quantization. The first element is the mode type
(e.g., "relu"), and the second element is the mode value (e.g., None). Defaults to ["relu", None].
Returns:
float: The benchmarked time for the autoquantization process.
"""
if not _is_interpolate_mode(mode):
return super()._autoquant_test(act_mat, weight, bias, best_time, mode)

# SAM best is between .8 and 1, SDXL also performs best in this range
INTERPOLATION_CONSTANT = mode[1]
w_qtensor = cls.from_float(weight)
x_vals_float8, x_scales = quantize_activation_per_token_absmax(
act_mat.reshape(-1, act_mat.shape[-1])
)
quantized_matmul = (
lambda x_vals_float8, x_scales, w_vals_float8:
safe_int_mm(x_vals_float8, w_vals_float8) * x_scales
)
q_c_matmul=torch.compile(quantized_matmul, mode="max-autotune-no-cudagraphs")
with torch.no_grad():
w_vals_float8 = w_qtensor.original_weight_tensor.layout_tensor.float8_data.contiguous().t()
res_matmul = do_autoquant_bench(q_c_matmul, x_vals_float8, x_scales.reshape(-1,1), w_vals_float8)
print(f">>time: {res_matmul:0.3f}ms for {cls} matmul, to_beat: {best_time:0.3f}ms")

# if the (much faster) matmul kernel is already beat, don't bother benchmarking full op
if res_matmul>=best_time:
return res_matmul

# calculate what time full op needs to beat for dynamic quant to be best given INTERPOLATION_CONSTANT
to_beat = best_time + INTERPOLATION_CONSTANT/(1-INTERPOLATION_CONSTANT)*(best_time-res_matmul)
res = super()._autoquant_test(act_mat, weight, bias, to_beat)
max_float_const_win = (best_time-res_matmul)/(res-res_matmul)
res_f = INTERPOLATION_CONSTANT*res+(1-INTERPOLATION_CONSTANT)*res_matmul
print(f">>time: {res_f:0.3f}ms for {cls} interpolated, breakeven constant: {max_float_const_win:0.2f}")
return res_f


# here we don't include int4 quantization in since int8 tends to be a better apples to apples comparison
DEFAULT_AUTOQUANT_CLASS_LIST = [
Expand All @@ -511,6 +596,7 @@ def from_float(cls, weight):

OTHER_AUTOQUANT_CLASS_LIST = [
AQFloat8WeightOnlyQuantizedLinearWeight,
AQFloat8DynamicallyQuantizedLinearWeight,
]


Expand Down

0 comments on commit 58fe60d

Please sign in to comment.