Skip to content

Commit

Permalink
adjust the flashinfer llama model to accomodate baichuan
Browse files Browse the repository at this point in the history
  • Loading branch information
alfredgui2 authored and tjluyao committed Jul 7, 2024
1 parent d8dcdeb commit 5004444
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
4 changes: 3 additions & 1 deletion server/examples/test_local_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,9 @@ def make_input(lora_id, lora_or_base, id=0, promptOverride=None):
promptOverride="What are the differences between Manhattan and Brooklyn",
),
]
service = FlashinferLlama(model_id="baichuan-inc/Baichuan2-7B-Chat")
service = FlashinferLlama(
model_id="baichuan-inc/Baichuan2-7B-Chat", trust_remote_code=True
)

print(service.get_lora_adapters())
tokenizer = service.tokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class AttentionRotaryParams:
causal: bool = True
pos_encoding_mode: POS_ENCODING_MODE = POS_ENCODING_MODE.ROPE_LLAMA
rope_scale: float = 1.0
rope_theta: float = 1.0e-4
rope_theta: float = 1.0e4


def find_padded_head_dim(head_dim):
Expand Down
31 changes: 14 additions & 17 deletions server/text_generation_server/models/flashinfer_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,23 +33,14 @@ def __init__(
else:
raise NotImplementedError("Flashinfer Llama is only available on Cuda")

tokenizer = AutoTokenizer.from_pretrained(model_id)
try:
tokenizer = LlamaTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
except Exception:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)

try:
generation_config = GenerationConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code
Expand All @@ -65,6 +56,12 @@ def __init__(
)
config.quantize = quantize
config.speculator = None
if not hasattr(config, "num_key_value_heads"):
config.num_key_value_heads = config.num_attention_heads

if not hasattr(config, "rope_theta"):
config.rope_theta = 1.0e4

torch.distributed.barrier(group=self.process_group)

filenames = weight_files(model_id, revision=revision, extension=".safetensors")
Expand Down

0 comments on commit 5004444

Please sign in to comment.