Skip to content

Commit e320580

Browse files
author
Apoorva Gokhale
committed
[eplatero] Add support for exporting and compiling models for SpD
(https://jira-dc.qualcomm.com/jira/browse/CLOUDPERF-43) This change has been validated and posted on behalf of Erick Platero. It adds support for generating a Target LM to run as a verifier model by outputting all logits instead of just that of the last position for the input sequence. It also allows compiling the Target and Draft LMs with specializations that support SpD Usage: TLM: tlm = QEFFAutoModelForCausalLM.from_pretrained(<tlm-model-card>) tlm.transform(num_speculative_tokens=<k>) tlm.export_and_compile(<compiler-args>) DLM: dlm = QEFFAutoModelForCausalLM.from_pretrained(<dlm-model-card>) dlm.transform(is_dlm=True) dlm.export_and_compile(<compiler-args>)
1 parent afb4645 commit e320580

File tree

6 files changed

+87
-18
lines changed

6 files changed

+87
-18
lines changed

QEfficient/compile/compile_helper.py

+36-11
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,47 @@
1515

1616

1717
def create_and_dump_specializations(
18-
batch_size: int, prompt_len: int, ctx_len: int, path: str, full_batch_size: Optional[int] = None
18+
batch_size: int, prompt_len: int, ctx_len: int, path: str, is_dlm: bool, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None,
1919
):
20-
# Create specialization file.
21-
specializations = {
22-
"specializations": [
23-
{
24-
"batch_size": str(batch_size),
25-
"seq_len": str(prompt_len),
26-
"ctx_len": str(ctx_len),
27-
},
28-
{"batch_size": str(batch_size), "seq_len": "1", "ctx_len": str(ctx_len)},
29-
]
20+
# Create specialization cfgs
21+
prefill_specialization = {
22+
"batch_size": str(batch_size),
23+
"seq_len": str(prompt_len),
24+
"ctx_len": str(ctx_len)
3025
}
26+
if num_speculative_tokens is None:
27+
decode_specialization = {
28+
"batch_size": str(batch_size),
29+
"seq_len": "1",
30+
"ctx_len": str(ctx_len),
31+
}
32+
else:
33+
decode_specialization = {
34+
"batch_size": str(batch_size),
35+
"seq_len": str(num_speculative_tokens+1),
36+
"ctx_len": str(ctx_len),
37+
}
38+
specialization_cfgs = [prefill_specialization, decode_specialization]
39+
if is_dlm:
40+
dlm_specialization = {
41+
"batch_size": str(batch_size),
42+
"seq_len": "2",
43+
"ctx_len": str(ctx_len),
44+
}
45+
specialization_cfgs.append(dlm_specialization)
46+
47+
48+
specializations = dict(specializations=specialization_cfgs)
49+
3150
# If continuous batching is enabled by proving full_batch_size we need to add FBS to the specialization file and update the batch size of decoder part to FBS
3251
if full_batch_size is not None:
3352
specializations["specializations"][0]["full_batch_size"] = str(full_batch_size)
3453
specializations["specializations"][1]["full_batch_size"] = str(full_batch_size)
3554
specializations["specializations"][1]["batch_size"] = str(full_batch_size)
55+
if len(specializations["specializations"]) == 3:
56+
specializations["specializations"][2]["batch_size"] = str(full_batch_size)
57+
specializations["specializations"][2]["full_batch_size"] = str(full_batch_size)
58+
3659

3760
# Dump
3861
with open(path, "w") as file:
@@ -158,6 +181,8 @@ def compile(
158181
ctx_len=ctx_len,
159182
path=specialization_json_path,
160183
full_batch_size=full_batch_size,
184+
is_dlm=kwargs.get("is_dlm",None),
185+
num_speculative_tokens=kwargs.get("num_speculative_tokens",None)
161186
)
162187

163188
# Select the customIO config based on the mx flag.

QEfficient/exporter/export_hf_to_cloud_ai_100.py

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def export_kvstyle_transformed_model_to_onnx(
213213
prompt_len=Constants.PROMPT_LEN,
214214
ctx_len=seq_len,
215215
full_batch_size=full_batch_size,
216+
num_speculative_tokens=getattr(transformed_model, "num_speculative_tokens", None)
216217
)
217218

218219
inputs = input_handler.prepare_pytorch_inputs()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
from typing import Optional
3+
4+
5+
def filter_hidden_states(
6+
hidden_states: torch.Tensor,
7+
position_ids: torch.Tensor,
8+
num_speculative_tokens: Optional[int],
9+
) -> torch.Tensor:
10+
"""filter hidden states based on whether this is a TLM SpD model
11+
"""
12+
batch_indices = torch.arange(position_ids.shape[0])
13+
if num_speculative_tokens is not None:
14+
# all logits need to be computed
15+
return hidden_states[batch_indices].squeeze(1)
16+
# Cast to INT32 to avoid issue while running in ONNXRT
17+
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
18+
return hidden_states[batch_indices.view(-1,1), logit_index]
19+
20+
21+

QEfficient/transformers/models/llama/modeling_llama.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
)
3232

3333
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
34+
from QEfficient.transformers.modeling_spd_utils import filter_hidden_states
3435

3536

3637
class QEffLlamaRotaryEmbedding(LlamaRotaryEmbedding):
@@ -288,8 +289,7 @@ def forward(
288289
)
289290

290291
# Cast to INT32 to avoid issue while running in ONNXRT
291-
logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True)
292-
hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index]
292+
hidden_states = filter_hidden_states(outputs[0], position_ids, getattr(self, "num_speculative_tokens", None))
293293
if self.config.pretraining_tp > 1:
294294
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
295295
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]

