diff --git a/server/examples/test_local_api.py b/server/examples/test_local_api.py index 8e449399..69c078c9 100644 --- a/server/examples/test_local_api.py +++ b/server/examples/test_local_api.py @@ -1,17 +1,23 @@ from text_generation_server.pb import generate_pb2 import torch -from text_generation_server.models.flashinfer_llama import FlashinferLlama -from text_generation_server.models.flashinfer_gemma import FlashinferGemma +from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama +from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma import sys try: - from text_generation_server.models.flashinfer_mistral import FlashinferMistral - from text_generation_server.models.flashinfer_phi import FlashinferPhi - from text_generation_server.models.flashinfer_qwen2 import FlashinferQwen2 + from text_generation_server.models_flashinfer.flashinfer_mistral import ( + FlashinferMistral, + ) + from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi + from text_generation_server.models_flashinfer.flashinfer_qwen2 import ( + FlashinferQwen2, + ) except: print("can't load flashinfer mistral and phi and qwen2 without flash attn") -from text_generation_server.models.flashinfer_causal_lm import FlashinferBatch +from text_generation_server.models_flashinfer.flashinfer_causal_lm import ( + FlashinferBatch, +) import random, json from test_cases import DEMO, LoraSpec diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index 0a7c3639..712263cb 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -43,6 +43,7 @@ def serve( logger_level: str = "INFO", json_output: bool = False, otlp_endpoint: Optional[str] = None, + use_flashinfer: bool = True, ): if sharded: assert ( @@ -90,17 +91,33 @@ def serve( raise RuntimeError( "Only 1 can be set between `dtype` and `quantize`, as they both decide how goes the final model." ) - server.serve( - model_id, - revision, - sharded, - quantize, - speculate, - dtype, - trust_remote_code, - uds_path, - lora_ids, - ) + + if use_flashinfer: + from text_generation_server import server_flashinfer + + server_flashinfer.serve( + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + lora_ids, + ) + else: + server.serve( + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + uds_path, + lora_ids, + ) @app.command() diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index a9495539..b6506e90 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -101,25 +101,6 @@ __all__.append(FlashGemma) __all__.append(FlashCohere) -FLASHINFER_AVAILABLE = True -try: - from text_generation_server.models.flashinfer_llama import FlashinferLlama - from text_generation_server.models.flashinfer_gemma import FlashinferGemma - from text_generation_server.models.flashinfer_mistral import FlashinferMistral - from text_generation_server.models.flashinfer_phi import FlashinferPhi - from text_generation_server.models.flashinfer_qwen2 import FlashinferQwen2 - -except ImportError as e: - logger.warning(f"Could not import FlashInfer: {e}") - FLASHINFER_AVAILABLE = False - -if FLASHINFER_AVAILABLE: - __all__.append(FlashinferLlama) - __all__.append(FlashinferGemma) - __all__.append(FlashinferMistral) - __all__.append(FlashinferPhi) - __all__.append(FlashinferQwen2) - MAMBA_AVAILABLE = True try: from text_generation_server.models.mamba import Mamba @@ -564,13 +545,6 @@ def get_model( ) elif model_type == PHI: - if FLASHINFER_AVAILABLE: - return FlashinferPhi( - model_id, - lora_ids.split(";") if lora_ids else None, - quantize=quantize, - dtype=dtype, - ) if FLASH_ATTENTION: return FlashPhi( model_id, @@ -606,14 +580,6 @@ def get_model( ) elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: - if FLASHINFER_AVAILABLE: - return FlashinferLlama( - model_id, - lora_ids.split(";") if lora_ids else None, - quantize=quantize, - dtype=dtype, - ) - if FLASH_ATTENTION: return FlashLlama( model_id, @@ -635,13 +601,6 @@ def get_model( trust_remote_code=trust_remote_code, ) if model_type == GEMMA: - if FLASHINFER_AVAILABLE: - return FlashinferGemma( - model_id, - lora_ids.split(";") if lora_ids else None, - quantize=quantize, - dtype=dtype, - ) if FLASH_ATTENTION: return FlashGemma( model_id, @@ -742,14 +701,6 @@ def get_model( ) if model_type == MISTRAL: - if FLASHINFER_AVAILABLE: - return FlashinferMistral( - model_id, - lora_ids.split(";") if lora_ids else None, - quantize=quantize, - dtype=dtype, - ) - sliding_window = config_dict.get("sliding_window", -1) if FLASH_ATTENTION: return FlashMistral( @@ -820,13 +771,6 @@ def get_model( ) if model_type == QWEN2: - if FLASHINFER_AVAILABLE: - return FlashinferQwen2( - model_id, - lora_ids.split(";") if lora_ids else None, - quantize=quantize, - dtype=dtype, - ) sliding_window = config_dict.get("sliding_window", -1) if (sliding_window is None or sliding_window != -1) and SUPPORTS_WINDOWING: return FlashQwen2( diff --git a/server/text_generation_server/models_flashinfer/__init__.py b/server/text_generation_server/models_flashinfer/__init__.py new file mode 100644 index 00000000..59f955eb --- /dev/null +++ b/server/text_generation_server/models_flashinfer/__init__.py @@ -0,0 +1,164 @@ +import torch +import enum + +from loguru import logger +from transformers.configuration_utils import PretrainedConfig +from typing import Optional + +from text_generation_server.models.model import Model +from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama +from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma +from text_generation_server.models_flashinfer.flashinfer_mistral import ( + FlashinferMistral, +) +from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi +from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2 + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True + +# Disable gradients +torch.set_grad_enabled(False) + +__all__ = [ + "Model", + "FlashinferLlama", + "FlashinferGemma", + "FlashinferMistral", + "FlashinferPhi", + "FlashinferQwen2", + "get_model", +] + + +class ModelType(enum.Enum): + LLAVA_NEXT = { + "type": "llava_next", + "name": "Llava Next (1.6)", + "url": "https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf", + "multimodal": True, + } + LLAMA = { + "type": "llama", + "name": "Llama", + "url": "https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct", + } + PHI3 = { + "type": "phi3", + "name": "Phi 3", + "url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + } + GEMMA = { + "type": "gemma", + "name": "Gemma", + "url": "https://huggingface.co/google/gemma-7b", + } + MISTRAL = { + "type": "mistral", + "name": "Mistral", + "url": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2", + } + PHI = { + "type": "phi", + "name": "Phi", + "url": "https://huggingface.co/microsoft/phi-1_5", + } + BAICHUAN = { + "type": "baichuan", + "name": "Baichuan", + "url": "https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat", + } + QWEN2 = { + "type": "qwen2", + "name": "Qwen 2", + "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", + } + + +def get_model( + model_id: str, + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + dtype: Optional[str], + trust_remote_code: bool, + lora_ids: Optional[str], +) -> Model: + if dtype is None: + if quantize in ["awq", "exl2", "gptq"]: + # These quantizers only work with float16 params. + dtype = torch.float16 + else: + # Keep it as default for now and let + # every model resolve their own default dtype. + dtype = None + elif dtype == "float16": + dtype = torch.float16 + elif dtype == "bfloat16": + dtype = torch.bfloat16 + else: + raise RuntimeError(f"Unknown dtype {dtype}") + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + model_type = config_dict.get("model_type", None) + if model_type is None: + raise RuntimeError( + f"Could not determine model type for {model_id} revision {revision}" + ) + quantization_config = config_dict.get("quantization_config", None) + if quantization_config is not None and quantize is None: + method = quantization_config.get("quant_method", None) + if method in {"gptq", "awq", "exl2"}: + logger.info(f"Auto selecting quantization method {method}") + quantize = method + else: + logger.info(f"Unknown quantization method {method}") + + if quantize == "exl2" and sharded: + raise RuntimeError( + "Sharding is currently not supported with `exl2` quantization" + ) + + if model_type == PHI: + return FlashinferPhi( + model_id, + lora_ids.split(";") if lora_ids else None, + quantize=quantize, + dtype=dtype, + ) + elif model_type == LLAMA or model_type == BAICHUAN or model_type == PHI3: + return FlashinferLlama( + model_id, + lora_ids.split(";") if lora_ids else None, + quantize=quantize, + dtype=dtype, + ) + elif model_type == GEMMA: + return FlashinferGemma( + model_id, + lora_ids.split(";") if lora_ids else None, + quantize=quantize, + dtype=dtype, + ) + elif model_type == MISTRAL: + return FlashinferMistral( + model_id, + lora_ids.split(";") if lora_ids else None, + quantize=quantize, + dtype=dtype, + ) + elif model_type == QWEN2: + return FlashinferQwen2( + model_id, + lora_ids.split(";") if lora_ids else None, + quantize=quantize, + dtype=dtype, + ) + + raise ValueError(f"Unsupported model type {model_type}") diff --git a/server/text_generation_server/models_flashinfer/custom_modeling/__init__.py b/server/text_generation_server/models_flashinfer/custom_modeling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/text_generation_server/models/custom_modeling/embedding_llama.py b/server/text_generation_server/models_flashinfer/custom_modeling/embedding_llama.py similarity index 100% rename from server/text_generation_server/models/custom_modeling/embedding_llama.py rename to server/text_generation_server/models_flashinfer/custom_modeling/embedding_llama.py diff --git a/server/text_generation_server/models/custom_modeling/flashinfer_gemma_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_gemma_modeling.py similarity index 100% rename from server/text_generation_server/models/custom_modeling/flashinfer_gemma_modeling.py rename to server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_gemma_modeling.py diff --git a/server/text_generation_server/models/custom_modeling/flashinfer_llama_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_llama_modeling.py similarity index 100% rename from server/text_generation_server/models/custom_modeling/flashinfer_llama_modeling.py rename to server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_llama_modeling.py diff --git a/server/text_generation_server/models/custom_modeling/flashinfer_mistral_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_mistral_modeling.py similarity index 100% rename from server/text_generation_server/models/custom_modeling/flashinfer_mistral_modeling.py rename to server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_mistral_modeling.py diff --git a/server/text_generation_server/models/custom_modeling/flashinfer_phi_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_phi_modeling.py similarity index 100% rename from server/text_generation_server/models/custom_modeling/flashinfer_phi_modeling.py rename to server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_phi_modeling.py diff --git a/server/text_generation_server/models/custom_modeling/flashinfer_qwen2_modeling.py b/server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_qwen2_modeling.py similarity index 100% rename from server/text_generation_server/models/custom_modeling/flashinfer_qwen2_modeling.py rename to server/text_generation_server/models_flashinfer/custom_modeling/flashinfer_qwen2_modeling.py diff --git a/server/text_generation_server/models/flashinfer_causal_lm.py b/server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py similarity index 100% rename from server/text_generation_server/models/flashinfer_causal_lm.py rename to server/text_generation_server/models_flashinfer/flashinfer_causal_lm.py diff --git a/server/text_generation_server/models/flashinfer_gemma.py b/server/text_generation_server/models_flashinfer/flashinfer_gemma.py similarity index 91% rename from server/text_generation_server/models/flashinfer_gemma.py rename to server/text_generation_server/models_flashinfer/flashinfer_gemma.py index 8278cf87..2ff05ce9 100644 --- a/server/text_generation_server/models/flashinfer_gemma.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_gemma.py @@ -1,8 +1,8 @@ import torch import torch.distributed from typing import Optional, List -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM -from text_generation_server.models.custom_modeling.flashinfer_gemma_modeling import ( +from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM +from text_generation_server.models_flashinfer.custom_modeling.flashinfer_gemma_modeling import ( GemmaTokenizerFast, GemmaConfig, FlashGemmaForCausalLM, diff --git a/server/text_generation_server/models/flashinfer_llama.py b/server/text_generation_server/models_flashinfer/flashinfer_llama.py similarity index 93% rename from server/text_generation_server/models/flashinfer_llama.py rename to server/text_generation_server/models_flashinfer/flashinfer_llama.py index 0ada574e..19cf3b1e 100644 --- a/server/text_generation_server/models/flashinfer_llama.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_llama.py @@ -5,8 +5,8 @@ from transformers.models.llama import LlamaTokenizer from typing import Optional, List -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM -from text_generation_server.models.custom_modeling.flashinfer_llama_modeling import ( +from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM +from text_generation_server.models_flashinfer.custom_modeling.flashinfer_llama_modeling import ( FlashLlamaForCausalLM, ) from text_generation_server.utils import ( diff --git a/server/text_generation_server/models/flashinfer_mistral.py b/server/text_generation_server/models_flashinfer/flashinfer_mistral.py similarity index 90% rename from server/text_generation_server/models/flashinfer_mistral.py rename to server/text_generation_server/models_flashinfer/flashinfer_mistral.py index 94b366de..12b64d14 100644 --- a/server/text_generation_server/models/flashinfer_mistral.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_mistral.py @@ -1,8 +1,8 @@ import torch import torch.distributed from typing import Optional, List -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM -from text_generation_server.models.custom_modeling.flashinfer_mistral_modeling import ( +from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM +from text_generation_server.models_flashinfer.custom_modeling.flashinfer_mistral_modeling import ( MistralConfig, FlashMistralForCausalLM, ) diff --git a/server/text_generation_server/models/flashinfer_phi.py b/server/text_generation_server/models_flashinfer/flashinfer_phi.py similarity index 94% rename from server/text_generation_server/models/flashinfer_phi.py rename to server/text_generation_server/models_flashinfer/flashinfer_phi.py index 3255cc05..69da0573 100644 --- a/server/text_generation_server/models/flashinfer_phi.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_phi.py @@ -1,8 +1,8 @@ import torch import torch.distributed from typing import Optional, List -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM -from text_generation_server.models.custom_modeling.flashinfer_phi_modeling import ( +from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM +from text_generation_server.models_flashinfer.custom_modeling.flashinfer_phi_modeling import ( FlashPhiForCausalLM, PhiConfig, ) diff --git a/server/text_generation_server/models/flashinfer_qwen2.py b/server/text_generation_server/models_flashinfer/flashinfer_qwen2.py similarity index 91% rename from server/text_generation_server/models/flashinfer_qwen2.py rename to server/text_generation_server/models_flashinfer/flashinfer_qwen2.py index bf51cff1..69c83271 100644 --- a/server/text_generation_server/models/flashinfer_qwen2.py +++ b/server/text_generation_server/models_flashinfer/flashinfer_qwen2.py @@ -2,8 +2,8 @@ import torch.distributed from typing import Optional, List -from text_generation_server.models.flashinfer_causal_lm import FlashinferLM -from text_generation_server.models.custom_modeling.flashinfer_qwen2_modeling import ( +from text_generation_server.models_flashinfer.flashinfer_causal_lm import FlashinferLM +from text_generation_server.models_flashinfer.custom_modeling.flashinfer_qwen2_modeling import ( Qwen2Config, FlashQwen2ForCausalLM, ) diff --git a/server/text_generation_server/models_flashinfer/llava_causal_lm.py b/server/text_generation_server/models_flashinfer/llava_causal_lm.py new file mode 100644 index 00000000..9ffc2027 --- /dev/null +++ b/server/text_generation_server/models_flashinfer/llava_causal_lm.py @@ -0,0 +1,418 @@ +# Modified from https://github.com/punica-ai/punica/blob/master/src/punica/models/llama_lora.py +# Editor: Junyi Shen + +import torch +from text_generation_server.models_flashinfer.custom_modeling.embedding_llama import ( + LlamaForCausalLM, +) + +import time +import json +from opentelemetry import trace +from typing import Optional, Tuple, List, Type, Dict +from text_generation_server.models import Model +from text_generation_server.utils.tokens import batch_top_tokens +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils import Sampling +from dataclasses import dataclass +from transformers import AutoTokenizer, AutoConfig, LlavaConfig +from huggingface_hub import hf_hub_download + +from loguru import logger +from PIL import Image +from io import BytesIO +import base64 + +tracer = trace.get_tracer(__name__) + +from text_generation_server.models.causal_lm import CausalLMBatch + + +@dataclass +class LlavaBatch(CausalLMBatch): + imgs = [] + + +class LlavaLM(Model): + def __init__( + self, + model_id: str = None, + revision: Optional[str] = None, + quantize: Optional[str] = None, + use_medusa: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + if use_medusa: + raise RuntimeError("Medusa decoding is not enabled for ThisModel") + + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + self.device = device + self.model_id = model_id + tokenizer = AutoTokenizer.from_pretrained( + model_id, + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + use_fast=True, + ) + model = LlamaForCausalLM.from_pretrained( + model_id, + revision=revision, + torch_dtype=dtype, + low_cpu_mem_usage=True, + device_map=device, + # device_map="auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None, + load_in_8bit=quantize == "bitsandbytes", + trust_remote_code=trust_remote_code, + ) + if ( + torch.cuda.is_available() + and torch.cuda.device_count() == 1 + and quantize != "bitsandbytes" + ): + model = model.cuda() + self.model = model + + if tokenizer.pad_token_id is None: + if model.config.pad_token_id is not None: + tokenizer.pad_token_id = model.config.pad_token_id + elif model.config.eos_token_id is not None: + tokenizer.pad_token_id = model.config.eos_token_id + elif tokenizer.eos_token_id is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + + self.model_config = AutoConfig.from_pretrained(model_id) + self.kvpool = KvPool( + num_layers=self.model_config.num_hidden_layers, + num_heads=self.model_config.num_attention_heads, + head_dim=self.model_config.hidden_size + // self.model_config.num_attention_heads, + page_len=16, + dtype=dtype, + device=device, + ) + self.cache_pool = {} + + with open(hf_hub_download(self.model_id, filename="config.json")) as f: + mm_config = json.loads(f.read()) + + self.vision_model = self.build_vision_model(mm_config) + self.vision_model.to(self.device).eval() + self.projector = self.build_projector(mm_config) + self.projector.to(self.device).eval() + self.id_embedder = self.model.model.embed_tokens + self.additional_init_length = 576 # 512 + 64 I guess + + super(LlavaLM, self).__init__( + model=self.model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + logger.info(f"Initialized LlavaLM with model_id: {model_id}") + + def build_vision_model(self, model_config, **kwargs): + from .llava_models.encoder.encoder import CLIPVisionTower + + mm_vision_tower = "openai/clip-vit-large-patch14-336" + return CLIPVisionTower(mm_vision_tower, args=model_config, **kwargs) + + def build_projector(self, model_config, **kwargs): + from .llava_models.projector.builder import build_vision_projector + + projector = build_vision_projector(model_config, **kwargs) + model_path = hf_hub_download(self.model_id, filename="mm_projector.bin") + state_dict = torch.load(model_path) + new_state_dict = { + "0.weight": state_dict["model.mm_projector.0.weight"], + "0.bias": state_dict["model.mm_projector.0.bias"], + "2.weight": state_dict["model.mm_projector.2.weight"], + "2.bias": state_dict["model.mm_projector.2.bias"], + } + projector.load_state_dict(new_state_dict) + return projector + + @property + def batch_type(self) -> Type[LlavaBatch]: + return LlavaBatch + + def decode(self, generated_ids: List[int]) -> str: + return self.tokenizer.decode( + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + + @torch.no_grad() + def prefill_token(self, batch: LlavaBatch): + img_features = [] + for r in batch.requests: + img = Image.open(r.inputb).convert("RGB") + img = self.vision_model.image_processor(img, return_tensors="pt")[ + "pixel_values" + ].squeeze(0) + img_features.append(self.vision_model(img)) + img_features = torch.stack(img_features, dim=0) + if self.projector: + img_features = self.projector(img_features) + + input_ids = torch.tensor(batch.input_ids, dtype=torch.long, device=self.device) + input_embeddings = self.id_embedder(input_ids).unsqueeze(0) + input_embeddings = torch.cat([img_features, input_embeddings], dim=1) + lens = batch.input_lengths + self.additional_init_length + blen = BatchLenInfo(lens, 0, self.device) + + for r, l in zip(batch.requests, lens): + kv_cache = KvCache(self.kvpool, l) + self.cache_pool[str(r.id)] = kv_cache + + prefill_kv = BatchedKvCache( + [self.cache_pool[str(r.id)] for r in batch.requests] + ) + + logits, _ = self.model( + input_ids=None, + blen=blen, + prefill_kv=prefill_kv, + decode_kv=None, + input_embeddings=input_embeddings, + ) + + logits = logits[blen.indptr[1:] - 1] + logits = logits.unsqueeze(1) + return logits + + @torch.no_grad() + def generate_token(self, batch: LlavaBatch): + input_ids, decode_kv = [], [] + + for i, (request, ids) in enumerate(zip(batch.requests, batch.input_ids)): + input_ids.append(ids) + kv_cache = self.cache_pool[str(request.id)] + decode_kv.append(kv_cache) + kv_cache.acquire_one() + + blen = BatchLenInfo([], len(input_ids), self.device) + decode_kv = BatchedKvCache(decode_kv) if decode_kv else None + + # Forward pass + logits, _ = self.model(input_ids, blen, None, decode_kv, None) + logits = logits.unsqueeze(1) + return logits + + def generate( + self, batch: LlavaBatch + ) -> Tuple[List[Generation], Optional[LlavaBatch], Tuple[int, int]]: + start = time.time_ns() + logits = ( + self.prefill_token(batch) + if batch.stopping_criterias[0].current_tokens == 0 + else self.generate_token(batch) + ) + generations: List[Generation] = [] + stopped = True + + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1, :], -1), + accepted_ids, + ) + + start_decode = time.time_ns() + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + if not stop: + stopped = False + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + + # release kv-cache + self.cache_pool[str(request.id)].release() + del self.cache_pool[str(request.id)] + + else: + generated_text = None + + # Prefill + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + all_top_tokens = [] + for top_token_ids, top_token_logprobs in zip( + top_token_ids, top_token_logprobs + ): + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids + for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + all_top_tokens.append(top_tokens) + top_tokens = all_top_tokens + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.next_token_choosers[i] = batch.next_token_choosers[i].advance_grammar( + next_token_id_squeezed.item() + ) + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + # Update attention_mask as we added a new token to input_ids + batch.attention_mask[:, -batch.padding_right_offset] = 1 + # Decrease right offset + batch.padding_right_offset -= 1 + + # Update position_ids + batch.position_ids = batch.position_ids[:, -1:] + 1 + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns) + + +if __name__ == "__main__": + model = LlavaLM(model_id="liuhaotian/llava-v1.5-7b") + print(model) diff --git a/server/text_generation_server/models/llava_models/encoder/builder.py b/server/text_generation_server/models_flashinfer/llava_models/encoder/builder.py similarity index 100% rename from server/text_generation_server/models/llava_models/encoder/builder.py rename to server/text_generation_server/models_flashinfer/llava_models/encoder/builder.py diff --git a/server/text_generation_server/models/llava_models/encoder/encoder.py b/server/text_generation_server/models_flashinfer/llava_models/encoder/encoder.py similarity index 100% rename from server/text_generation_server/models/llava_models/encoder/encoder.py rename to server/text_generation_server/models_flashinfer/llava_models/encoder/encoder.py diff --git a/server/text_generation_server/models/llava_models/projector/__init__.py b/server/text_generation_server/models_flashinfer/llava_models/projector/__init__.py similarity index 100% rename from server/text_generation_server/models/llava_models/projector/__init__.py rename to server/text_generation_server/models_flashinfer/llava_models/projector/__init__.py diff --git a/server/text_generation_server/models/llava_models/projector/builder.py b/server/text_generation_server/models_flashinfer/llava_models/projector/builder.py similarity index 100% rename from server/text_generation_server/models/llava_models/projector/builder.py rename to server/text_generation_server/models_flashinfer/llava_models/projector/builder.py diff --git a/server/text_generation_server/server_flashinfer.py b/server/text_generation_server/server_flashinfer.py new file mode 100644 index 00000000..95225204 --- /dev/null +++ b/server/text_generation_server/server_flashinfer.py @@ -0,0 +1,214 @@ +import asyncio +import os +import torch +import time +import signal + +from grpc import aio +from loguru import logger + +from grpc_reflection.v1alpha import reflection +from pathlib import Path +from typing import List, Optional + +from text_generation_server.cache import Cache +from text_generation_server.interceptor import ExceptionInterceptor +from text_generation_server.models_flashinfer import Model, get_model + +from text_generation_server.pb import generate_pb2_grpc, generate_pb2 +from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor + + +class SignalHandler: + KEEP_PROCESSING = True + + def __init__(self): + signal.signal(signal.SIGINT, self.exit_gracefully) + signal.signal(signal.SIGTERM, self.exit_gracefully) + + def exit_gracefully(self, signum, frame): + print(f"Exiting gracefully: Signal {signum}") + self.KEEP_PROCESSING = False + + +class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): + def __init__( + self, + model: Model, + cache: Cache, + quantize: Optional[str], + server_urls: List[str], + ): + self.cache = cache + self.model = model + self.quantize = quantize + self.server_urls = server_urls + # For some reason, inference_mode does not work well with GLOO which we use on CPU + if model.device.type == "cuda": + # Force inference mode for the lifetime of TextGenerationService + self._inference_mode_raii_guard = torch._C._InferenceMode(True) + + async def Info(self, request, context): + return self.model.info + + async def Health(self, request, context): + if self.model.device.type == "cuda": + torch.zeros((2, 2)).cuda() + return generate_pb2.HealthResponse() + + async def ServiceDiscovery(self, request, context): + return generate_pb2.ServiceDiscoveryResponse(urls=self.server_urls) + + async def ClearCache(self, request, context): + if request.HasField("id"): + self.cache.delete(request.id) + else: + self.cache.clear() + return generate_pb2.ClearCacheResponse() + + async def FilterBatch(self, request, context): + batch = self.cache.pop(request.batch_id) + if batch is None: + raise ValueError(f"Batch ID {request.batch_id} not found in cache.") + filtered_batch = batch.filter(request.request_ids) + self.cache.set(filtered_batch) + + return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) + + async def Warmup(self, request, context): + + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + max_supported_total_tokens = self.model.warmup(batch) + + return generate_pb2.WarmupResponse( + max_supported_total_tokens=max_supported_total_tokens + ) + + async def Prefill(self, request, context): + start = time.time_ns() + batch = self.model.batch_type.from_pb( + request.batch, self.model.tokenizer, self.model.dtype, self.model.device + ) + + generations, next_batch, timings = self.model.generate_token(batch) + self.cache.set(next_batch) + return generate_pb2.PrefillResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + forward_ns=timings[0], + decode_ns=timings[1], + total_ns=time.time_ns() - start, + ) + + async def Decode(self, request, context): + start = time.time_ns() + if len(request.batches) == 0: + raise ValueError("Must provide at least one batch") + + batches = [] + for batch_pb in request.batches: + batch = self.cache.pop(batch_pb.id) + if batch is None: + raise ValueError(f"Batch ID {batch_pb.id} not found in cache.") + batches.append(batch) + + if len(batches) == 0: + raise ValueError("All batches are empty") + + if len(batches) > 1: + start_concat = time.time_ns() + batch = self.model.batch_type.concatenate(batches) + concat_ns = time.time_ns() - start_concat + else: + batch = batches[0] + concat_ns = None + + generations, next_batch, timings = self.model.generate_token(batch) + self.cache.set(next_batch) + + return generate_pb2.DecodeResponse( + generations=[generation.to_pb() for generation in generations], + batch=next_batch.to_pb() if next_batch else None, + concat_ns=concat_ns, + forward_ns=timings[0], + decode_ns=timings[1], + total_ns=time.time_ns() - start, + ) + + +def serve( + model_id: str, + revision: Optional[str], + sharded: bool, + quantize: Optional[str], + speculate: Optional[int], + dtype: Optional[str], + trust_remote_code: bool, + uds_path: Path, + lora_ids: Optional[str], +): + async def serve_inner( + model_id: str, + revision: Optional[str], + sharded: bool = False, + quantize: Optional[str] = None, + speculate: Optional[int] = None, + dtype: Optional[str] = None, + trust_remote_code: bool = False, + ): + unix_socket_template = "unix://{}-{}" + if sharded: + server_urls = [ + unix_socket_template.format(uds_path, rank) + for rank in range(int(os.environ["WORLD_SIZE"])) + ] + local_url = server_urls[int(os.environ["RANK"])] + else: + local_url = unix_socket_template.format(uds_path, 0) + server_urls = [local_url] + + try: + model = get_model( + model_id, + revision, + sharded, + quantize, + speculate, + dtype, + trust_remote_code, + lora_ids, + ) + except Exception: + logger.exception("Error when initializing model") + raise + + server = aio.server( + interceptors=[ + ExceptionInterceptor(), + UDSOpenTelemetryAioServerInterceptor(), + ] + ) + generate_pb2_grpc.add_TextGenerationServiceServicer_to_server( + TextGenerationService(model, Cache(), quantize, server_urls), server + ) + SERVICE_NAMES = ( + generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name, + reflection.SERVICE_NAME, + ) + reflection.enable_server_reflection(SERVICE_NAMES, server) + server.add_insecure_port(local_url) + + await server.start() + + logger.info("Server started at {}".format(local_url)) + signal_handler = SignalHandler() + while signal_handler.KEEP_PROCESSING: + await asyncio.sleep(0.5) + + asyncio.run( + serve_inner( + model_id, revision, sharded, quantize, speculate, dtype, trust_remote_code + ) + )