diff --git a/server/examples/test_causal.py b/server/examples/test_causal.py index 9d5d7979..5f724557 100644 --- a/server/examples/test_causal.py +++ b/server/examples/test_causal.py @@ -1,18 +1,21 @@ from text_generation_server.pb import generate_pb2 import torch -from text_generation_server.models.flashinfer_causal_lm import ( - FlashinferLM, - FlashinferBatch, +from text_generation_server.models.flash_causal_lm import ( + FlashCausalLMBatch +) +from text_generation_server.models.flash_llama import ( + FlashLlama ) from text_generation_server.models.causal_lm import CausalLM, CausalLMBatch import random, json from test_cases import DEMO, LoraSpec - +from collections import defaultdict # Load demo inputs lora_specs = {} for name, spec in DEMO.items(): lora_prompts, base_prompts = spec.generate_prompts() lora_specs[name] = LoraSpec(lora_prompts, base_prompts) +from text_generation_server.utils.speculate import get_speculate, set_speculate # Create input requests @@ -49,46 +52,49 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): flash = False +set_speculate(0) if flash: - service = FlashinferLM( - model_type="llama", model_id="meta-llama/Llama-2-7b-hf", lora_ids=["empty"] - ) + service = FlashLlama(model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True, dtype=torch.bfloat16) else: - service = CausalLM(model_id="meta-llama/Llama-2-7b-hf") + service = CausalLM(model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True, dtype=torch.bfloat16) requests = [ - make_input("abcdabcd987/gsm8k-llama2-7b-lora-16", "base", id=0) -] # , promptOverride= "test")] + make_input( + "abcdabcd987/gsm8k-llama2-7b-lora-16", + "base", + id=0, + promptOverride="why is deep learning so popular these days?", + ) +] tokenizer = service.tokenizer batch = generate_pb2.Batch(id=0, requests=requests, size=len(requests)) if flash: - pb_batch = FlashinferBatch.from_pb( - batch, tokenizer, torch.float16, torch.device("cuda") + pb_batch = FlashCausalLMBatch.from_pb( + batch, tokenizer, torch.bfloat16, torch.device("cuda") ) - ids = service.add_request(pb_batch) else: pb_batch = CausalLMBatch.from_pb( - batch, tokenizer, torch.float16, torch.device("cuda") + batch, tokenizer, torch.bfloat16, torch.device("cuda") ) -display_results = {} +display_results = defaultdict(lambda: []) -# service.warmup(pb_batch) +service.warmup(pb_batch) while True: - if flash: - generations, _, _ = service.generate_token(FlashinferBatch.Empty(batch.id)) - else: - generations, _, _ = service.generate_token(pb_batch) + generations, next_batch, _ = service.generate_token(pb_batch) + for gen in generations: - if gen.generated_text: + if gen.prefill_tokens: display_results[gen.request_id] = [ - "Prompt: " + "Prompt:\n" + tokenizer.decode(gen.prefill_tokens.token_ids) - + "\nAnswer: " - + gen.generated_text.text + + "\nAnswer:\n" ] + if gen.generated_text: + display_results[gen.request_id] += [gen.generated_text.text] + # Stop if all input generations are done if all([g.generated_text for g in generations]): break diff --git a/server/examples/test_local_api.py b/server/examples/test_local_api.py index 9249e384..bd578744 100644 --- a/server/examples/test_local_api.py +++ b/server/examples/test_local_api.py @@ -32,14 +32,14 @@ # test = "gemma" # test = "llama-3" # test = 'llama-3-70' - # test = "baichuan" + test = "baichuan" # test = "gemma" # test = 'mistral' # test = 'qwen1.5-7' # test = 'qwen1.5-1.8' # test = 'qwen1.5-70' # test = 'qwen2-7' - test = "yi1.5-9b" + # test = "yi1.5-9b" # test = "chatglm4" print("Testing " + test) @@ -279,6 +279,7 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None): model_id="baichuan-inc/Baichuan2-7B-Chat", lora_ids=["tjluyao/baichuan2-7b-chat-lora1"], trust_remote_code=True, + dtype=torch.bfloat16, ) elif test == "qwen2-7": # Todo: qwen2-7b instruct lora adapter