Skip to content

Commit

Permalink
Use custom nd to scalar absolute max reduce kernel in max calibration…
Browse files Browse the repository at this point in the history
… runtime to improve perf (mlc-ai#229)

Use custom nd to scalar absolute max reduce kernel
in max calibration runtime to improve perfomance.

Co-authored-by: Chris Sullivan <csullivan@octo.ai>
  • Loading branch information
csullivan and csullivan committed Mar 21, 2024
1 parent 54585a9 commit 4cf21a4
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions python/mlc_chat/quantization/fp8_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,14 @@ def quantize(
), "'max_int_value' must be provided when using fp8-max quantization"

def fused_compute_scale_and_quantize(
tensor: te.Tensor, axis: int, out_shape: Optional[List[tir.PrimExpr]] = None
tensor: te.Tensor,
max_abs: te.Tensor,
axis: int,
out_shape: Optional[List[tir.PrimExpr]] = None,
):
max_int = tir.const(kwargs["max_int_value"], x.dtype)
min_scaling_factor = tir.const(1.0 / (kwargs["max_int_value"] * 512.0), x.dtype)
r_idx = [te.reduce_axis((0, d)) for d in tensor.shape]
max_abs = te.compute(
shape=(1,),
fcompute=lambda *idx: te.max(
te.abs(tensor(*r_idx)),
axis=r_idx,
),
name="max_abs_value",
)

scale = te.compute(
(1,),
lambda *idx: te.max(
Expand All @@ -93,14 +88,21 @@ def fused_compute_scale_and_quantize(

return scaled_act, scale

max_abs = nn.op.extern(
"tvm.contrib.cuda.reduce_max_abs",
[x],
nn.Tensor.placeholder((1,), x.dtype),
)

quant, scale = nn.op.tensor_expr_op( # pylint: disable=invalid-name
lambda tensor: fused_compute_scale_and_quantize( # pylint: disable=protected-access
lambda tensor, max_tensor: fused_compute_scale_and_quantize( # pylint: disable=protected-access
tensor,
max_tensor,
axis=None,
out_shape=x.shape,
),
name_hint="quantize_act",
args=[x],
args=[x, max_abs],
)
return quant, scale
else:
Expand Down

0 comments on commit 4cf21a4

Please sign in to comment.