Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions aiter/ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,10 @@ def per_tensor_quant_hip(x, scale=None, quant_dtype=dtypes.i8):
def per_token_quant_triton(x, scale=None, quant_dtype=dtypes.i8):
shape = x.shape
device = x.device
dtypeMax = get_dtype_max(quant_dtype)
y = torch.empty(shape, dtype=quant_dtype, device=device)
if scale is None:
scale = torch.empty((*shape[:-1], 1), dtype=dtypes.fp32, device=device)
triton.quant.dynamic_per_token_fp8_quant(
y, x, scale, quant_dtype=quant_dtype, dtypeMax=dtypeMax
)
triton.quant.dynamic_per_token_quant_fp8_i8(y, x, scale)
else:
raise ValueError("unsupported: static per token quant")

Expand All @@ -277,9 +274,9 @@ def per_tensor_quant_triton(x, scale=None, quant_dtype=dtypes.i8):
x = x.view(-1, x.shape[-1])
if scale is None:
scale = torch.zeros(1, dtype=dtypes.fp32, device=x.device)
triton.quant.dynamic_per_tensor_fp8_quant(y, x, scale)
triton.quant.dynamic_per_tensor_quant_fp8_i8(y, x, scale)
else:
triton.quant.static_per_tensor_fp8_quant(y, x, scale)
triton.quant.static_per_tensor_quant_fp8_i8(y, x, scale)
return y, scale


Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/moe_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import triton
import triton.language as tl
from typing import Any, Dict, Optional, List
from aiter.ops.triton.quant import dynamic_per_tensor_fp8_quant
from aiter.ops.triton.quant import dynamic_per_tensor_quant_fp8_i8
from aiter.ops.triton.utils.pid_preprocessing import pid_grid, remap_xcd
from aiter.ops.triton.utils.moe_common import _write_zeros_to_output

Expand All @@ -15,7 +15,7 @@

_PADDING_SIZE = 0

_MOE_A_QUANT_FUNC = dynamic_per_tensor_fp8_quant
_MOE_A_QUANT_FUNC = dynamic_per_tensor_quant_fp8_i8

_USE_MOE_PERSISTENT_KERNEL = False

Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/moe_op_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import triton.language as tl
from typing import Any, Dict, Optional

from aiter.ops.triton.quant import dynamic_per_tensor_fp8_quant
from aiter.ops.triton.quant import dynamic_per_tensor_quant_fp8_i8
from aiter.ops.triton.utils.types import torch_to_triton_dtype

# Source:
# MoE Kernel adapted from VLLM

_PADDING_SIZE = 0

_MOE_A_QUANT_FUNC = dynamic_per_tensor_fp8_quant
_MOE_A_QUANT_FUNC = dynamic_per_tensor_quant_fp8_i8

_USE_MOE_PERSISTENT_KERNEL = False

Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/moe_op_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import triton.language as tl
from typing import Any, Dict, Optional, List

from aiter.ops.triton.quant import dynamic_per_tensor_fp8_quant
from aiter.ops.triton.quant import dynamic_per_tensor_quant_fp8_i8
from aiter.ops.triton.activation import _gelu_tanh
from aiter.ops.triton.utils.pid_preprocessing import pid_grid, remap_xcd
from aiter.ops.triton.utils.moe_common import _write_zeros_to_output
Expand All @@ -16,7 +16,7 @@

_PADDING_SIZE = 0

_MOE_A_QUANT_FUNC = dynamic_per_tensor_fp8_quant
_MOE_A_QUANT_FUNC = dynamic_per_tensor_quant_fp8_i8

_USE_MOE_PERSISTENT_KERNEL = False

Expand Down
4 changes: 2 additions & 2 deletions aiter/ops/triton/moe_op_silu_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional, List

from aiter.ops.triton.activation import _silu_exp2
from aiter.ops.triton.quant import dynamic_per_tensor_fp8_quant
from aiter.ops.triton.quant import dynamic_per_tensor_quant_fp8_i8
from aiter.ops.triton.utils.pid_preprocessing import pid_grid, remap_xcd
from aiter.ops.triton.utils.moe_common import _write_zeros_to_output