QEfficient/transformers/models/modeling_auto.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ def transform(self, **kwargs):
168168
"""
169169
if self.is_transformed:
170170
return
171-
172171
if self.full_batch_size is not None:
173172
if KVCacheTransform in self._pytorch_transforms:
174173
self._pytorch_transforms[self._pytorch_transforms.index(KVCacheTransform)] = CBTransform
@@ -188,6 +187,18 @@ def transform(self, **kwargs):
188187
if isinstance(self.model.config.quantization_config, QEffGPTQConfig):
189188
self._pytorch_transforms.insert(0, GPTQToMatmulNbitsTransform)
190189

190+
num_speculative_tokens = kwargs.get("num_speculative_tokens", None)
191+
is_dlm = kwargs.get("is_dlm", False)
192+
assert (not isinstance(num_speculative_tokens,int)) or not is_dlm, "number of speculative tokens are only to be specified for Target LM"
193+
if num_speculative_tokens:
194+
assert isinstance(num_speculative_tokens, int) and num_speculative_tokens>0, (
195+
f"argument num_speculative_tokens"
196+
f" should be of type integer and"
197+
f" be positive if specified")
198+
setattr(self.model, "num_speculative_tokens", num_speculative_tokens)
199+
elif is_dlm:
200+
setattr(self.model, "is_dlm", True)
201+
191202
for transform in self._pytorch_transforms:
192203
transform.apply(self.model)
193204
self.is_transformed = True
@@ -289,6 +300,8 @@ def compile(
289300
mxfp6=mxfp6,
290301
mxint8=mxint8,
291302
full_batch_size=self.full_batch_size,
303+
num_speculative_tokens=getattr(self.model, "num_speculative_tokens", None),
304+
is_dlm=getattr(self.model, "is_dlm", False),
292305
)
293306
self.qpc_path = qpc_dir_path
294307
return self.qpc_path

QEfficient/utils/generate_inputs.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class InputHandler:
15-
def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size):
15+
def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, full_batch_size, num_speculative_tokens):
1616
"""
1717
Initialization
1818
@@ -24,6 +24,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f
2424
:prompt_len (int): Prompt length for the model to compile.
2525
:ctx_len (int): Maximum context length to compile the model.
2626
:full_batch_size (int): Continuous batching batch size
27+
:num_speculative_tokens (Optional[int]): used to determine whether this is a TLM model or not
2728
"""
2829
# check and fix tokenizer viability
2930
padding_check_and_fix(tokenizer)
@@ -32,6 +33,7 @@ def __init__(self, batch_size, tokenizer, config, prompt, prompt_len, ctx_len, f
3233
self.prompt_len = prompt_len
3334
self.ctx_len = ctx_len
3435
self.full_batch_size = full_batch_size
36+
self.num_speculative_tokens = num_speculative_tokens
3537
self.n_layer = get_num_layers_from_config(config)
3638
self.padding_shape = get_padding_shape_from_config(
3739
config=config, batch_size=full_batch_size if full_batch_size else batch_size, seq_len=ctx_len
@@ -99,8 +101,10 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
99101
updated_inputs = {}
100102
if self.full_batch_size:
101103
batch_index = torch.arange(1).view(-1, 1)
102-
103-
input_ids = pt_outputs.logits.detach().argmax(2)
104+
if self.num_speculative_tokens:
105+
input_ids = pt_outputs.logits.detach()[:,-1].argmax(1, keepdim=True)
106+
else:
107+
input_ids = pt_outputs.logits.detach().argmax(2)
104108
updated_inputs["input_ids"] = torch.full((self.full_batch_size, 1), self.tokenizer.pad_token_id)
105109
updated_inputs["input_ids"][batch_index.view(-1)] = input_ids
106110

@@ -111,7 +115,12 @@ def update_pytorch_inputs(self, inputs, pt_outputs):
111115
updated_inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1)
112116

113117
else:
114-
updated_inputs["input_ids"] = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
118+
if self.num_speculative_tokens:
119+
# assume spec decoding logits
120+
input_ids = pt_outputs["logits"][:,-1].argmax(-1).reshape(-1, 1)
121+
else:
122+
input_ids = pt_outputs["logits"].argmax(-1).reshape(-1, 1)
123+
pt_outputs["input_ids"] = input_ids
115124
updated_inputs["position_ids"] = inputs["position_ids"].max(1, keepdim=True).values + 1
116125

117126
updated_inputs["past_key_values"] = tuple(

0 commit comments

Comments
 (0)