Skip to content

Commit 22170b0

Browse files
committed
[RFC] Lift freqs_cis as an input of models
freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness. ghstack-source-id: 49e4ec0 Pull-Request-resolved: #1797
1 parent 31bc306 commit 22170b0

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

torchtitan/models/llama3/model/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
5656
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
5757
for the purpose of broadcasting the frequency tensor during element-wise operations.
5858
59-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
60-
and the first seqlen elements will be sliced, but dim must match x.
59+
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).
6160
6261
Args:
6362
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
@@ -68,10 +67,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
6867
"""
6968
ndim = x.ndim
7069
assert ndim > 1
70+
batch_size = x.shape[0]
7171
seqlen = x.shape[1]
72-
freqs_cis = freqs_cis[0:seqlen]
73-
assert freqs_cis.shape == (seqlen, x.shape[-1])
74-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
72+
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
73+
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
7574
return freqs_cis.view(*shape)
7675

7776

@@ -437,9 +436,18 @@ def get_attention_masks(
437436
mask_mod, B, None, input_batch.shape[1], input_batch.shape[1]
438437
)
439438

439+
def get_order_sensitive_buffers(
440+
self,
441+
batch_size: int,
442+
seq_len: int,
443+
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
444+
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
445+
return ((freqs_cis,), (1,))
446+
440447
def forward(
441448
self,
442449
tokens: torch.Tensor,
450+
freqs_cis: torch.Tensor,
443451
attention_masks: AttentionMasksType | None = None,
444452
input_batch: torch.Tensor | None = None,
445453
):
@@ -464,7 +472,7 @@ def forward(
464472
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
465473

466474
for layer in self.layers.values():
467-
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
475+
h = layer(h, freqs_cis, attention_masks=attention_masks)
468476

469477
h = self.norm(h) if self.norm else h
470478
output = self.output(h) if self.output else h

torchtitan/protocols/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ def get_attention_masks(
7070
raise NotImplementedError(
7171
"This model does not support attention masking/Flex Attention."
7272
)
73+
74+
def get_order_sensitive_buffers(
75+
self,
76+
batch_size: int,
77+
seq_len: int,
78+
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
79+
raise NotImplementedError()
80+
return ((), ())

torchtitan/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,11 @@ def forward_backward_step(
425425
else None
426426
)
427427

428+
# Get the order sensitive buffers
429+
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
430+
inputs.size(0), inputs.size(1)
431+
)
432+
428433
# apply context parallelism if cp is enabled
429434
# ensure CP handles the separate freqs_cis buffer for each pp stage
430435
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
@@ -449,6 +454,7 @@ def forward_backward_step(
449454
if self.pp_has_first_stage:
450455
self.pp_schedule.step(
451456
inputs,
457+
*order_sensitive_buffers[0],
452458
**extra_inputs,
453459
attention_masks=attention_masks,
454460
target=targets,
@@ -457,6 +463,7 @@ def forward_backward_step(
457463
)
458464
else:
459465
self.pp_schedule.step(
466+
*order_sensitive_buffers[0],
460467
attention_masks=attention_masks,
461468
target=targets,
462469
losses=losses,
@@ -479,7 +486,10 @@ def forward_backward_step(
479486
assert len(model_parts) == 1
480487
with self.maybe_enable_amp:
481488
pred = model_parts[0](
482-
inputs, **extra_inputs, attention_masks=attention_masks
489+
inputs,
490+
*order_sensitive_buffers[0],
491+
**extra_inputs,
492+
attention_masks=attention_masks,
483493
)
484494
loss = self.loss_fn(pred, labels)
485495
# need to free pred before bwd to avoid peaking memory

0 commit comments

Comments
 (0)