Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[models] Microsoft Phi 1.5 #1664

Merged
merged 40 commits into from
Nov 16, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
e7d81f5
Modify the tests and basic config of phi model
Nov 13, 2023
6f3b8e9
Minor update in model configuration
Nov 13, 2023
7a38f33
Fix weight
Nov 14, 2023
2343963
Minor name fix
Nov 14, 2023
6cd07e9
Fix mapping
Nov 14, 2023
7781e60
Fix model names
Nov 14, 2023
2fec6af
Modify tests
Nov 14, 2023
70c8e03
Update requiremenrs
Nov 14, 2023
505b5f3
Fix codestyle
Nov 14, 2023
0f730b8
Fix rotaty emb
Nov 14, 2023
1f664ec
Fix codestyle
Nov 14, 2023
be4b546
Fix codestyle in tests
Nov 14, 2023
eeb4da4
Update requirements.txt
maximzubkov Nov 15, 2023
0e408cf
Update test_models.py
maximzubkov Nov 15, 2023
4856e5d
Reorder imports
Nov 15, 2023
f6942b5
Fix codestylr
Nov 15, 2023
6b7db71
Update vllm/model_executor/models/phi_1_5.py
maximzubkov Nov 15, 2023
8f97e39
Update vllm/model_executor/models/phi_1_5.py
maximzubkov Nov 15, 2023
de30260
Update vllm/model_executor/models/phi_1_5.py
maximzubkov Nov 15, 2023
bb1455c
Use self instead of config
Nov 15, 2023
6f7389b
Remove comment about dropout
Nov 15, 2023
4db31e4
Update vllm/model_executor/models/phi_1_5.py
maximzubkov Nov 15, 2023
029d77e
License
Nov 15, 2023
ee3419b
Readme minor fix
Nov 15, 2023
387f970
Original way of defining the model
Nov 15, 2023
fccc70b
Fix bugs
Nov 15, 2023
c098c1d
Fix _column_parallel_weights
Nov 15, 2023
7ca6b73
Fix codestylr
Nov 15, 2023
09d8234
Fix parrallel weight
Nov 15, 2023
34d9adc
Merge branch 'main' into phi-1.5
Nov 16, 2023
ab21cd9
Update after merge
Nov 16, 2023
ed1f59b
Fix codestyle
Nov 16, 2023
7b802b5
Fix bias
Nov 16, 2023
341ce91
Minor fix
Nov 16, 2023
f262f28
Codestyle
Nov 16, 2023
675130f
Remove inv_freq upd
Nov 16, 2023
654322f
Update phi model to match the most recent configuration
Nov 16, 2023
de09db6
Fix order of imports and rename
Nov 16, 2023
b2d4587
Update vllm/model_executor/models/phi_1_5.py
maximzubkov Nov 16, 2023
8f2610d
Update vllm/model_executor/models/phi_1_5.py
maximzubkov Nov 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
- Phi-1.5 (`microsoft/phi-1_5`, `microsoft/phi-1` etc.)
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)

Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
einops # Required for phi-1_5
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
torch >= 2.1.0
transformers >= 4.34.0 # Required for Mistral.
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
Expand Down
15 changes: 5 additions & 10 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,11 @@
import pytest

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1",
"tiiuae/falcon-7b",
"gpt2",
"bigcode/tiny_starcoder_py",
"EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m",
"bigscience/bloom-560m",
"mosaicml/mpt-7b",
"facebook/opt-125m", "meta-llama/Llama-2-7b-hf",
"mistralai/Mistral-7B-v0.1", "tiiuae/falcon-7b", "gpt2",
"bigcode/tiny_starcoder_py", "EleutherAI/gpt-j-6b",
"EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b",
"microsoft/phi-1_5"
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"MPTForCausalLM": MptForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"MixFormerSequentialForCausalLM": PhiForCausalLM,
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
"RWForCausalLM": FalconForCausalLM,
"YiForCausalLM": YiForCausalLM,
}
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.models.mpt import MptForCausalLM
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.model_executor.models.qwen import QWenLMHeadModel
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
from vllm.model_executor.models.yi import YiForCausalLM

