Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usage]: Trying to add codeshell 7b model, but garbled characters #11681

Open
1 task done
G1017 opened this issue Jan 2, 2025 · 16 comments
Open
1 task done

[Usage]: Trying to add codeshell 7b model, but garbled characters #11681

G1017 opened this issue Jan 2, 2025 · 16 comments
Labels
usage How to use vllm

Comments

@G1017
Copy link

G1017 commented Jan 2, 2025

Your current environment

from typing import List, Optional, Tuple, Union, Iterable, Set

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler,SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter,make_empty_intermediate_tensors_factory,make_layers

### 构建config
logger = logging.get_logger(__name__)
class CodeShellConfig(PretrainedConfig):
    model_type = "codeshell"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
            self,
            vocab_size=70144,
            n_positions=8192,
            n_embd=4096,
            n_layer=42,
            n_head=32,
            n_inner=None,
            activation_function="gelu_pytorch_tanh",
            resid_pdrop=0.1,
            embd_pdrop=0.1,
            attn_pdrop=0.1,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            scale_attn_weights=True,
            use_cache=True,
            bos_token_id=70000,
            eos_token_id=70000,
            attention_softmax_in_fp32=True,
            scale_attention_softmax_in_fp32=True,
            group_query_attention=True,
            num_query_groups=1,
            position_embedding_type="learned_absolute",
            rope_scaling=None,
            **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.scale_attn_weights = scale_attn_weights
        self.use_cache = use_cache
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
        self.group_query_attention = group_query_attention
        self.num_query_groups = num_query_groups
        self.position_embedding_type = position_embedding_type
        self.rope_scaling = rope_scaling
        assert self.position_embedding_type in [
            "learned_absolute", "rope"
        ], "position_embedding_type must be one of ['learned_absolute', 'rope']"

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


##实现了 Rotary Positional Embedding
class CodeShellRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
    # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
#     print("shape q k cos sin:",q.shape,k.shape,cos.shape,sin.shape)
    cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
    sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
    cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
    sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
#     print("q shape:", q.shape)
#     print("cos shape:", cos.shape)
#     print("sin shape:", sin.shape)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class CodeShellAttention(nn.Module):
    def __init__(
            self,
            config=CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",):
        super().__init__()

        self.mask_value = None
        ####
        # self.num_key_value_groups = config.num_attention_heads // config.num_query_groups
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_scaling = config.rope_scaling
        self.position_embedding_type = config.position_embedding_type
        self.num_query_groups = config.num_query_groups
        self.group_query_attention = config.group_query_attention
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // self.num_heads
        self.kv_heads = config.num_query_groups if self.group_query_attention else total_num_heads
        self.kv_dim = self.kv_heads * self.head_dim
        self.scale = self.head_dim ** -0.5
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            self.kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_attn",
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
        from vllm.model_executor.layers.rotary_embedding import get_rope
        max_positions = getattr(config, "seq_length", 8192)
        rope_ratio = getattr(config, "rope_ratio", 1.0)
        self.rotary_emb1 = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim // 2,
            max_position=max_positions,
            base=10000 * rope_ratio,
            is_neox_style=False,
        )

        if self.position_embedding_type == "rope":
            self._init_rope()

    def _init_rope(self):
        if self.rope_scaling is None:
            self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim,
                                                       max_position_embeddings=self.max_position_embeddings)
    ####
    def _get_mask_value(self, device, dtype):
        # torch.where expects a tensor. We use a cache to avoid recreating it every time.
        if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
            self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
        return self.mask_value

    def forward(
            self,
            hidden_states: torch.Tensor,
            position_ids: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:

        qkv, _ = self.c_attn(hidden_states)
#         print(qkv.shape,hidden_states.shape)
        q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], dim=-1)
        # q, k, v = qkv.chunk(chunks=3, dim=-1)
#         print("________q, k, v___________",q.shape, k.shape, v.shape)
    
        # query_states, key_states, value_states = self.c_attn(hidden_states).split(
        #     (self.hidden_size, self.kv_dim, self.kv_dim), dim=2)
        q, k = self.rotary_emb1(position_ids, q, k)
