Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor(vllm): to prepare for merging it to upstream #78

Merged
merged 6 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
277 changes: 61 additions & 216 deletions aria/vllm/aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig, VllmConfig
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
Expand All @@ -39,24 +38,20 @@
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput, SamplingMetadata
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaMLP,
LlamaModel,
RMSNorm,
)
from vllm.model_executor.models.utils import (
PPMissingLayer,
AutoWeightsLoader,
WeightsMapper,
make_layers,
maybe_prefix,
merge_multimodal_embeddings,
)
from vllm.model_executor.utils import set_weight_attrs
Expand All @@ -70,17 +65,11 @@
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of

from aria.model.configuration_aria import AriaConfig
from aria.model.projector import AriaProjector
from aria.model.vision_encoder import AriaVisionModel
from .projector import AriaProjector
from .vision_encoder import AriaVisionModel

logger = logging.get_logger(__name__)

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
"language_model.model": "language_model",
}


class AriaMoELMConfig(LlamaConfig):
"""
Expand Down Expand Up @@ -156,36 +145,37 @@ def __init__(self, config: AriaMoELMConfig):
)
)
)
set_weight_attrs(self.router_weight, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w1, {"weight_loader": self.weight_loader})
set_weight_attrs(self.w2, {"weight_loader": self.weight_loader})
set_weight_attrs(
self.router_weight, {"weight_loader": self._weight_loader_for_router}
)
set_weight_attrs(self.w1, {"weight_loader": self._weight_loader_for_w1})
set_weight_attrs(self.w2, {"weight_loader": self._weight_loader_for_w2})

def weight_loader(
self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
def _weight_loader_for_router(
self, param: nn.Parameter, loaded_weight: torch.Tensor
):
if shard_id == "router":
param.data.copy_(loaded_weight)
elif shard_id == "w1":
if self.tp_size > 1:
# the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size)
up, gate = loaded_weight.chunk(2, dim=-1)
up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank]
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank]
up_and_gate = torch.cat(
[up_current_rank, gate_current_rank], dim=-1
).transpose(1, 2)
param.data.copy_(up_and_gate)
else:
param.data.copy_(loaded_weight.transpose(1, 2))
param.data.copy_(loaded_weight)

def _weight_loader_for_w1(self, param: nn.Parameter, loaded_weight: torch.Tensor):
# the shape of loaded_weight is (num_experts, hidden_size, 2 * moe_intermediate_size)
if self.tp_size > 1:
up, gate = loaded_weight.chunk(2, dim=-1)
up_current_rank = up.chunk(self.tp_size, dim=-1)[self.tp_rank]
gate_current_rank = gate.chunk(self.tp_size, dim=-1)[self.tp_rank]
up_and_gate = torch.cat(
[up_current_rank, gate_current_rank], dim=-1
).transpose(1, 2)
param.data.copy_(up_and_gate)
else:
if self.tp_size > 1:
# the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size)
down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[
self.tp_rank
]
param.data.copy_(down_current_rank.transpose(1, 2))
else:
param.data.copy_(loaded_weight.transpose(1, 2))
param.data.copy_(loaded_weight.transpose(1, 2))

def _weight_loader_for_w2(self, param: nn.Parameter, loaded_weight: torch.Tensor):
# the shape of loaded_weight is (num_experts, moe_intermediate_size, hidden_size)
if self.tp_size > 1:
down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[self.tp_rank]
param.data.copy_(down_current_rank.transpose(1, 2))
else:
param.data.copy_(loaded_weight.transpose(1, 2))

def forward(self, hidden_states):
router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
Expand Down Expand Up @@ -328,39 +318,18 @@ class AriaMoELMModel(LlamaModel):
config (LlamaConfig): Configuration object for the model.
"""

def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config, prefix=prefix)

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

# FIXME(zhoufan): this is a hack to avoid the error: AttributeError: 'AriaMoELMModel' object has no attribute 'do_not_compile'.
self.do_not_compile = True

