Skip to content

Commit

Permalink
doc: fix fp8 bmm documentation (#470)
Browse files Browse the repository at this point in the history
The documentation was not indexed properly in #469 , this PR fixes the
issue.
  • Loading branch information
yzh119 authored Aug 27, 2024
1 parent 3d38d0d commit d357a91
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 12 deletions.
21 changes: 16 additions & 5 deletions docs/api/python/gemm.rst
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
.. _apigroup_gemm:
.. _apigemm:

flashinfer.group_gemm
=====================
flashinfer.gemm
===============

This module provides a set of functions to group GEMM operations.
.. currentmodule:: flashinfer.gemm

.. currentmodule:: flashinfer.group_gemm
This module provides a set of GEMM operations.

FP8 Batch GEMM
--------------

.. autosummary::
:toctree: ../../generated

bmm_fp8

Grouped GEMM
------------

.. autoclass:: SegmentGEMMWrapper
:members:
Expand Down
7 changes: 4 additions & 3 deletions python/flashinfer/cascade.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,10 @@ class BatchDecodeWithSharedPrefixPagedKVCacheWrapper:
Check :ref:`our tutorial<page-layout>` for page table layout.
It is recommended to use :class:`MultiLevelCascadeAttentionWrapper` instead for general
multi-level cascade inference, where the KV-Cache of each level is stored in a unified
page table. This API will be deprecated in the future.
Warning
-------
This API will be deprecated in the future, please use
:class:`MultiLevelCascadeAttentionWrapper` instead.
Example
-------
Expand Down
30 changes: 27 additions & 3 deletions python/flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from .utils import get_indptr
from typing import Optional

# mypy: disable-error-code="attr-defined"
try:
Expand Down Expand Up @@ -204,7 +205,7 @@ def bmm_fp8(
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: torch.Tensor = None,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""BMM FP8
Expand All @@ -225,13 +226,36 @@ def bmm_fp8(
dtype: torch.dtype
out dtype, bf16 or fp16.
out: torch.Tensor
Out tensor, shape (b, m, n), bf16 or fp16.
out: Optional[torch.Tensor]
Out tensor, shape (b, m, n), bf16 or fp16, defaults to ``None``.
Returns
-------
out: torch.Tensor
Out tensor, shape (b, m, n), bf16 or fp16.
Examples
--------
>>> import torch
>>> import torch.nn.functional as F
>>> import flashinfer
>>> def to_float8(x, dtype=torch.float8_e4m3fn):
... finfo = torch.finfo(dtype)
... abs_max = x.abs().amax(dim=(1, 2), keepdim=True).clamp(min=1e-12)
... scale = finfo.max / abs_max
... x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
... return x_scl_sat.to(dtype), scale.float().reciprocal()
>>>
>>> input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16)
>>> input_fp8, input_inv_s = to_float8(input, dtype=torch.float8_e4m3fn)
>>> # column major weight
>>> weight = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose(-2, -1)
>>> weight_fp8, weight_inv_s = to_float8(weight, dtype=torch.float8_e4m3fn)
>>> out = flashinfer.bmm_fp8(input_fp8, weight_fp8, input_inv_s, weight_inv_s, torch.bfloat16)
>>> out.shape
torch.Size([16, 48, 80])
>>> out.dtype
torch.bfloat16
"""
if out is None:
out = torch.empty(
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_group_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_segment_gemm(
pytest.skip("batch_size * num_rows_per_batch too large for test.")
torch.manual_seed(42)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8).to(0)
segment_gemm = flashinfer.group_gemm.SegmentGEMMWrapper(workspace_buffer)
segment_gemm = flashinfer.gemm.SegmentGEMMWrapper(workspace_buffer)
x = (
(torch.randn(batch_size * num_rows_per_batch, d_in) / 10)
.to(0)
Expand Down

0 comments on commit d357a91

Please sign in to comment.