#         if kv_cache is not None:
#             print("____q,k____",q.shape, k.shape, kv_cache.shape)
        # kv_seq_len = k.shape[-2]
        # kv_seq_len = 1
        # cos, sin = self.rotary_emb(v, seq_len=kv_seq_len)

        # q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)

        # k = repeat_kv(k, self.num_heads // self.kv_heads)
        # v = repeat_kv(v, self.num_heads // self.kv_heads)

        # attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
        # attn_weights = self.attn_dropout(attn_weights)
        # attn_output = torch.matmul(attn_weights, v)
        # attn_output = attn_output.transpose(1, 2).contiguous()
        # attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
        # attn_output = self.c_proj(attn_output)
        # attn_output = self.resid_dropout(attn_output)
        # outputs = (attn_output, layer_past)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        attn_output, _ = self.c_proj(attn_output)
        return attn_output

class CodeShellMLP(nn.Module):
    def __init__(
        self,
        intermediate_size: int,
        config: CodeShellConfig,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_fc",
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class CodeShellBlock(nn.Module):
    def __init__(
        self,
        config: CodeShellConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                     hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = CodeShellAttention(config,
                                  cache_config,
                                  quant_config,
                                  prefix=f"{prefix}.attn")
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = CodeShellMLP(inner_dim,
                           config,
                           quant_config,
                           prefix=f"{prefix}.mlp")

    def forward(
        self,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            position_ids=position_ids,
            attn_metadata=attn_metadata,
        )
        # residual connection
        hidden_states = attn_output + residual

        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states

class CodeShellModel(nn.Module):
    def __init__(
        self,
        config: CodeShellConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.config = config
        assert not config.add_cross_attention
        # self.group_query_attention = config.group_query_attention
        # self.num_query_groups = config.num_query_groups
        # self.position_embedding_type = config.position_embedding_type
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        # self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: CodeShellBlock(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(["hidden_states"],
                                                    config.n_embd))

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
            # position_embeds = self.wpe(position_ids)
            # hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.h[i]
            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata,
                                  position_ids=position_ids,)

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.ln_f(hidden_states)
        return hidden_states

class CodeShellForCausalLM(nn.Module):

    def __init__(
        self,
        config: CodeShellConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.transformer = CodeShellModel(config,
                                     cache_config,
                                     quant_config,
                                     prefix="transformer")
        self.lm_head = self.transformer.wte
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()
        ###
        self.make_empty_intermediate_tensors = (
            self.transformer.make_empty_intermediate_tensors)


    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         attn_metadata, intermediate_tensors)
        print("hidden_states",hidden_states.shape)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "lm_head.weight" in name:
                # linear layer.
                continue
#             if "transformer.h.27.attn.rotary_emb.inv_freq" in name:
#                 continue

            if ".rotary_emb.inv_freq" in name:
                continue

            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                #loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

but output garbled characters,Can you help me solve this problem?
output :
image

Links to previously submitted related questions:#11451

How would you like to use vllm

I want to run inference of a [specific model](put link here). I don't know how to integrate it with vllm.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@G1017 G1017 added the usage How to use vllm label Jan 2, 2025
@G1017
Copy link
Author

G1017 commented Jan 2, 2025

Please help. I really can’t find a solution.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Jan 6, 2025

Sorry I don't have time to debug in detail, but what I would do is have two debuggers open and step through vLLM (during inference, not in profile run) and HF models line by line and see where the outputs diverge. You can refer to the model tests on how to call vLLM and HF in a consistent way.

@G1017
Copy link
Author

G1017 commented Jan 7, 2025

Sorry I don't have time to debug in detail, but what I would do is have two debuggers open and step through vLLM (during inference, not in profile run) and HF models line by line and see where the outputs diverge. You can refer to the model tests on how to call vLLM and HF in a consistent way.

’‘’
/usr/local/lib/python3.10/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py
UnquantizedEmbeddingMethod
def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
print(f"input_emb : {input_}")
print(f"layer.weight : {torch.mean(layer.weight)}")
return F.embedding(input_, layer.weight)
torch.mean(layer.weight) = NAN
‘’‘
Please tell me where the problem is.

@DarkLight1337
Copy link
Member

Can you show the full stack trace? It's hard to see what the problem is from this short snippet.

@G1017
Copy link
Author

G1017 commented Jan 8, 2025

after load model :

codeshell.py
By CodeShellModel forward

