From fc0f6d49b71ea336045e12a8aaa192579342585d Mon Sep 17 00:00:00 2001 From: Lequn Chen Date: Fri, 1 Nov 2024 19:40:48 -0700 Subject: [PATCH] misc: return type overload for return_lse (#578) Make mypy and pylance happy. For example, here's screenshots of vscode viewing `test_batch_decode_kernels.py`. return_lse=False: ![return_lse=False](https://github.com/user-attachments/assets/625919ca-04f8-4f37-b4b0-d89c15423c51) return_lse=True: ![return_lse=True](https://github.com/user-attachments/assets/53a01a2b-c8b2-4629-922e-ca8dc8fad18c) --- python/flashinfer/decode.py | 24 ++++++++++- python/flashinfer/prefill.py | 80 +++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/python/flashinfer/decode.py b/python/flashinfer/decode.py index 2bb5d6b9..95670a77 100644 --- a/python/flashinfer/decode.py +++ b/python/flashinfer/decode.py @@ -17,7 +17,7 @@ import functools import math from types import SimpleNamespace -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union, overload import torch @@ -831,6 +831,28 @@ def forward( q, paged_kv_cache, q_scale=q_scale, k_scale=k_scale, v_scale=v_scale ) + @overload + def run( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + return_lse: Literal[False] = False, + ) -> torch.Tensor: ... + + @overload + def run( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + q_scale: Optional[float] = None, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + return_lse: Literal[True] = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: ... + def run( self, q: torch.Tensor, diff --git a/python/flashinfer/prefill.py b/python/flashinfer/prefill.py index 3d550576..4ef29206 100644 --- a/python/flashinfer/prefill.py +++ b/python/flashinfer/prefill.py @@ -18,7 +18,7 @@ import logging import math from types import SimpleNamespace -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union, overload import torch @@ -371,6 +371,46 @@ def single_prefill_with_kv_cache_with_jit_module( return (out, lse) if return_lse else out +@overload +def single_prefill_with_kv_cache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, + causal: bool = False, + kv_layout: str = "NHD", + pos_encoding_mode: str = "NONE", + allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + return_lse: Literal[False] = False, +) -> torch.Tensor: ... + + +@overload +def single_prefill_with_kv_cache( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + custom_mask: Optional[torch.Tensor] = None, + packed_custom_mask: Optional[torch.Tensor] = None, + causal: bool = False, + kv_layout: str = "NHD", + pos_encoding_mode: str = "NONE", + allow_fp16_qk_reduction: bool = False, + sm_scale: Optional[float] = None, + window_left: int = -1, + logits_soft_cap: Optional[float] = None, + rope_scale: Optional[float] = None, + rope_theta: Optional[float] = None, + return_lse: Literal[True] = True, +) -> Tuple[torch.Tensor, torch.Tensor]: ... + + def single_prefill_with_kv_cache( q: torch.Tensor, k: torch.Tensor, @@ -1077,6 +1117,26 @@ def forward( self._rope_theta = rope_theta return self.run(q, paged_kv_cache, k_scale=k_scale, v_scale=v_scale) + @overload + def run( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + return_lse: Literal[False] = False, + ) -> torch.Tensor: ... + + @overload + def run( + self, + q: torch.Tensor, + paged_kv_cache: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, + return_lse: Literal[True] = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: ... + def run( self, q: torch.Tensor, @@ -1643,6 +1703,24 @@ def forward( self._rope_theta = rope_theta return self.run(q, k, v) + @overload + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + return_lse: Literal[False] = False, + ) -> torch.Tensor: ... + + @overload + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + return_lse: Literal[True] = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: ... + def run( self, q: torch.Tensor,