Skip to content

Commit

Permalink
Add ChatGLM and refactor Qwen2
Browse files Browse the repository at this point in the history
  • Loading branch information
NovTi authored and tjluyao committed Jul 8, 2024
1 parent 89492fb commit 7ea9184
Show file tree
Hide file tree
Showing 7 changed files with 649 additions and 241 deletions.
45 changes: 38 additions & 7 deletions server/examples/test_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch
from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama
from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma
from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2
from text_generation_server.models_flashinfer.flashinfer_chatglm import FlashinferChatGLM
import sys

try:
Expand All @@ -27,11 +29,13 @@
# test = "gemma"
# test = "llama-3"
# test = 'llama-3-70'
test = "llama-2"
# test = "llama-2"
# test = 'mistral'
# test = 'qwen2'
# test = 'qwen2-1.8'
# test = 'qwen2-70'
# test = 'qwen1.5-7'
# test = 'qwen1.5-1.8'
# test = 'qwen1.5-70'
# test = 'qwen2-7'
test = 'chatglm4'
print("Testing " + test)

# Load demo inputs
Expand Down Expand Up @@ -161,7 +165,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
),
]
service = FlashinferMistral(model_id="mistralai/Mistral-7B-v0.3")
elif test == "qwen2":
elif test == "qwen1.5-7":
requests = [
make_input(
"REILX/Qwen1.5-7B-Chat-750Mb-lora",
Expand All @@ -180,7 +184,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
service = FlashinferQwen2(
model_id="Qwen/Qwen1.5-7B-Chat", lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"]
)
elif test == "qwen2-1.8":
elif test == "qwen1.5-1.8":
# Todo: Add qwen1.5 1.8b chat lora adapter / Output Repetition Problem
requests = [
make_input(
Expand All @@ -194,7 +198,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
service = FlashinferQwen2(
model_id="Qwen/Qwen1.5-1.8B-Chat", lora_ids=["REILX/Qwen1.5-7B-Chat-750Mb-lora"]
)
elif test == "qwen2-70":
elif test == "qwen1.5-70":
# Todo: Add qwen1.5 72b chat lora adapter
requests = [
make_input(
Expand Down Expand Up @@ -260,6 +264,33 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
service = FlashinferLlama(
model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True
)
elif test == "qwen2-7":
# Todo: qwen2-7b instruct lora adapter
requests = [
make_input(
"abcdabcd987/gsm8k-llama2-7b-lora-16",
"base",
id=0,
promptOverride="给我讲个故事",
),
]
service = FlashinferQwen2(
model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True
)

elif test == "chatglm4":
# Todo: chatglm4-9b lora adapter
requests = [
make_input(
"abcdabcd987/gsm8k-llama2-7b-lora-16",
"base",
id=0,
promptOverride="给我讲个故事",
),
]
service = FlashinferChatGLM(
model_id="THUDM/glm-4-9b-chat", trust_remote_code=True
)

print(service.get_lora_adapters())
tokenizer = service.tokenizer
Expand Down
14 changes: 11 additions & 3 deletions server/text_generation_server/layers/flashinfer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def __init__(
32 * 1024 * 1024, dtype=torch.int8, device=torch.cuda.current_device()
)
self.page_size = 16

self.group_size = self.num_attention_heads // self.num_key_value_heads

def computeAttention(
self,
Expand Down Expand Up @@ -182,9 +184,15 @@ def _batchDecode(
decodeBatchPosition.kv_last_page_len,
)

decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer=self._workspace_buffer, kv_layout="NHD"
)
if self.group_size in [7, 16]:
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer=self._workspace_buffer, kv_layout="NHD", use_tensor_cores=True
)
else:
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer=self._workspace_buffer, kv_layout="NHD"
)

decode_wrapper.begin_forward(
decodeBatchPosition.kv_page_indptr,
decodeBatchPosition.kv_page_indices,
Expand Down
Loading

0 comments on commit 7ea9184

Please sign in to comment.