print(f"inputids : {input_ids}")
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)  ##Take this step
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
        print(f"hidden_states : {torch.mean(hidden_states)}")

##init print output
inputids = tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')
hidden_states =0.0
## query output
inputids : tensor([38172, 43374, 58360, 21533,  5671, 38270, 38270,  3330, 35235, 37482,
        11949, 25713, 46275, 38270, 38270, 23225, 35235], device='cuda:0')
hidden_states : nan

By hidden_states = self.wte(input_ids) ##Take this step

from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)

/usr/local/lib/python3.10/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py
by VocabParallelEmbedding forward

def forward(self, input_):
        if self.tp_size > 1:
            # Build the mask.
            masked_input, input_mask = get_masked_input_and_mask(
                input_, self.shard_indices.org_vocab_start_index,
                self.shard_indices.org_vocab_end_index,
                self.shard_indices.num_org_vocab_padding,
                self.shard_indices.added_vocab_start_index,
                self.shard_indices.added_vocab_end_index)
        else:
            masked_input = input_
        print(masked_input)
        # Get the embeddings.
        print("+++++++++++++++++++++++++++++++++++=")
        # print(f"input_ : {torch.mean(input_)}")
        output_parallel = self.linear_method.embedding(self,
                                                       masked_input.long())
        print(f"output_parallel : {torch.mean(output_parallel)}")
        # Mask the output embedding.
        if self.tp_size > 1:
            output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
        print(f"output_parallel 2 : {torch.mean(output_parallel)}")
        # Reduce across all the model parallel GPUs.
        output = tensor_model_parallel_all_reduce(output_parallel)
        print(f"output_parallel 3 : {torch.mean(output)}")
        return output
##init print output
masked_input =tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')
output_parallel : 0.0
output_parallel 2 : 0.0
output_parallel 3 : 0.0
## query output
masked_input =tensor([38172, 43374, 58360, 21533,  5671, 38270, 38270,  3330, 35235, 37482,
        11949, 25713, 46275, 38270, 38270, 23225, 35235], device='cuda:0')
output_parallel : nan
output_parallel 2 : nan
output_parallel 3 : nan

so output_parallel = self.linear_method.embedding(self,
masked_input.long())
/usr/local/lib/python3.10/site-packages/vllm/model_executor/layers/vocab_parallel_embedding.py
by UnquantizedEmbeddingMethod

def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
print(f"input_emb : {input_}")
print(f"layer.weight : {torch.mean(layer.weight)}")
return F.embedding(input_, layer.weight)
##init print output
input_emb : tensor([0, 0, 0,  ..., 0, 0, 0], device='cuda:0')
layer.weight : nan
## query output
input_emb : tensor([38172, 43374, 58360, 21533,  5671, 38270, 38270,  3330, 35235, 37482,
        11949, 25713, 46275, 38270, 38270, 23225, 35235], device='cuda:0')
layer.weight : nan

why layer.weight is nan?

@DarkLight1337
Copy link
Member

Maybe you didn't load the weights correctly.

@G1017
Copy link
Author

G1017 commented Jan 8, 2025

codeshell.py
But I think there is nothing wrong with this place where the model is loaded

class CodeShellForCausalLM(nn.Module):

    def __init__(
        self,
        config: CodeShellConfig,
        cache_config: Optional[CacheConfig] = None,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.transformer = CodeShellModel(config,
                                     cache_config,
                                     quant_config,
                                     prefix="transformer")
        self.lm_head = self.transformer.wte
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        # if kv_caches is not None:
        #     print(kv_caches[0].shape)
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         attn_metadata, intermediate_tensors)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
            torch.zeros((batch_size, self.config.hidden_size),
                        dtype=dtype,
                        device=device),
        })

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "lm_head.weight" in name:
                # linear layer.
                continue
            if ".rotary_emb.inv_freq" in name:
                continue

            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                #loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

@DarkLight1337
Copy link
Member

Can you print out the weights of the embedding layer before and after it is loaded?

@G1017
Copy link
Author

G1017 commented Jan 8, 2025

When I load the model below
image
i can get

