Skip to content

Commit 10990a9

Browse files
committed
Fix bug in getting vocab_size and missing ccl in forward
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent df06617 commit 10990a9

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from pathlib import Path
1010
from time import perf_counter
11-
from typing import Dict, List, Optional, Union
11+
from typing import List, Optional, Union
1212

1313
import numpy as np
1414
import torch
@@ -1081,7 +1081,7 @@ def export(
10811081
output_names=output_names["lang"],
10821082
dynamic_axes=dynamic_axes["lang"],
10831083
continuous_batching=self.continuous_batching,
1084-
vocab_size=self.config.vocab_size,
1084+
vocab_size=self.model.language_model.config.vocab_size,
10851085
qaic_config=self.lang_model.model.qaic_config,
10861086
)
10871087

QEfficient/transformers/sampler/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def sampler_forward(
105105
attention_mask: Optional[torch.Tensor] = None,
106106
position_ids: Optional[torch.LongTensor] = None,
107107
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
108+
comp_ctx_lengths: Optional[torch.LongTensor] = None,
108109
batch_index: Optional[torch.LongTensor] = None,
109110
inputs_embeds: Optional[torch.FloatTensor] = None,
110111
labels: Optional[torch.LongTensor] = None,
@@ -181,6 +182,7 @@ def sampler_forward(
181182
position_ids=position_ids,
182183
image_idx=image_idx,
183184
past_key_values=past_key_values,
185+
comp_ctx_lengths=comp_ctx_lengths,
184186
)
185187
if batch_index is not None:
186188
forward_kwargs["batch_index"] = batch_index
@@ -195,6 +197,7 @@ def sampler_forward(
195197
attention_mask=attention_mask,
196198
position_ids=position_ids,
197199
past_key_values=past_key_values,
200+
comp_ctx_lengths=comp_ctx_lengths,
198201
batch_index=batch_index,
199202
inputs_embeds=inputs_embeds,
200203
use_cache=use_cache,

tests/transformers/sampler/test_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
# -----------------------------------------------------------------------------
77

88
from typing import List, Union
9-
from transformers import AutoConfig, AutoProcessor
109

1110
import numpy as np
1211
import pytest
12+
from transformers import AutoProcessor
1313

1414
from QEfficient import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText
1515
from QEfficient.generation.cloud_infer import QAICInferenceSession

0 commit comments

Comments
 (0)