Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Sphinx #573

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ help:

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: export FLASHINFER_BUILDING_DOCS=1
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
7 changes: 4 additions & 3 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,6 @@
register_fake_op,
)

if has_prebuilt_ops:
from . import _kernels # type: ignore[attr-defined]


def compile_single_prefill_module(
*args,
Expand Down Expand Up @@ -85,6 +82,8 @@ def get_single_prefill_module(*args):
if args not in _single_prefill_modules:
uri = get_single_prefill_uri(*args)
if has_prebuilt_ops and uri in prebuilt_ops_uri:
from . import _kernels

# NOTE(Zihao): we should avoid hard-coded index like this, refactor it later
mask_mode = args[5]
run_func = lambda *run_args: _kernels.single_prefill_with_kv_cache(
Expand Down Expand Up @@ -157,6 +156,8 @@ def get_batch_prefill_module(*args):
if args not in _batch_prefill_modules:
uri = get_batch_prefill_uri(*args)
if has_prebuilt_ops and uri in prebuilt_ops_uri:
from . import _kernels

# NOTE(Zihao): we should avoid hard-coded index like this, refactor it later
head_dim = args[4]
plan_func = lambda *plan_args: _kernels.batch_prefill_with_kv_cache_plan(
Expand Down
63 changes: 43 additions & 20 deletions python/flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@
"""

import math
import os
from enum import Enum
from typing import Callable, Dict, Iterable, Optional, Sequence, Tuple, Union

import torch
from torch.torch_version import TorchVersion
from torch.torch_version import __version__ as torch_version

IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"


class PosEncodingMode(Enum):
NONE = 0
Expand Down Expand Up @@ -202,26 +205,46 @@ def _check_cached_qkv_data_type(
)


def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
return lambda x: x
return torch.library.custom_op(
name, fn, mutates_args=mutates_args, device_types=device_types, schema=schema
)
if IS_BUILDING_DOCS or TorchVersion(torch_version) < TorchVersion("2.4"):
zhyncs marked this conversation as resolved.
Show resolved Hide resolved

def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
return lambda x: x

def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
if TorchVersion(torch_version) < TorchVersion("2.4"):
def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
return lambda x: x
return torch.library.register_fake(name, fn)

else:

def register_custom_op(
name: str,
fn: Optional[Callable] = None,
/,
*,
mutates_args: Union[str, Iterable[str]],
device_types: Optional[Union[str, Sequence[str]]] = None,
schema: Optional[str] = None,
) -> Callable:
return torch.library.custom_op(
name,
fn,
mutates_args=mutates_args,
device_types=device_types,
schema=schema,
)

def register_fake_op(
name: str,
fn: Optional[Callable] = None,
) -> Callable:
return torch.library.register_fake(name, fn)