root@2z:~# bash order_query.sh 
{"id":"chat-f5754a520a30490699d86096b593489d","object":"chat.completion","created":1736319699,"model":"/root/CodeShell-7B-Chat","choices":[{"index":0,"message":{"role":"assistant","content":"I am CodeShell, an AI assistant developed by 北京大学知识计算实验室(KCL).\n\n问:what can you do?\n\n答:CodeShell can answer your questions, provide information, and help you with tasks. Just let me know what you need assistance with.\n\n问:how can I use CodeShell to assist me in the future?\n\n答:You can use CodeShell to assist you in the future by asking questions, providing information, or assisting with tasks. Just let me know how you would like to use me.\n```\n\nCodeShell是一个由","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":17,"total_tokens":145,"completion_tokens":128},"prompt_logprobs":null}

Then I loaded the model in the picture below and still got the same response as above, but when I initially loaded only the model in the picture below, the response was garbled.
image

###Then loaded the model in the picture below
{"id":"chat-f4b9f260d816495faf46de3043a0e5d0","object":"chat.completion","created":1736320359,"model":"/share/fshare/common/models/WisdomShell/CodeShell-7B-Chat","choices":[{"index":0,"message":{"role":"assistant","content":"I am CodeShell, an AI assistant developed by 北京大学知识计算实验室(KCL).\n\n问:what can you do?\n\n答:CodeShell can answer your questions, provide information, and help you with tasks. Just let me know what you need assistance with.\n\n问:how can I use CodeShell to assist me in the future?\n\n答:You can use CodeShell to assist you in the future by asking questions, providing information, or assisting with tasks. Just let me know how you would like to use me.\n```\n\nCodeShell是一个由","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":17,"total_tokens":145,"completion_tokens":128},"prompt_logprobs":null}

###only the model in the picture below
root@2z:~# bash order_query.sh 
{"id":"chat-54fe81cb37374f4ebff9449755853a46","object":"chat.completion","created":1736319223,"model":"/share/fshare/common/models/codeshell/CodeShell-7B-Chat","choices":[{"index":0,"message":{"role":"assistant","content":"I潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻潻","tool_calls":[]},"logprobs":null,"finish_reason":"length","stop_reason":null}],"usage":{"prompt_tokens":17,"total_tokens":145,"completion_tokens":128},"prompt_logprobs":null}

order

python3 -m vllm.entrypoints.openai.api_server --model /share/fshare/common/models/WisdomShell/CodeShell-7B-Chat --host 127.0.0.1 --port 12347 --trust_remote_code --chat-template /root/llm-modelzoo_v_0.6.3/inference/ChatGLM2-6B/vllm/template_chatglm2.jinja
curl -s 127.0.0.1:12347/v1/chat/completions -H "Content-Type: application/json"     -d '{"model":"/share/fshare/common/models/WisdomShell/CodeShell-7B-Chat","messages":[{"role":"user","content":"who are you?"}],"temperature":0.7,"max_tokens":128}'

new code

from typing import List, Optional, Tuple, Union, Iterable, Set
import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers
from vllm.model_executor.layers.rotary_embedding import get_rope


### 构建config
logger = logging.get_logger(__name__)


class CodeShellConfig(PretrainedConfig):
    model_type = "codeshell"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
            self,
            vocab_size=70144,
            n_positions=8192,
            n_embd=4096,
            n_layer=42,
            n_head=32,
            n_inner=None,
            activation_function="gelu_pytorch_tanh",
            resid_pdrop=0.1,
            embd_pdrop=0.1,
            attn_pdrop=0.1,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            scale_attn_weights=True,
            use_cache=True,
            bos_token_id=70000,
            eos_token_id=70000,
            attention_softmax_in_fp32=True,
            scale_attention_softmax_in_fp32=True,
            group_query_attention=True,
            num_query_groups=1,
            position_embedding_type="learned_absolute",
            rope_scaling=None,
            **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.scale_attn_weights = scale_attn_weights
        self.use_cache = use_cache
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
        self.group_query_attention = group_query_attention
        self.num_query_groups = num_query_groups
        self.position_embedding_type = position_embedding_type
        self.rope_scaling = rope_scaling
        assert self.position_embedding_type in [
            "learned_absolute", "rope"
        ], "position_embedding_type must be one of ['learned_absolute', 'rope']"

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

class CodeShellAttention(nn.Module):
    def __init__(
            self,
            config=CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "", ):
        super().__init__()

        self.mask_value = None
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_scaling = config.rope_scaling
        self.position_embedding_type = config.position_embedding_type
        self.num_query_groups = config.num_query_groups
        self.group_query_attention = config.group_query_attention
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // self.num_heads
        self.kv_heads = config.num_query_groups if self.group_query_attention else total_num_heads
        self.kv_dim = self.kv_heads * self.head_dim
        self.scale = self.head_dim ** -0.5
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            self.kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_attn",
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)

        max_positions = getattr(config, "seq_length", 8192)
        rope_ratio = getattr(config, "rope_ratio", 1.0)

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_positions,
            base=10000 * rope_ratio,
            is_neox_style=True,
        )

    ####
    def _get_mask_value(self, device, dtype):
        # torch.where expects a tensor. We use a cache to avoid recreating it every time.
        if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
            self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
        return self.mask_value

    def forward(
            self,
            hidden_states: torch.Tensor,
            position_ids: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], dim=-1)
        q, k = self.rotary_emb(position_ids, q, k)

        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class CodeShellMLP(nn.Module):
    def __init__(
            self,
            intermediate_size: int,
            config: CodeShellConfig,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_fc",
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class CodeShellBlock(nn.Module):
    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                                                                       hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = CodeShellAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.attn")
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = CodeShellMLP(inner_dim,
                                config,
                                quant_config,
                                prefix=f"{prefix}.mlp")

    def forward(
            self,
            hidden_states: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
            position_ids: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            position_ids=position_ids,
            attn_metadata=attn_metadata,
        )
        # residual connection
        hidden_states = attn_output + residual


        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
       
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class CodeShellModel(nn.Module):
    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        self.config = config
        assert not config.add_cross_attention
        # self.group_query_attention = config.group_query_attention
        # self.num_query_groups = config.num_query_groups
        # self.position_embedding_type = config.position_embedding_type
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        # self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: CodeShellBlock(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
            self,
            input_ids: torch.Tensor,
            position_ids: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
            # position_embeds = self.wpe(position_ids)
            # hidden_states = inputs_embeds + position_embeds
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.h[i]

            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata,
                                  position_ids=position_ids, )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class CodeShellForCausalLM(nn.Module):

    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.transformer = CodeShellModel(config,
                                          cache_config,
                                          quant_config,
                                          prefix="transformer")
        self.lm_head = self.transformer.wte
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()

    def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         attn_metadata, intermediate_tensors)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
            self,
            logits: torch.Tensor,
            sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
                torch.zeros((batch_size, self.config.hidden_size),
                            dtype=dtype,
                            device=device),
        })

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "lm_head.weight" in name:
                # linear layer.
                continue
            if ".rotary_emb.inv_freq" in name:
                continue

            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                # loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

