Skip to content

Commit

Permalink
[float8nocompile] add triton kernel which does fp8 conversion to col …
Browse files Browse the repository at this point in the history
…major and transpose in col major at once (pytorch#1566)
  • Loading branch information
danielvegamyhre authored Jan 16, 2025
1 parent 5e59b51 commit 522f5b8
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 2 deletions.
162 changes: 160 additions & 2 deletions torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def to_fp8_col_major_t(
block_col_offs[:, None] * output_stride_row
+ block_row_offs[None, :] * output_stride_col
)
out_mask = (block_row_offs[:, None] < output_num_rows) & (
block_col_offs[None, :] < output_num_cols
out_mask = (block_col_offs[:, None] < output_num_rows) & (
block_row_offs[None, :] < output_num_cols
)
tl.store(out_ptr + out_offs, fp8_vals, mask=out_mask)

Expand Down Expand Up @@ -381,6 +381,77 @@ def _to_fp8_row_major_t_and_non_t(
tl.store(row_major_t_out_ptr + row_major_t_offs, fp8_vals.trans(1, 0), mask=mask)


@triton.autotune(configs=kernel_configs_2D, key=["num_elements"])
@triton.jit
def _to_fp8_col_major_t_and_non_t(
input_ptr,
col_major_out_ptr,
col_major_t_out_ptr,
scale_ptr,
num_elements: int,
fp8_dtype_min: float,
fp8_dtype_max: float,
input_num_rows: int,
input_num_cols: int,
input_stride_row: int,
input_stride_col: int,
col_major_out_stride_row: int,
col_major_out_stride_col: int,
col_major_t_out_stride_row: int,
col_major_t_out_stride_col: int,
input_dtype: tl.constexpr,
output_dtype: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
EPS: tl.constexpr,
):
"""
Reads a row-major, high precision input tensor and writes 2 output tensors:
1) fp8 col major tensor (transposed)
2) fp8 col major tensor
"""
# col major tranposed
block_row_id = tl.program_id(axis=0)
block_col_id = tl.program_id(axis=1)

# load scaling factor
scale = tl.load(scale_ptr).to(tl.float32)

# load block of input tensor
block_row_start = block_row_id * BLOCK_SIZE_ROWS
block_col_start = block_col_id * BLOCK_SIZE_COLS
block_row_offs = block_row_start + tl.arange(0, BLOCK_SIZE_ROWS)
block_col_offs = block_col_start + tl.arange(0, BLOCK_SIZE_COLS)
input_offs = (
block_row_offs[:, None] * input_stride_row
+ block_col_offs[None, :] * input_stride_col
)
mask = (block_row_offs[:, None] < input_num_rows) & (
block_col_offs[None, :] < input_num_cols
)
vals = tl.load(input_ptr + input_offs, mask=mask).to(input_dtype)

# perform conversion
vals = vals * scale
fp8_vals = tl.clamp(vals, min=fp8_dtype_min, max=fp8_dtype_max).to(output_dtype)

# 1. write col-major output
out_offs = block_row_offs[:, None] + block_col_offs[None, :] * input_num_rows
tl.store(col_major_out_ptr + out_offs, fp8_vals, mask=mask)

# 2. write tranposed col-major output
col_major_t_num_rows = input_num_cols
col_major_t_num_cols = input_num_rows
out_offs = (
block_col_offs[:, None] * col_major_t_out_stride_row
+ block_row_offs[None, :] * col_major_t_out_stride_col
)
out_mask = (block_col_offs[:, None] < col_major_t_num_rows) & (
block_row_offs[None, :] < col_major_t_num_cols
)
tl.store(col_major_t_out_ptr + out_offs, fp8_vals.trans(1, 0), mask=out_mask)


@triton.autotune(configs=kernel_configs_1D, key=["num_elements"])
@triton.jit
def _amax_atomic(
Expand Down Expand Up @@ -859,6 +930,93 @@ def hp_to_fp8_row_major_t_and_non_t(
return fp8_tensor_row_major, fp8_tensor_row_major_t


def hp_to_fp8_col_major_t_and_non_t(
hp_tensor: torch.Tensor,
fp8_dtype: torch.dtype,
linear_mm_config: LinearMMConfig,
gemm_input_role: GemmInputRole = GemmInputRole.INPUT,
algo: KernelAlgorithm = KernelAlgorithm.ATOMIC_MAX,
) -> Float8Tensor:
assert hp_tensor.is_contiguous(), "input tensor must be contiguous"

tl_input_dtype = FP8_DTYPE_MAP[hp_tensor.dtype]
tl_output_dtype = FP8_DTYPE_MAP[fp8_dtype]

fp8_dtype_min = torch.finfo(fp8_dtype).min
fp8_dtype_max = torch.finfo(fp8_dtype).max

# compute scaling factor for tensor
scale = _hp_tensor_to_scale(
hp_tensor,
tl_input_dtype,
fp8_dtype_max,
algo,
)

# perform fp8 conversion
input_num_rows, input_num_cols = hp_tensor.shape
num_elements = hp_tensor.numel()

# preallocate necessary output tensors
fp8_output_col_major = torch.empty(
(input_num_rows, input_num_cols), dtype=fp8_dtype, device=hp_tensor.device
)
fp8_output_col_major_t = torch.empty_like(
hp_tensor.t(),
dtype=fp8_dtype,
device=hp_tensor.device,
)

# launch triton kernel to perform conversion
grid = lambda meta: (
triton.cdiv(input_num_rows, meta["BLOCK_SIZE_ROWS"]),
triton.cdiv(input_num_cols, meta["BLOCK_SIZE_COLS"]),
)
_to_fp8_col_major_t_and_non_t[grid](
hp_tensor,
fp8_output_col_major,
fp8_output_col_major_t,
scale,
num_elements,
fp8_dtype_min,
fp8_dtype_max,
input_num_rows,
input_num_cols,
hp_tensor.stride(0),
hp_tensor.stride(1),
fp8_output_col_major.stride(0),
fp8_output_col_major.stride(1),
fp8_output_col_major_t.stride(0),
fp8_output_col_major_t.stride(1),
input_dtype=tl_input_dtype,
output_dtype=tl_output_dtype,
EPS=EPS,
)

# for col major we need to update the strides to reflect the new memory layout
col_major_strides = (1, input_num_rows)
fp8_output_col_major = fp8_output_col_major.as_strided(
fp8_output_col_major.size(), col_major_strides
)

# wrap outputs in Float8Tensors
fp8_tensor_col_major = Float8Tensor(
fp8_output_col_major,
scale,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
fp8_tensor_col_major_t = Float8Tensor(
fp8_output_col_major_t,
scale,
orig_dtype=hp_tensor.dtype,
linear_mm_config=linear_mm_config,
gemm_input_role=gemm_input_role,
)
return fp8_tensor_col_major, fp8_tensor_col_major_t


def _hp_tensor_to_scale(
hp_tensor: torch.Tensor,
tl_input_dtype: tl.core.dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
KernelAlgorithm,
hp_to_fp8_col_major,
hp_to_fp8_col_major_t,
hp_to_fp8_col_major_t_and_non_t,
hp_to_fp8_row_and_col_major,
hp_to_fp8_row_major,
hp_to_fp8_row_major_t,
Expand Down Expand Up @@ -410,3 +411,78 @@ def test_fp8_hp_to_fp8_row_major_t_and_non_t(
torch.float8_e4m3fn,
LinearMMConfig(),
)


@pytest.mark.parametrize(
"algo",
[KernelAlgorithm.REDUCTION, KernelAlgorithm.ATOMIC_MAX],
)
@pytest.mark.parametrize(
"input_shape",
[(2, 4), (32, 16), (512, 512)],
)
def test_fp8_hp_to_fp8_col_major_t_and_non_t(
input_shape: tuple[int, int], algo: KernelAlgorithm
):
assert torch.cuda.is_available()
device = "cuda"
input_bf16 = torch.randn(input_shape, dtype=torch.bfloat16, device=device)
x_bf16 = input_bf16.clone().detach().to(device)
y_bf16 = input_bf16.clone().detach().to(device)

# production implementation
x_fp8_row_major = hp_tensor_to_float8_dynamic(
x_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
)
x_fp8_col_major = x_fp8_row_major.t().contiguous().t()
x_fp8_col_major_t = x_fp8_row_major.t()

# float8nocompile triton implementation
y_fp8_col_major, y_fp8_col_major_t = hp_to_fp8_col_major_t_and_non_t(
y_bf16,
torch.float8_e4m3fn,
LinearMMConfig(),
algo=algo,
)

# check scales
assert torch.eq(x_fp8_col_major._scale, y_fp8_col_major._scale)
assert torch.eq(x_fp8_col_major_t._scale, y_fp8_col_major_t._scale)

# check data
assert torch.all(torch.eq(x_fp8_col_major._data, y_fp8_col_major._data))
assert torch.all(torch.eq(x_fp8_col_major_t._data, y_fp8_col_major_t._data))

# check shapes
assert x_fp8_col_major.shape == y_fp8_col_major.shape
assert x_fp8_col_major_t.shape == y_fp8_col_major_t.shape

# check strides
assert x_fp8_col_major.stride() == y_fp8_col_major.stride()
assert x_fp8_col_major_t.stride() == y_fp8_col_major_t.stride()

# check memory layout
assert not is_row_major(x_fp8_col_major.stride())
assert not is_row_major(y_fp8_col_major.stride())
assert not is_row_major(x_fp8_col_major_t.stride())
assert not is_row_major(y_fp8_col_major_t.stride())

# check underlying memory layout
assert (
x_fp8_col_major._data.storage().tolist()
== y_fp8_col_major._data.storage().tolist()
)
assert (
x_fp8_col_major_t._data.storage().tolist()
== y_fp8_col_major_t._data.storage().tolist()
)

# assert that error is raised when input tensor is not contiguous
with pytest.raises(AssertionError, match="tensor must be contiguous"):
hp_to_fp8_col_major_t_and_non_t(
y_bf16.t(), # transpose so tensor memory layout is no longer contiguous
torch.float8_e4m3fn,
LinearMMConfig(),
)

0 comments on commit 522f5b8

Please sign in to comment.