Skip to content

Commit

Permalink
Jamba mamba (vllm-project#3)
Browse files Browse the repository at this point in the history
* Remove assertion

* adapting jamba vllm to changes after hf release, working on weight loading in modeling file

* splitting the JambaDecoderLayer to JambaMambaDecoderLayer and JambaAttentionDecoderLayer

* weight loading from hf checkpoint supposedly works, might be a mixup in the MoE between the gated and non-gated weights

* Add mamba from jamba modeling file

* Remove slow forward

* Modifications to mamba_mixer

* Save changes, WIP

* Fix cache placement

* Debugging

* Additions and logging

* Jamba with mamba cache handling

* Clean up

* Another cleanup

* Use vllm's RMSNorm instead of JambaRMSNorm, Thier implementation is with
fused kernel

* Clean up and orginization of the objects to handle the mamba cache

* Shorten the code for kv cache mem

* Move cache handling inside the Mixer

* Add mamba to the wheel requirements

* Add mamba to the requirements script

* Add mamba_metadata

* Add to __init__ __all__

* Revert 2 commits

ad1a3db 'Add mamba to the requirements script'
75ed2c8 'Add mamba to the wheel requirements'

* Clean up

* Naming

* Apply whitespace suggestions from code review

* pass tie_word_embeddings to PretrainedConfig init

* Replace repeat with expand as expand doesn't require more mem

* Allocate really small cache if needed , don't use meta

* Fix for expanded

---------

Co-authored-by: Mor Zusman <morz@ai21.com>
Co-authored-by: Erez Schwartz <erezs@ai21.com>
Co-authored-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
  • Loading branch information
4 people committed Apr 16, 2024
1 parent 0330e14 commit 07cc899
Show file tree
Hide file tree
Showing 12 changed files with 825 additions and 573 deletions.
4 changes: 4 additions & 0 deletions vllm/model_executor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.mamba_metadata import MambaCacheParams, RequestInfo, MambaCache

__all__ = [
"SamplingMetadata",
"set_random_seed",
"MambaCacheParams",
"RequestInfo",
"MambaCache",
]
8 changes: 6 additions & 2 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import Optional
from typing import Dict, List, Optional

import torch

from vllm.model_executor.mamba_metadata import MambaCache, RequestInfo


class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention.
Expand All @@ -27,6 +29,7 @@ def __init__(
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
kv_cache_dtype: str,
requests_info: Optional[List[RequestInfo]] = None
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
Expand All @@ -42,7 +45,8 @@ def __init__(
# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
self.attn_bias = None
self.mamba_metadata = None
self.mamba_cache_batch: List[MambaCache] = []
self.requests_info = requests_info

def __repr__(self) -> str:
return ("InputMetadata("
Expand Down
30 changes: 30 additions & 0 deletions vllm/model_executor/mamba_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple
import torch

@dataclass
class MambaCacheParams:
seqlen_offset: int = 0
conv_state: torch.Tensor = torch.Tensor()
ssm_state: torch.Tensor = torch.Tensor()


@dataclass
class RequestInfo:
request_id: str = ''
n: int = 1


class MambaCache:
def __init__(
self,
request_info: RequestInfo,
layer_idx2mamba_cache: Optional[Dict[int, MambaCacheParams]] = None
) -> None:
self.request_info = request_info
if layer_idx2mamba_cache is None:
self.layer_idx2mamba_cache = defaultdict(MambaCacheParams)
else:
self.layer_idx2mamba_cache = layer_idx2mamba_cache

24 changes: 10 additions & 14 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"LlavaForConditionalGeneration":
("llava", "LlavaForConditionalGeneration"),
"LlavaForConditionalGeneration": ("llava", "LlavaForConditionalGeneration"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
Expand All @@ -54,7 +53,7 @@
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
"Jurassic3ForCausalLM": ("jurassic3", "Jurassic3ForCausalLM")
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
}

# Architecture -> type.
Expand All @@ -67,17 +66,13 @@
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"Qwen2ForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM": "Sliding window attention is not yet supported in ROCm's flash attention",
}


class ModelRegistry:

@staticmethod
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
Expand All @@ -88,15 +83,16 @@ def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
"ROCm for now."
)
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]
)

module_name, model_cls_name = _MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
module = importlib.import_module(f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)

@staticmethod
Expand Down
Loading

0 comments on commit 07cc899

Please sign in to comment.