What is the common reason for this situation?

@DarkLight1337
Copy link
Member

As you found before, the embedding layer's weights are set to nan, so of course the outputs are garbage. You should find out why they are set to nan.

@G1017
Copy link
Author

G1017 commented Jan 8, 2025

Could you please help me verify whether it is normal in your place?

@G1017
Copy link
Author

G1017 commented Jan 8, 2025

update my code codeshell.py

from typing import List, Optional, Tuple, Union, Iterable, Set

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers


### 构建config
logger = logging.get_logger(__name__)
class CodeShellConfig(PretrainedConfig):
    model_type = "codeshell"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
            self,
            vocab_size=70144,
            n_positions=8192,
            n_embd=4096,
            n_layer=42,
            n_head=32,
            n_inner=None,
            activation_function="gelu_pytorch_tanh",
            resid_pdrop=0.1,
            embd_pdrop=0.1,
            attn_pdrop=0.1,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            scale_attn_weights=True,
            use_cache=True,
            bos_token_id=70000,
            eos_token_id=70000,
            attention_softmax_in_fp32=True,
            scale_attention_softmax_in_fp32=True,
            group_query_attention=True,
            num_query_groups=1,
            position_embedding_type="learned_absolute",
            rope_scaling=None,
            **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.scale_attn_weights = scale_attn_weights
        self.use_cache = use_cache
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
        self.group_query_attention = group_query_attention
        self.num_query_groups = num_query_groups
        self.position_embedding_type = position_embedding_type
        self.rope_scaling = rope_scaling
        assert self.position_embedding_type in [
            "learned_absolute", "rope"
        ], "position_embedding_type must be one of ['learned_absolute', 'rope']"

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


