forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Remove assertion * adapting jamba vllm to changes after hf release, working on weight loading in modeling file * splitting the JambaDecoderLayer to JambaMambaDecoderLayer and JambaAttentionDecoderLayer * weight loading from hf checkpoint supposedly works, might be a mixup in the MoE between the gated and non-gated weights * Add mamba from jamba modeling file * Remove slow forward * Modifications to mamba_mixer * Save changes, WIP * Fix cache placement * Debugging * Additions and logging * Jamba with mamba cache handling * Clean up * Another cleanup * Use vllm's RMSNorm instead of JambaRMSNorm, Thier implementation is with fused kernel * Clean up and orginization of the objects to handle the mamba cache * Shorten the code for kv cache mem * Move cache handling inside the Mixer * Add mamba to the wheel requirements * Add mamba to the requirements script * Add mamba_metadata * Add to __init__ __all__ * Revert 2 commits ad1a3db 'Add mamba to the requirements script' 75ed2c8 'Add mamba to the wheel requirements' * Clean up * Naming * Apply whitespace suggestions from code review * pass tie_word_embeddings to PretrainedConfig init * Replace repeat with expand as expand doesn't require more mem * Allocate really small cache if needed , don't use meta * Fix for expanded --------- Co-authored-by: Mor Zusman <morz@ai21.com> Co-authored-by: Erez Schwartz <erezs@ai21.com> Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
- Loading branch information
1 parent
0330e14
commit 07cc899
Showing
12 changed files
with
825 additions
and
573 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,11 @@ | ||
from vllm.model_executor.sampling_metadata import SamplingMetadata | ||
from vllm.model_executor.utils import set_random_seed | ||
from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo, MambaCache | ||
|
||
__all__ = [ | ||
"SamplingMetadata", | ||
"set_random_seed", | ||
"MambaCacheParams", | ||
"RequestInfo", | ||
"MambaCache", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from collections import defaultdict | ||
from dataclasses import dataclass, field | ||
from typing import Dict, Optional, Tuple | ||
import torch | ||
|
||
@dataclass | ||
class MambaCacheParams: | ||
seqlen_offset: int = 0 | ||
conv_state: torch.Tensor = torch.Tensor() | ||
ssm_state: torch.Tensor = torch.Tensor() | ||
|
||
|
||
@dataclass | ||
class RequestInfo: | ||
request_id: str = '' | ||
n: int = 1 | ||
|
||
|
||
class MambaCache: | ||
def __init__( | ||
self, | ||
request_info: RequestInfo, | ||
layer_idx2mamba_cache: Optional[Dict[int, MambaCacheParams]] = None | ||
) -> None: | ||
self.request_info = request_info | ||
if layer_idx2mamba_cache is None: | ||
self.layer_idx2mamba_cache = defaultdict(MambaCacheParams) | ||
else: | ||
self.layer_idx2mamba_cache = layer_idx2mamba_cache | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.