Expand All @@ -16,7 +16,7 @@

_PADDING_SIZE = 0

_MOE_A_QUANT_FUNC = dynamic_per_tensor_fp8_quant
_MOE_A_QUANT_FUNC = dynamic_per_tensor_quant_fp8_i8

_USE_MOE_PERSISTENT_KERNEL = False

Expand Down
96 changes: 59 additions & 37 deletions aiter/ops/triton/quant.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.

import triton
import triton.language as tl
import torch


@triton.jit
def _static_per_tensor_fp8_quant_kernel(
def _static_per_tensor_quant_fp8_i8_kernel(
qx_ptr,
x_in_ptr,
scale_in_ptr,
cols: int,
x_in_stride_r: int,
NUM_COL_POW2: tl.constexpr,
):
"""
#TODO: Add Doc
"""

pid = tl.program_id(axis=0)
tl.assume(pid > 0)
tl.assume(x_in_stride_r > 0)
Expand All @@ -35,37 +31,41 @@ def _static_per_tensor_fp8_quant_kernel(
tl.store(qx_ptr + offs, qx, mask=mask)


def static_per_tensor_fp8_quant(
def static_per_tensor_quant_fp8_i8(
qx: torch.Tensor, x_in: torch.Tensor, scale_in: torch.Tensor
):
"""
#TODO: Add Doc
Quantizes tensor using the provided scale to int8 or fp8

Parameters:
- qx: Output tensor of same shape as x_in. Must be fp8 or int8 dtype and allocated by the caller
- x_in: Input tensor of shape (M, N).
- scale_in: Input Scale tensor of shape (1,) and dtype fp32

Returns:
- qx: Quantized output values.
"""
assert scale_in.numel() == 1 # only single scale value
rows = x_in.shape[0]
cols = x_in.shape[1]
NUM_COL_POW2 = triton.next_power_of_2(cols)
grid = lambda meta: (rows,)
_static_per_tensor_fp8_quant_kernel[grid](
grid = lambda meta: (rows,) # noqa: E731
_static_per_tensor_quant_fp8_i8_kernel[grid](
qx, x_in, scale_in, cols, x_in.stride(0), NUM_COL_POW2=NUM_COL_POW2
)

return qx


@triton.jit
def _dynamic_per_tensor_fp8_quant_kernel(
def _dynamic_per_tensor_quant_fp8_i8_kernel(
x_in_ptr,
scale_out_ptr,
cols: int,
x_in_stride_r: int,
NUM_COL_POW2: tl.constexpr,
FP8_MAX: tl.constexpr,
DTYPE_MAX: tl.constexpr,
):
"""
#TODO: Add Doc
"""

pid = tl.program_id(axis=0)
tl.assume(pid > 0)
tl.assume(x_in_stride_r > 0)
Expand All @@ -75,49 +75,59 @@ def _dynamic_per_tensor_fp8_quant_kernel(
x = tl.load(x_in_ptr + offs, mask=mask, cache_modifier=".cg")

m = tl.max(tl.abs(x))
tl.atomic_max(scale_out_ptr, m / FP8_MAX, sem="relaxed")
tl.atomic_max(scale_out_ptr, m / DTYPE_MAX, sem="relaxed")


def dynamic_per_tensor_fp8_quant(
def dynamic_per_tensor_quant_fp8_i8(
qx: torch.Tensor, x_in: torch.Tensor, scale_out: torch.Tensor
):
"""
#TODO: Add Doc
Calculate per tensor scale and then uses the scale to quantize input tensor to fp8 or int8

Parameters:
- x_in: Input tensor of shape (M, N).
- qx: Output tensor of same shape as x_in. Must be fp8 or int8 dtype and allocated by the caller
- scale_out: Output scale tensor of shape (1,), dtype fp32 and allocated by the caller

Returns:
- qx: Quantized output values of shape (M, N) with dtype fp8 or int8
- scale_out: Single scale value of shape (1,)
"""

rows = x_in.shape[0]
cols = x_in.shape[1]
NUM_COL_POW2 = triton.next_power_of_2(cols)
grid = lambda meta: (rows,)
_dynamic_per_tensor_fp8_quant_kernel[grid](
grid = lambda meta: (rows,) # noqa: E731
_dynamic_per_tensor_quant_fp8_i8_kernel[grid](
x_in,
scale_out,
cols,
x_in.stride(0),
NUM_COL_POW2=NUM_COL_POW2,
FP8_MAX=torch.finfo(qx.dtype).max,
DTYPE_MAX=(
torch.finfo(qx.dtype).max
if torch.is_floating_point(qx)
else torch.iinfo(qx.dtype).max
),
)

_static_per_tensor_fp8_quant_kernel[grid](
_static_per_tensor_quant_fp8_i8_kernel[grid](
qx, x_in, scale_out, cols, x_in.stride(0), NUM_COL_POW2=NUM_COL_POW2
)

return qx, scale_out


@triton.jit
def _dynamic_per_token_fp8_quant_kernel(
def _dynamic_per_token_quant_fp8_i8_kernel(
qx_ptr,
scale_out_ptr,
x_in_ptr,
cols: int,
x_in_stride_r: int,
NUM_COL_POW2: tl.constexpr,
FP8_MAX: tl.constexpr,
DTYPE_MAX: tl.constexpr,
):
"""
#TODO: Add Doc
"""

pid = tl.program_id(axis=0)
tl.assume(pid > 0)
tl.assume(x_in_stride_r > 0)
Expand All @@ -127,7 +137,7 @@ def _dynamic_per_token_fp8_quant_kernel(
x = tl.load(x_in_ptr + offs, mask=mask, cache_modifier=".cg")

m = tl.max(tl.abs(x), axis=-1)
scale_out = m / FP8_MAX
scale_out = m.to(tl.float32) / DTYPE_MAX
scale_recip = 1 / scale_out

qx = x * scale_recip
Expand All @@ -139,28 +149,40 @@ def _dynamic_per_token_fp8_quant_kernel(
tl.store(qx_ptr + offs, qx, mask=mask, cache_modifier=".cs")


def dynamic_per_token_fp8_quant(
def dynamic_per_token_quant_fp8_i8(
qx: torch.Tensor,
x_in: torch.Tensor,
scale_out: torch.Tensor,
quant_dtype=torch.float8_e4m3fnuz,
dtypeMax: torch.Tensor = torch.finfo(torch.float8_e4m3fnuz).max,
):
"""
#TODO: Add doc
Quantizes tensor using the provided scale

Parameters:
- x_in: Input tensor of shape (M, N).
- dtype_max: Optional parameter which specifies the max value of the dtype of x_in.
- qx: Output tensor of same shape as x_in. Must be fp8 dtype and allocated by the caller
- scale_out: Output scale tensor of shape (M,) dtype fp32 and allocated by the caller

Returns:
- qx: Quantized output values.
- scale_out: Scale tensor of shape (M, )
"""
rows = x_in.shape[0]
cols = x_in.shape[1]
NUM_COL_POW2 = triton.next_power_of_2(cols)
grid = lambda meta: (rows,)
_dynamic_per_token_fp8_quant_kernel[grid](
grid = lambda meta: (rows,) # noqa: E731
_dynamic_per_token_quant_fp8_i8_kernel[grid](
qx,
scale_out,
x_in,
cols,
x_in.stride(0),
NUM_COL_POW2=NUM_COL_POW2,
FP8_MAX=dtypeMax,
DTYPE_MAX=(
torch.finfo(qx.dtype).max
if torch.is_floating_point(qx)
else torch.iinfo(qx.dtype).max
),
)

return qx, scale_out
Expand Down
Loading