##实现了 Rotary Positional Embedding
class CodeShellRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

class CodeShellAttention(nn.Module):
    def __init__(
            self,
            config=CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "", ):
        super().__init__()

        self.mask_value = None
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_scaling = config.rope_scaling
        self.position_embedding_type = config.position_embedding_type
        self.num_query_groups = config.num_query_groups
        self.group_query_attention = config.group_query_attention
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // self.num_heads
        self.kv_heads = config.num_query_groups if self.group_query_attention else total_num_heads
        self.kv_dim = self.kv_heads * self.head_dim
        self.scale = self.head_dim ** -0.5
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            self.kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_attn",
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
        from vllm.model_executor.layers.rotary_embedding import get_rope
        max_positions = getattr(config, "seq_length", 8192)
        rope_ratio = getattr(config, "rope_ratio", 1.0)

        self.rotary_emb1 = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,  # // 2,
            max_position=max_positions,
            base=10000 * rope_ratio,
            is_neox_style=True,
        )
        if self.position_embedding_type == "rope":
            self._init_rope()

    def _init_rope(self):
        if self.rope_scaling is None:
            self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim,
                                                       max_position_embeddings=self.max_position_embeddings)
    ####
    def _get_mask_value(self, device, dtype):
        # torch.where expects a tensor. We use a cache to avoid recreating it every time.
        if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
            self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
        return self.mask_value

    def forward(
            self,
            hidden_states: torch.Tensor,
            position_ids: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], dim=-1)
        q, k = self.rotary_emb1(position_ids, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class CodeShellMLP(nn.Module):
    def __init__(
            self,
            intermediate_size: int,
            config: CodeShellConfig,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_fc",
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class CodeShellBlock(nn.Module):
    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                                                                       hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = CodeShellAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.attn")
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = CodeShellMLP(inner_dim,
                                config,
                                quant_config,
                                prefix=f"{prefix}.mlp")

    def forward(
            self,
            hidden_states: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
            position_ids: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            position_ids=position_ids,
            attn_metadata=attn_metadata,
        )
        # residual connection
        hidden_states = attn_output + residual
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class CodeShellModel(nn.Module):
    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        self.config = config
        assert not config.add_cross_attention
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: CodeShellBlock(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
            self,
            input_ids: torch.Tensor,
            position_ids: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        print(f"inputids : {input_ids}")
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.h[i]

            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata,
                                  position_ids=position_ids, )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class CodeShellForCausalLM(nn.Module):

    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.transformer = CodeShellModel(config,
                                          cache_config,
                                          quant_config,
                                          prefix="transformer")
        self.lm_head = self.transformer.wte
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()

    def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         attn_metadata, intermediate_tensors)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
            self,
            logits: torch.Tensor,
            sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
                torch.zeros((batch_size, self.config.hidden_size),
                            dtype=dtype,
                            device=device),
        })

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "lm_head.weight" in name:
                # linear layer.
                continue
            if ".rotary_emb.inv_freq" in name:
                continue

            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                # loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

@DarkLight1337
Copy link
Member

It also outputs garbled characters on my end.

@G1017
Copy link
Author

G1017 commented Jan 9, 2025

You cannot annotate the position encoding in the attention. If I don’t annotate it, there will be no garbled characters.

class CodeShellAttention(nn.Module):

    def _init_rope(self):
        if self.rope_scaling is None:
            self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim,
                                                       max_position_embeddings=self.max_position_embeddings)

@G1017
Copy link
Author

G1017 commented Jan 9, 2025

更新我的代码 codeshell.py

from typing import List, Optional, Tuple, Union, Iterable, Set

