Skip to content

Commit

Permalink
Make tensor split contiguous (#6580) (#6593)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <aklife97@gmail.com>
Co-authored-by: Abhinav Khattar <aklife97@gmail.com>
  • Loading branch information
2 people authored and yaoyu-33 committed May 26, 2023
1 parent 13d0895 commit c85c30c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ def forward(
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3)
(query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(
mixed_x_layer, 3, contiguous_split_chunks=True
)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer, _ = self.key_value(encoder_output)
Expand All @@ -395,7 +397,9 @@ def forward(
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)

# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2)
(key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(
mixed_kv_layer, 2, contiguous_split_chunks=True
)

# Attention head [sq, b, h] --> [sq, b, hp]
query_layer, _ = self.query(hidden_states)
Expand Down

0 comments on commit c85c30c

Please sign in to comment.