|
1 | 1 | # SPDX-License-Identifier: Apache-2.0 |
2 | 2 | """Inference-only Bamba model.""" |
3 | 3 | # Added by the IBM Team, 2024 |
4 | | -import math |
5 | 4 | from typing import Iterable, Optional, Set, Tuple |
6 | 5 |
|
7 | 6 | import torch |
|
21 | 20 | from vllm.model_executor.layers.logits_processor import LogitsProcessor |
22 | 21 | from vllm.model_executor.layers.mamba.mamba_mixer2 import ( |
23 | 22 | MambaMixer2, extra_groups_for_head_shards) |
24 | | -from vllm.model_executor.layers.quantization import QuantizationConfig |
25 | 23 | from vllm.model_executor.layers.mamba.ops.ssd_chunk_scan import ( |
26 | 24 | seq_idx_to_chunk_indices_offsets) |
| 25 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
27 | 26 | from vllm.model_executor.layers.rotary_embedding import get_rope |
28 | 27 | from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler |
29 | 28 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
@@ -258,6 +257,7 @@ def forward( |
258 | 257 | "mamba": BambaMixerDecoderLayer |
259 | 258 | } |
260 | 259 |
|
| 260 | + |
261 | 261 | class BambaModel(nn.Module): |
262 | 262 |
|
263 | 263 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
@@ -329,10 +329,10 @@ def forward( |
329 | 329 | seq_idx[srt:end] = i |
330 | 330 | seq_idx.unsqueeze_(0) |
331 | 331 |
|
332 | | - # compute metadata for chunked prefill. |
333 | | - # actually this is only needed if there are |
| 332 | + # compute metadata for chunked prefill. |
| 333 | + # actually this is only needed if there are |
334 | 334 | # initial states, but this is determinable |
335 | | - # only from attention metadata yet |
| 335 | + # only from attention metadata yet |
336 | 336 | # unavailable from the current top-level forward. |
337 | 337 | # Rather than complicating things to extract said |
338 | 338 | # metadata, we simply just compute redundently and |
|
0 commit comments