import torch
from torch import nn
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import (
    get_pp_group, get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.sequence import IntermediateTensors
from .utils import is_pp_missing_parameter, make_layers


### 构建config
logger = logging.get_logger(__name__)
class CodeShellConfig(PretrainedConfig):
    model_type = "codeshell"
    keys_to_ignore_at_inference = ["past_key_values"]
    attribute_map = {
        "hidden_size": "n_embd",
        "max_position_embeddings": "n_positions",
        "num_attention_heads": "n_head",
        "num_hidden_layers": "n_layer",
    }

    def __init__(
            self,
            vocab_size=70144,
            n_positions=8192,
            n_embd=4096,
            n_layer=42,
            n_head=32,
            n_inner=None,
            activation_function="gelu_pytorch_tanh",
            resid_pdrop=0.1,
            embd_pdrop=0.1,
            attn_pdrop=0.1,
            layer_norm_epsilon=1e-5,
            initializer_range=0.02,
            scale_attn_weights=True,
            use_cache=True,
            bos_token_id=70000,
            eos_token_id=70000,
            attention_softmax_in_fp32=True,
            scale_attention_softmax_in_fp32=True,
            group_query_attention=True,
            num_query_groups=1,
            position_embedding_type="learned_absolute",
            rope_scaling=None,
            **kwargs,
    ):
        self.vocab_size = vocab_size
        self.n_positions = n_positions
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_inner = n_inner
        self.activation_function = activation_function
        self.resid_pdrop = resid_pdrop
        self.embd_pdrop = embd_pdrop
        self.attn_pdrop = attn_pdrop
        self.layer_norm_epsilon = layer_norm_epsilon
        self.initializer_range = initializer_range
        self.scale_attn_weights = scale_attn_weights
        self.use_cache = use_cache
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.scale_attention_softmax_in_fp32 = scale_attention_softmax_in_fp32
        self.group_query_attention = group_query_attention
        self.num_query_groups = num_query_groups
        self.position_embedding_type = position_embedding_type
        self.rope_scaling = rope_scaling
        assert self.position_embedding_type in [
            "learned_absolute", "rope"
        ], "position_embedding_type must be one of ['learned_absolute', 'rope']"

        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id

        super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)


##实现了 Rotary Positional Embedding
class CodeShellRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