self.config = config
self.padding_idx = config.pad_token_id
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
if get_pp_group().is_first_rank or (
config.tie_word_embeddings and get_pp_group().is_last_rank
):
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()
self.layers = None

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: MoEDecoderLayer(
Expand All @@ -371,112 +340,9 @@ def __init__(
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()


class AriaMoELMForCausalLM(LlamaForCausalLM):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}

# LoRA specific attributes
supported_lora_modules = [
"qkv_proj",
"o_proj",
"gate_up_proj",
"down_proj",
"embed_tokens",
"lm_head",
]
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
# Mistral/Llama models can also be loaded with --load-format mistral
# from consolidated.safetensors checkpoints
mistral_mapping = {
"layers": "model.layers",
"attention": "self_attn",
"wq": "q_proj",
"wk": "k_proj",
"wv": "v_proj",
"wo": "o_proj",
"attention_norm": "input_layernorm",
"feed_forward": "mlp",
"w1": "gate_proj",
"w2": "down_proj",
"w3": "up_proj",
"ffn_norm": "post_attention_layernorm",
"tok_embeddings": "model.embed_tokens",
"output": "lm_head",
"norm": "model.norm",
}

def __init__(
self,
config: LlamaConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
nn.Module.__init__(self)

self.config = config
self.lora_config = lora_config

self.model = AriaMoELMModel(
config, cache_config, quant_config, lora_config=lora_config, prefix="model"
)
if get_pp_group().is_last_rank:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight

logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
self.sampler = Sampler()
else:
self.lm_head = PPMissingLayer()


def build_mm_projector(config: AriaConfig):
def build_mm_projector(config):
"""
Builds and returns an AriaProjector instance based on the provided configuration.

Expand Down Expand Up @@ -699,7 +565,6 @@ def input_processor(ctx, llm_inputs):
)


# adapted from transformers.models.llava.modeling_llava.LlavaForConditionalGeneration
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens)
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria)
@INPUT_REGISTRY.register_input_processor(input_processor)
Expand All @@ -718,7 +583,7 @@ def __init__(
):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
vllm_config.cache_config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this line for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove it. This change was made by the formatter because the cache_config is not currently used.

quant_config = vllm_config.quant_config

# prepare the image_size to tokens mapping for the image preprocess, see input_processor
Expand All @@ -735,7 +600,8 @@ def __init__(
self.multi_modal_projector = build_mm_projector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaMoELMModel(
config.text_config, cache_config, quant_config
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
self.pad_token_id = (
self.config.pad_token_id if self.config.pad_token_id is not None else -1
Expand Down Expand Up @@ -773,11 +639,10 @@ def forward(
torch.bfloat16
)
pixel_mask = pixel_mask.view(-1, *pixel_mask.shape[-2:])
image_outputs, image_attn_mask = self.vision_tower(
selected_image_feature, image_attn_mask = self.vision_tower(
pixel_values,
pixel_mask=pixel_mask,
)
selected_image_feature = image_outputs.last_hidden_state

image_features = self.multi_modal_projector(
selected_image_feature, attn_mask=image_attn_mask
Expand Down Expand Up @@ -814,37 +679,17 @@ def sample(
return next_tokens

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("experts.router_weight", "router.weight", "router"),
("experts.w1", "experts.fc1.weight", "w1"),
("experts.w2", "experts.fc2.weight", "w2"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
if key_to_modify in name:
name = name.replace(key_to_modify, new_key)
shard_id = None
# Because we used the origin hf vit and vision projector, we cound keep the weight in the sharded shape.
# Only for the language model part needs to adjust the weight loading.
if "language_model" in name:
for param_name, weight_name, _shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
shard_id = _shard_id
break

param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
if shard_id is not None:
weight_loader(param, loaded_weight, shard_id)
else:
weight_loader(param, loaded_weight)
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"language_model.model": "language_model",
"language_model.lm_head": "lm_head",
},
orig_to_new_suffix={
"experts.fc1.weight": "experts.w1",
"experts.fc2.weight": "experts.w2",
"router.weight": "experts.router_weight",
},
)

loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=hf_to_vllm_mapper)
Loading
Loading