diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index 5c2267a25e44..14a94fb141f2 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -371,7 +371,9 @@ def forward( mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) + (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( + mixed_x_layer, 3, contiguous_split_chunks=True + ) else: # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] mixed_kv_layer, _ = self.key_value(encoder_output) @@ -386,7 +388,9 @@ def forward( mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) + (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( + mixed_kv_layer, 2, contiguous_split_chunks=True + ) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states)