diff --git a/server/text_generation_server/models_flashinfer/__init__.py b/server/text_generation_server/models_flashinfer/__init__.py index 8114caab..00b92162 100644 --- a/server/text_generation_server/models_flashinfer/__init__.py +++ b/server/text_generation_server/models_flashinfer/__init__.py @@ -13,6 +13,9 @@ ) from text_generation_server.models_flashinfer.flashinfer_phi import FlashinferPhi from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2 +from text_generation_server.models_flashinfer.flashinfer_chatglm import ( + FlashinferChatGLM, +) # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. @@ -77,6 +80,11 @@ class ModelType(enum.Enum): "name": "Qwen 2", "url": "https://huggingface.co/bigcode/starcoder2-15b-instruct-v0.1", } + CHATGLM = { + "type": "chatglm", + "name": "Chatglm", + "url": "https://huggingface.co/THUDM/glm-4-9b-chat", + } __GLOBALS = locals() @@ -170,5 +178,13 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) + elif model_type == CHATGLM: + return FlashinferChatGLM( + model_id, + lora_ids.split(";") if lora_ids else None, + quantize=quantize, + dtype=torch.float16, + trust_remote_code=trust_remote_code, + ) raise ValueError(f"Unsupported model type {model_type}")