Skip to content

Commit

Permalink
fix causal lm test script
Browse files Browse the repository at this point in the history
  • Loading branch information
tjluyao committed Aug 6, 2024
1 parent dcdb457 commit fe53daf
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 26 deletions.
54 changes: 30 additions & 24 deletions server/examples/test_causal.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from text_generation_server.pb import generate_pb2
import torch
from text_generation_server.models.flashinfer_causal_lm import (
FlashinferLM,
FlashinferBatch,
from text_generation_server.models.flash_causal_lm import (
FlashCausalLMBatch
)
from text_generation_server.models.flash_llama import (
FlashLlama
)
from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch
import random, json
from test_cases import DEMO, LoraSpec

from collections import defaultdict
# Load demo inputs
lora_specs = {}
for name, spec in DEMO.items():
lora_prompts, base_prompts = spec.generate_prompts()
lora_specs[name] = LoraSpec(lora_prompts, base_prompts)
from text_generation_server.utils.speculate import get_speculate, set_speculate


# Create input requests
Expand Down Expand Up @@ -49,46 +52,49 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):


flash = False
set_speculate(0)

if flash:
service = FlashinferLM(
model_type="llama", model_id="meta-llama/Llama-2-7b-hf", lora_ids=["empty"]
)
service = FlashLlama(model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True, dtype=torch.bfloat16)
else:
service = CausalLM(model_id="meta-llama/Llama-2-7b-hf")
service = CausalLM(model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True, dtype=torch.bfloat16)
requests = [
make_input("abcdabcd987/gsm8k-llama2-7b-lora-16", "base", id=0)
] # , promptOverride= "test")]
make_input(
"abcdabcd987/gsm8k-llama2-7b-lora-16",
"base",
id=0,
promptOverride="why is deep learning so popular these days?",
)
]

tokenizer = service.tokenizer
batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests))
if flash:
pb_batch = FlashinferBatch.from_pb(
batch, tokenizer, torch.float16, torch.device("cuda")
pb_batch = FlashCausalLMBatch.from_pb(
batch, tokenizer, torch.bfloat16, torch.device("cuda")
)
ids = service.add_request(pb_batch)
else:
pb_batch = CausalLMBatch.from_pb(
batch, tokenizer, torch.float16, torch.device("cuda")
batch, tokenizer, torch.bfloat16, torch.device("cuda")
)

display_results = {}
display_results = defaultdict(lambda: [])

# service.warmup(pb_batch)
service.warmup(pb_batch)

while True:
if flash:
generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id))
else:
generations, _, _ = service.generate_token(pb_batch)
generations, next_batch, _ = service.generate_token(pb_batch)

for gen in generations:
if gen.generated_text:
if gen.prefill_tokens:
display_results[gen.request_id] = [
"Prompt: "
"Prompt:\n"
+ tokenizer.decode(gen.prefill_tokens.token_ids)
+ "\nAnswer: "
+ gen.generated_text.text
+ "\nAnswer:\n"
]
if gen.generated_text:
display_results[gen.request_id] += [gen.generated_text.text]
# Stop if all input generations are done
if all([g.generated_text for g in generations]):
break

Expand Down
5 changes: 3 additions & 2 deletions server/examples/test_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@
# test = "gemma"
# test = "llama-3"
# test = 'llama-3-70'
# test = "baichuan"
test = "baichuan"
# test = "gemma"
# test = 'mistral'
# test = 'qwen1.5-7'
# test = 'qwen1.5-1.8'
# test = 'qwen1.5-70'
# test = 'qwen2-7'
test = "yi1.5-9b"
# test = "yi1.5-9b"
# test = "chatglm4"
print("Testing " + test)

Expand Down Expand Up @@ -279,6 +279,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
model_id="baichuan-inc/Baichuan2-7B-Chat",
lora_ids=["tjluyao/baichuan2-7b-chat-lora1"],
trust_remote_code=True,
dtype=torch.bfloat16,
)
elif test == "qwen2-7":
# Todo: qwen2-7b instruct lora adapter
Expand Down

0 comments on commit fe53daf

Please sign in to comment.