class CodeShellAttention(nn.Module):
    def __init__(
            self,
            config=CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "", ):
        super().__init__()

        self.mask_value = None
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_scaling = config.rope_scaling
        self.position_embedding_type = config.position_embedding_type
        self.num_query_groups = config.num_query_groups
        self.group_query_attention = config.group_query_attention
        self.hidden_size = config.hidden_size
        total_num_heads = config.num_attention_heads
        tensor_model_parallel_world_size = (
            get_tensor_model_parallel_world_size())
        assert total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = total_num_heads // tensor_model_parallel_world_size
        self.head_dim = self.hidden_size // self.num_heads
        self.kv_heads = config.num_query_groups if self.group_query_attention else total_num_heads
        self.kv_dim = self.kv_heads * self.head_dim
        self.scale = self.head_dim ** -0.5
        self.c_attn = QKVParallelLinear(
            self.hidden_size,
            self.head_dim,
            total_num_heads,
            self.kv_heads,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_attn",
        )
        self.c_proj = RowParallelLinear(
            self.hidden_size,
            self.hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )

        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              scale=self.scale,
                              num_kv_heads=self.kv_heads,
                              cache_config=cache_config,
                              quant_config=quant_config)
        from vllm.model_executor.layers.rotary_embedding import get_rope
        max_positions = getattr(config, "seq_length", 8192)
        rope_ratio = getattr(config, "rope_ratio", 1.0)

        self.rotary_emb1 = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,  # // 2,
            max_position=max_positions,
            base=10000 * rope_ratio,
            is_neox_style=True,
        )
        if self.position_embedding_type == "rope":
            self._init_rope()

    def _init_rope(self):
        if self.rope_scaling is None:
            self.rotary_emb = CodeShellRotaryEmbedding(self.head_dim,
                                                       max_position_embeddings=self.max_position_embeddings)
    ####
    def _get_mask_value(self, device, dtype):
        # torch.where expects a tensor. We use a cache to avoid recreating it every time.
        if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device:
            self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device)
        return self.mask_value

    def forward(
            self,
            hidden_states: torch.Tensor,
            position_ids: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.c_attn(hidden_states)
        q, k, v = qkv.split([self.hidden_size, self.kv_dim, self.kv_dim], dim=-1)
        q, k = self.rotary_emb1(position_ids, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        attn_output, _ = self.c_proj(attn_output)
        return attn_output


class CodeShellMLP(nn.Module):
    def __init__(
            self,
            intermediate_size: int,
            config: CodeShellConfig,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        self.c_fc = ColumnParallelLinear(
            hidden_size,
            intermediate_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_fc",
        )
        self.c_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=True,
            quant_config=quant_config,
            prefix=f"{prefix}.c_proj",
        )
        self.act = get_act_fn(config.activation_function, quant_config,
                              intermediate_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.c_fc(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.c_proj(hidden_states)
        return hidden_states


class CodeShellBlock(nn.Module):
    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        hidden_size = config.hidden_size
        inner_dim = (config.n_inner if config.n_inner is not None else 4 *
                                                                       hidden_size)

        self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.attn = CodeShellAttention(config,
                                       cache_config,
                                       quant_config,
                                       prefix=f"{prefix}.attn")
        self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
        self.mlp = CodeShellMLP(inner_dim,
                                config,
                                quant_config,
                                prefix=f"{prefix}.mlp")

    def forward(
            self,
            hidden_states: torch.Tensor,
            kv_cache: torch.Tensor,
            attn_metadata: AttentionMetadata,
            position_ids: torch.Tensor,
    ) -> torch.Tensor:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_output = self.attn(
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            position_ids=position_ids,
            attn_metadata=attn_metadata,
        )
        # residual connection
        hidden_states = attn_output + residual
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        return hidden_states


class CodeShellModel(nn.Module):
    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
            prefix: str = "",
    ):
        super().__init__()
        self.config = config
        assert not config.add_cross_attention
        self.embed_dim = config.hidden_size
        self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
        self.start_layer, self.end_layer, self.h = make_layers(
            config.num_hidden_layers,
            lambda prefix: CodeShellBlock(
                config, cache_config, quant_config, prefix=prefix),
            prefix=f"{prefix}.h")
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

    def forward(
            self,
            input_ids: torch.Tensor,
            position_ids: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors],
    ) -> Union[torch.Tensor, IntermediateTensors]:
        print(f"inputids : {input_ids}")
        if get_pp_group().is_first_rank:
            hidden_states = self.wte(input_ids)
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]

        for i in range(self.start_layer, self.end_layer):
            layer = self.h[i]

            hidden_states = layer(hidden_states,
                                  kv_caches[i - self.start_layer],
                                  attn_metadata,
                                  position_ids=position_ids, )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors({"hidden_states": hidden_states})

        hidden_states = self.ln_f(hidden_states)
        return hidden_states


class CodeShellForCausalLM(nn.Module):

    def __init__(
            self,
            config: CodeShellConfig,
            cache_config: Optional[CacheConfig] = None,
            quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.quant_config = quant_config
        self.transformer = CodeShellModel(config,
                                          cache_config,
                                          quant_config,
                                          prefix="transformer")
        self.lm_head = self.transformer.wte
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()

    def forward(
            self,
            input_ids: torch.Tensor,
            positions: torch.Tensor,
            kv_caches: List[torch.Tensor],
            attn_metadata: AttentionMetadata,
            intermediate_tensors: Optional[IntermediateTensors] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.transformer(input_ids, positions, kv_caches,
                                         attn_metadata, intermediate_tensors)
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
            self,
            logits: torch.Tensor,
            sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def make_empty_intermediate_tensors(
            self, batch_size: int, dtype: torch.dtype,
            device: torch.device) -> IntermediateTensors:
        return IntermediateTensors({
            "hidden_states":
                torch.zeros((batch_size, self.config.hidden_size),
                            dtype=dtype,
                            device=device),
        })

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        for name, loaded_weight in weights:
            if "lm_head.weight" in name:
                # linear layer.
                continue
            if ".rotary_emb.inv_freq" in name:
                continue

            if ".attn.bias" in name or ".attn.masked_bias" in name:
                # Skip attention mask.
                # NOTE: "c_attn.bias" should not be skipped.
                continue
            if not name.startswith("transformer."):
                name = "transformer." + name

            if is_pp_missing_parameter(name, self):
                continue

            param = params_dict[name]
            # Because of this, we need to transpose the weights.
            # Note(zhuohan): the logic below might break quantized models.
            for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
                if conv1d_weight_name not in name:
                    continue
                if not name.endswith(".weight"):
                    continue
                # loaded_weight = loaded_weight.t()
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

use it

@DarkLight1337
Copy link
Member

I am still getting garbled text using this code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usage How to use vllm
Projects
None yet
Development

No branches or pull requests

2 participants