Skip to content

Commit

Permalink
Currently multi-gpu generate does not work with hf.generate for hf ch…
Browse files Browse the repository at this point in the history
…eckpoints. This PR fixes that. (#1332)

* making generate work

* ..

* addressing comments

* reverting hf rotary emb changes, they will remain on a branch

* adding comments
  • Loading branch information
ShashankMosaicML authored Jul 2, 2024
1 parent fe0f25c commit 199c3b9
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 7 deletions.
16 changes: 11 additions & 5 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,11 +266,13 @@ def flash_attn_fn(

batch_size, seqlen = query.shape[:2]

indices_q = flash_attn_padding_info['indices_q']
indices_k = flash_attn_padding_info['indices_k']
indices_v = flash_attn_padding_info['indices_v']
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q']
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k']
# In the following lines we move the tensors to the same devices as query, key, and value respectively. These operations should be no-ops during training.
# This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204
indices_q = flash_attn_padding_info['indices_q'].to(query.device)
indices_k = flash_attn_padding_info['indices_k'].to(key.device)
indices_v = flash_attn_padding_info['indices_v'].to(value.device)
cu_seqlens_q = flash_attn_padding_info['cu_seqlens_q'].to(query.device)
cu_seqlens_k = flash_attn_padding_info['cu_seqlens_k'].to(key.device)
max_seqlen_q = flash_attn_padding_info['max_seqlen_q']
max_seqlen_k = flash_attn_padding_info['max_seqlen_k']

Expand Down Expand Up @@ -667,6 +669,10 @@ def _apply_rotary_embeddings(
else:
(cos, sin) = rotary_emb(x=value, seq_len=seq_len)
if is_transformers_version_gte('4.38'):
# In the following lines we move the cos and sin tensors to the same devices as query. These operations should be no-ops during training.
# This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204
cos = cos.to(query.device)
sin = sin.to(query.device)
query, key = apply_rotary_pos_emb(
q=query,
k=key,
Expand Down
4 changes: 3 additions & 1 deletion llmfoundry/models/layers/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,9 @@ def forward(
m = self.norm_2(x)

n = self.apply_ffn(attention_mask, m)
x = x + self.resid_ffn_dropout(n)
# In the following line we move the `x` tensor to the same devices as the output of ffn layer. This operation should be a no-op during training.
# This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204
x = x.to(device=n.device) + self.resid_ffn_dropout(n)
return x, attn_weights, past_key_value

def apply_ffn(
Expand Down
6 changes: 5 additions & 1 deletion llmfoundry/models/layers/dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,13 @@ def forward(

expert_tokens = x[None, token_list].reshape(-1, hidden_size)
mlp_output = self.mlp(expert_tokens, expert_idx)
# In the following lines we move tensors to the same devices as the output of mlp. These operations should be no-ops during training.
# This is done to fix pipeline parallel generation using hf.generate. Please see this comment for details: https://github.com/mosaicml/llm-foundry/pull/1332#issue-2386827204
expert_weights = expert_weights.to(mlp_output.device)
expert_out = mlp_output * expert_weights[token_list, topk_list,
None]

out = out.to(mlp_output.device)
token_idx = token_idx.to(mlp_output.device)
out.index_add_(0, token_idx, expert_out)

out = out.view(in_shape)
Expand Down

0 comments on commit 199c3b9

Please sign in to comment.