Skip to content

Commit

Permalink
misc: return type overload for return_lse (#578)
Browse files Browse the repository at this point in the history
Make mypy and pylance happy.

For example, here's screenshots of vscode viewing
`test_batch_decode_kernels.py`.

return_lse=False:

![return_lse=False](https://github.com/user-attachments/assets/625919ca-04f8-4f37-b4b0-d89c15423c51)

return_lse=True:

![return_lse=True](https://github.com/user-attachments/assets/53a01a2b-c8b2-4629-922e-ca8dc8fad18c)
  • Loading branch information
abcdabcd987 authored Nov 2, 2024
1 parent 5d454ed commit fc0f6d4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 2 deletions.
24 changes: 23 additions & 1 deletion python/flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import math
from types import SimpleNamespace
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Literal, Optional, Tuple, Union, overload

import torch

Expand Down Expand Up @@ -831,6 +831,28 @@ def forward(
q, paged_kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale
)

@overload
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
return_lse: Literal[False] = False,
) -> torch.Tensor: ...

@overload
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
q_scale: Optional[float] = None,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
return_lse: Literal[True] = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ...

def run(
self,
q: torch.Tensor,
Expand Down
80 changes: 79 additions & 1 deletion python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import math
from types import SimpleNamespace
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Literal, Optional, Tuple, Union, overload

import torch

Expand Down Expand Up @@ -371,6 +371,46 @@ def single_prefill_with_kv_cache_with_jit_module(
return (out, lse) if return_lse else out


@overload
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
return_lse: Literal[False] = False,
) -> torch.Tensor: ...


@overload
def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
custom_mask: Optional[torch.Tensor] = None,
packed_custom_mask: Optional[torch.Tensor] = None,
causal: bool = False,
kv_layout: str = "NHD",
pos_encoding_mode: str = "NONE",
allow_fp16_qk_reduction: bool = False,
sm_scale: Optional[float] = None,
window_left: int = -1,
logits_soft_cap: Optional[float] = None,
rope_scale: Optional[float] = None,
rope_theta: Optional[float] = None,
return_lse: Literal[True] = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ...


def single_prefill_with_kv_cache(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -1077,6 +1117,26 @@ def forward(
self._rope_theta = rope_theta
return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale)

@overload
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
return_lse: Literal[False] = False,
) -> torch.Tensor: ...

@overload
def run(
self,
q: torch.Tensor,
paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
return_lse: Literal[True] = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ...

def run(
self,
q: torch.Tensor,
Expand Down Expand Up @@ -1643,6 +1703,24 @@ def forward(
self._rope_theta = rope_theta
return self.run(q, k, v)

@overload
def run(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
return_lse: Literal[False] = False,
) -> torch.Tensor: ...

@overload
def run(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
return_lse: Literal[True] = True,
) -> Tuple[torch.Tensor, torch.Tensor]: ...

def run(
self,
q: torch.Tensor,
Expand Down

0 comments on commit fc0f6d4

Please sign in to comment.