Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
support phi models
Browse files Browse the repository at this point in the history
  • Loading branch information
mikecovlee committed Jul 25, 2024
1 parent 49ea13d commit dd06560
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
2 changes: 1 addition & 1 deletion mlora/common/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def eager_attention_forward(
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attention_score = attention_score + causal_mask
attention_score = F.softmax(attention_score, dim=-1, dtype=torch.float32).to(
query_states.dtype
value_states.dtype
)
attention_score = torch.matmul(attention_score, value_states)
attention_score = attention_score.transpose(1, 2).contiguous()
Expand Down
74 changes: 41 additions & 33 deletions mlora/models/modeling_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@

@dataclass
class PhiConfig(LLMModelConfig):
partial_rotary_factor_: float = 0.5
layer_norm_eps_: float = 1e-05
resid_pdrop_: float = 0.0
embd_pdrop_: float = 0.0
rotary_emb_dim_: int = 0
qk_layernorm_: bool = False


Expand Down Expand Up @@ -76,41 +76,35 @@ def __init__(
v_proj: nn.Module,
dense: nn.Module,
idx: int,
args: PhiConfig,
config: PhiConfig,
):
super().__init__()
# attention
self.wq_: Linear = Linear(q_proj, args.device_)
self.wk_: Linear = Linear(k_proj, args.device_)
self.wv_: Linear = Linear(v_proj, args.device_)
self.dense_: Linear = Linear(dense, args.device_)
self.wq_: Linear = Linear(q_proj, config.device_)
self.wk_: Linear = Linear(k_proj, config.device_)
self.wv_: Linear = Linear(v_proj, config.device_)
self.dense_: Linear = Linear(dense, config.device_)
# config
self.layer_idx_ = idx
self.dim_ = args.dim_
self.n_heads_ = args.n_heads_
self.n_kv_heads_ = args.n_kv_heads_
self.dim_ = config.dim_
self.n_heads_ = config.n_heads_
self.n_kv_heads_ = config.n_kv_heads_
self.n_rep_ = self.n_heads_ // self.n_kv_heads_
self.head_dim_ = args.dim_ // args.n_heads_
self.dtype_ = args.dtype_
self.rotary_emb_dim_ = config.rotary_emb_dim_
self.head_dim_ = config.head_dim_
self.dtype_ = config.dtype_
self.is_causal_ = True
# cos and sin
self.rotary_emb_ = PhiRotaryEmbedding(
int(args.partial_rotary_factor_ * self.head_dim_),
max_position_embeddings=args.max_seq_len_,
base=args.rope_theta_,
device=args.device_,
)
# qk norm
self.qk_layernorm_: bool = args.qk_layernorm_
self.qk_layernorm_: bool = config.qk_layernorm_
if self.qk_layernorm_:
self.q_layernorm_ = nn.LayerNorm(
self.hidden_size_ // self.num_heads_,
eps=args.norm_eps_,
eps=config.norm_eps_,
elementwise_affine=True,
)
self.k_layernorm_ = nn.LayerNorm(
self.hidden_size_ // self.num_heads_,
eps=args.norm_eps_,
eps=config.norm_eps_,
elementwise_affine=True,
)
else:
Expand All @@ -129,6 +123,7 @@ def forward(
self,
hidden_states: torch.Tensor,
input_args: LLMModelInput,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
Expand All @@ -153,16 +148,13 @@ def forward(
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
).transpose(1, 2)

kv_seq_len = xk.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx_)
cos, sin = self.rotary_emb_(xv, seq_len=kv_seq_len)
cos, sin = rotary_emb

# partial rotary embedding
xq, xk = apply_partial_rotary_emb(
xq,
xk,
self.rotary_emb_.dim,
self.rotary_emb_dim_,
cos,
sin,
cache_position.unsqueeze(0),
Expand All @@ -172,7 +164,7 @@ def forward(
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_emb_.dim,
"partial_rotation_size": self.rotary_emb_dim_,
"cache_position": cache_position,
}
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
Expand Down Expand Up @@ -209,6 +201,7 @@ def forward(
self,
hidden_states: torch.Tensor,
input_args: LLMModelInput,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
Expand All @@ -233,16 +226,13 @@ def forward(
batch_size, max_seq_len, self.n_kv_heads_, self.head_dim_
).transpose(1, 2)

kv_seq_len = xk.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx_)
cos, sin = self.rotary_emb_(xv, seq_len=kv_seq_len)
cos, sin = rotary_emb

# partial rotary embedding
xq, xk = apply_partial_rotary_emb(
xq,
xk,
self.rotary_emb_.dim,
self.rotary_emb_dim_,
cos,
sin,
cache_position.unsqueeze(0),
Expand All @@ -252,7 +242,7 @@ def forward(
cache_kwargs = {
"sin": sin,
"cos": cos,
"partial_rotation_size": self.rotary_emb_.dim,
"partial_rotation_size": self.rotary_emb_dim_,
"cache_position": cache_position,
}
xk, xv = past_key_value.update(xk, xv, self.layer_idx_, cache_kwargs)
Expand Down Expand Up @@ -394,6 +384,7 @@ def forward(
self,
hidden_states: torch.Tensor,
input_args: LLMModelInput,
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
past_key_value: Optional[Cache] = None,
Expand All @@ -404,6 +395,7 @@ def forward(
attn_outputs = self.self_attn_.forward(
hidden_states,
input_args,
rotary_emb,
attention_mask,
cache_position,
past_key_value,
Expand Down Expand Up @@ -462,6 +454,12 @@ def __init__(self, config: PhiConfig) -> None:
self.vocab_size_ = config.vocab_size_
self.embed_tokens_ = PhiEmbedding(config)
self.final_layernorm_ = PhiLayerNorm(config)
self.rotary_emb_ = PhiRotaryEmbedding(
dim=config.rotary_emb_dim_,
max_position_embeddings=config.max_seq_len_,
base=config.rope_theta_,
device=config.device_,
)
self.lm_head_ = nn.Linear(
config.dim_,
config.vocab_size_,
Expand All @@ -474,6 +472,11 @@ def __init__(self, config: PhiConfig) -> None:
def embed_tokens(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens_(input_ids)

def rotary_embed(
self, input_tensor: torch.Tensor, position_ids: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.rotary_emb_(input_tensor, seq_len=position_ids[-1, -1] + 1)

def decoder_stack(self) -> List[LLMDecoder]:
return self.layers_

Expand Down Expand Up @@ -511,6 +514,7 @@ def from_pretrained(
name_or_path_=llm_config.name_or_path,
vocab_size_=llm_config.vocab_size,
dim_=llm_config.hidden_size,
head_dim_=llm_config.hidden_size // llm_config.num_attention_heads,
intermediate_=llm_config.intermediate_size,
n_layers_=llm_config.num_hidden_layers,
n_heads_=llm_config.num_attention_heads,
Expand All @@ -529,6 +533,10 @@ def from_pretrained(
dtype_=llm_model.dtype,
)

llm_args.rotary_emb_dim_ = int(
llm_args.partial_rotary_factor_ * llm_args.head_dim_
)

if llm_args.pad_token_id_ is None:
llm_args.pad_token_id_ = -1

Expand Down

0 comments on commit dd06560

Please sign in to comment.