Expand All @@ -32,6 +33,7 @@
"MptForCausalLM",
"OPTForCausalLM",
"QWenLMHeadModel",
"PhiForCausalLM",
"MistralForCausalLM",
"YiForCausalLM",
]
274 changes: 274 additions & 0 deletions vllm/model_executor/models/phi_1_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
# coding=utf-8
# Adapted from
# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_mixformer_sequential.py
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
# Copyright 2023 The vLLM team.
# Copyright 2023 Microsoft and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
"""Inference-only Phi-1.5 model compatible with HuggingFace weights.

The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple

import torch
from torch import nn
from transformers import PretrainedConfig

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput

KVCache = Tuple[torch.Tensor, torch.Tensor]


class PhiAttention(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()
self.total_num_heads = config.n_head
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
self.hidden_size = config.n_embd
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
self.head_size = self.hidden_size // self.total_num_heads

tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size())
assert self.total_num_heads % tensor_model_parallel_world_size == 0
self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size)

# pylint: disable=C0103
self.Wqkv = ColumnParallelLinear(
config.n_embd,
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
3 * config.n_embd,
gather_output=False,
)
self.out_proj = RowParallelLinear(
config.n_embd,
config.n_embd,
input_is_parallel=True,
)

scaling = self.head_size**-0.5
rotary_dim = config.rotary_dim
assert rotary_dim % 2 == 0

# pylint: disable=C0301
# See https://huggingface.co/microsoft/phi-1_5/blob/92557d03bb12543040c8bb5f0475cbdd9968f05f/modeling_mixformer_sequential.py#L222
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
rope_theta = 10000
max_position_embeddings = getattr(config, "n_positions", 2048)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_size,
scaling,
rotary_dim,
base=rope_theta,
max_position=max_position_embeddings)

def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
k_cache, v_cache = kv_cache
attn_output = self.attn(position_ids, q, k, v, k_cache, v_cache,
input_metadata, cache_event)
output, _ = self.out_proj(attn_output)
return output


class PhiMLP(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()

n_inner = getattr(config, "n_inner", None)
n_inner = n_inner if n_inner is not None else 4 * config.n_embd

self.fc1 = ColumnParallelLinear(
config.n_embd,
n_inner,
gather_output=False,
)
self.fc2 = RowParallelLinear(
n_inner,
config.n_embd,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)

def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states


class PhiLayer(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mixer = PhiAttention(config)
self.mlp = PhiMLP(config)

def forward(
self,
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
# pylint: disable=C0301
# Dropout 0.0 https://huggingface.co/microsoft/phi-1_5/blob/92557d03bb12543040c8bb5f0475cbdd9968f05f/modeling_mixformer_sequential.py#L696
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attn_outputs + feed_forward_hidden_states + residual
return hidden_states


class PhiModel(nn.Module):

def __init__(self, config: PretrainedConfig):
super().__init__()
self.config = config

self.wte = VocabParallelEmbedding(
config.vocab_size,
config.n_embd,
)
self.layers = nn.ModuleList(
[PhiLayer(config) for _ in range(config.n_layer)])

def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
position_ids,
hidden_states,
kv_caches[i],
input_metadata,
cache_event,
)
return hidden_states


class PhiForCausalLM(nn.Module):

def __init__(self, config):
super().__init__()
self.config = config
self.phi = PhiModel(config)
self.ln = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.linear = ColumnParallelLinear(
config.n_embd,
config.vocab_size,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]],
) -> SamplerOutput:
hidden_states = self.phi(input_ids, positions, kv_caches,
input_metadata, cache_events)
hidden_states = self.ln(hidden_states)
next_tokens = self.sampler(self.linear.weight, hidden_states,
input_metadata, self.linear.bias)
return next_tokens

_column_parallel_weights = [
"embed_in.weight", "embed_out.weight", "embed_out.bias", "fc1.weight",
"fc1.bias"
]
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, why is the inv_freq here different from the one in our RotaryEmbedding? According to this line, the code looks the same.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I also thought so until I checked the weight:

  • inv_freq from RotaryEmbedding
tensor([1.0000e+00, 5.6234e-01, 3.1623e-01, 1.7783e-01, 1.0000e-01, 5.6234e-02,
        3.1623e-02, 1.7783e-02, 1.0000e-02, 5.6234e-03, 3.1623e-03, 1.7783e-03,
        1.0000e-03, 5.6234e-04, 3.1623e-04, 1.7783e-04], device='cuda:0',
       dtype=torch.float32)
  • loaded_weight of rotary_emb.inv_freq
tensor([1.0000e+00, 5.6250e-01, 3.1616e-01, 1.7786e-01, 9.9976e-02, 5.6244e-02,
             3.1616e-02, 1.7776e-02, 1.0002e-02, 5.6229e-03, 3.1624e-03, 1.7786e-03,
             1.0004e-03, 5.6219e-04, 3.1614e-04, 1.7786e-04])

But even more surprising is that after I copied the cached weight, the output on Test 4 was still different. However, @Linzecong says that it might be an issue of fp16 link

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some investigations, I found that the issue is because inv_freq is stored in FP16 in the weight checkpoint (I guess it is calculated in FP32 and converted to FP16 for some reason). As we use the same logic to calculate inv_freq, I think this slight difference should be acceptable.

Can we remove the special weight loading logic for inv_freq and use our current implementation instead?

t = torch.arange(self.config.n_positions, dtype=torch.float32)

freqs = torch.einsum("i,j -> ij", t, loaded_weight)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)

for i in range(len(self.phi.layers)):
self.phi.layers[
i].mixer.attn.rotary_emb.cos_sin_cache.copy_(cache)
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
continue
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved
_, layer_idx, *tail = name.split(".")
tail = ".".join(tail)
layer_idx = int(layer_idx)

# First or last layers are Embeddings and CausalLMHead respectively
if layer_idx == 0:
key = f"phi.{tail}"
elif layer_idx == self.config.n_layer + 1:
key = tail
else:
key = f"phi.layers.{layer_idx - 1}.{tail}"
maximzubkov marked this conversation as resolved.
Show resolved Hide resolved

# pylint: disable=E1136
param = state_dict[key]
load_tensor_parallel_weights(param, loaded_weight, name,
self._column_parallel_weights,
self._row_parallel_weights,
tensor_model_parallel_rank)
Loading