Skip to content

Commit

Permalink
Add the batch concatenation functionality for flashinfer server (#43)
Browse files Browse the repository at this point in the history
* refactor flashinfer causal lm

* modify test_local_api

* fixes

* fixes

* lint
  • Loading branch information
alfredgui2 authored and tjluyao committed Jul 9, 2024
1 parent 44454b1 commit 4440030
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 386 deletions.
32 changes: 15 additions & 17 deletions server/examples/test_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from text_generation_server.models_flashinfer.flashinfer_llama import FlashinferLlama
from text_generation_server.models_flashinfer.flashinfer_gemma import FlashinferGemma
from text_generation_server.models_flashinfer.flashinfer_qwen2 import FlashinferQwen2
from text_generation_server.models_flashinfer.flashinfer_chatglm import FlashinferChatGLM
from text_generation_server.models_flashinfer.flashinfer_chatglm import (
FlashinferChatGLM,
)
import sys

try:
Expand All @@ -29,13 +31,13 @@
# test = "gemma"
# test = "llama-3"
# test = 'llama-3-70'
# test = "llama-2"
test = "gemma"
# test = 'mistral'
# test = 'qwen1.5-7'
# test = 'qwen1.5-1.8'
# test = 'qwen1.5-70'
# test = 'qwen2-7'
test = 'chatglm4'
# test = "chatglm4"
print("Testing " + test)

# Load demo inputs
Expand Down Expand Up @@ -274,10 +276,8 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
promptOverride="给我讲个故事",
),
]
service = FlashinferQwen2(
model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True
)

service = FlashinferQwen2(model_id="Qwen/Qwen2-7B-Instruct", trust_remote_code=True)

elif test == "chatglm4":
# Todo: chatglm4-9b lora adapter
requests = [
Expand All @@ -288,25 +288,23 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
promptOverride="给我讲个故事",
),
]
service = FlashinferChatGLM(
model_id="THUDM/glm-4-9b-chat", trust_remote_code=True
)
service = FlashinferChatGLM(model_id="THUDM/glm-4-9b-chat", trust_remote_code=True)

print(service.get_lora_adapters())
tokenizer = service.tokenizer

batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests))
pb_batch = FlashinferBatch.from_pb(
batch, tokenizer, torch.float16, torch.device("cuda")
)

# Add input batch to model service
ids = service.add_request(pb_batch)
display_results = {}

# Iterative generation: each step generates a token for each input in the batch
isPrefill = True
while True:
generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id))
if isPrefill:
generations, next_batch, _ = service.prefill_batch(batch)
isPrefill = False
else:
generations, next_batch, _, _ = service.decode_batch([next_batch.to_pb()])

for gen in generations:
if gen.prefill_tokens:
display_results[gen.request_id] = [
Expand Down
3 changes: 3 additions & 0 deletions server/text_generation_server/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ class Cache:
def __init__(self):
self.cache: Dict[int, B] = {}

def get_all_values(self):
return self.cache.values()

def pop(self, batch_id: int) -> Optional[B]:
return self.cache.pop(batch_id, None)

Expand Down
6 changes: 4 additions & 2 deletions server/text_generation_server/layers/flashinfer_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
32 * 1024 * 1024, dtype=torch.int8, device=torch.cuda.current_device()
)
self.page_size = 16

self.group_size = self.num_attention_heads // self.num_key_value_heads

def computeAttention(
Expand Down Expand Up @@ -186,7 +186,9 @@ def _batchDecode(

if self.group_size in [7, 16]:
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer=self._workspace_buffer, kv_layout="NHD", use_tensor_cores=True
workspace_buffer=self._workspace_buffer,
kv_layout="NHD",
use_tensor_cores=True,
)
else:
decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
Expand Down
Loading

0 comments on commit 4440030

Please sign in to comment.