Skip to content

Commit 48b2a11

Browse files
committed
fix sink
1 parent bb8ee6f commit 48b2a11

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp
2323
from torchtitan.experiments.llama4.infra.parallelize import (
2424
apply_fsdp,
25-
apply_moe_ep_tp,
2625
)
2726

2827
from torchtitan.tools.logging import logger
@@ -231,30 +230,30 @@ def apply_non_moe_tp(
231230
layer_plan = {
232231
"attention_norm": SequenceParallel(),
233232
"attention": prepare_module_input(
234-
input_layouts=(Shard(1), Replicate()),
235-
desired_input_layouts=(Replicate(), Replicate()),
233+
input_layouts=(Shard(1), None),
234+
desired_input_layouts=(Replicate(), None),
236235
),
237-
"attention.wq": colwise_parallel(use_local_output=False),
238-
"attention.wk": colwise_parallel(use_local_output=False),
239-
"attention.wv": colwise_parallel(use_local_output=False),
236+
"attention.wq": colwise_parallel(),
237+
"attention.wk": colwise_parallel(),
238+
"attention.wv": colwise_parallel(),
240239
"attention.wo": rowwise_parallel(output_layouts=Shard(1)),
241240
"ffn_norm": SequenceParallel(),
242241
}
243242

243+
# shard attention.sinks across heads
244+
# TODO(jianiw): Fix the sink implementation for nn.Parameter
245+
attn = transformer_block.attention
246+
attn.register_parameter(
247+
"sinks",
248+
nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])),
249+
)
250+
244251
parallelize_module(
245252
module=transformer_block,
246253
device_mesh=tp_mesh,
247254
parallelize_plan=layer_plan,
248255
)
249256

250-
# shard attention.sinks across heads
251-
# TODO(jianiw): Fix the sink implementation
252-
# attn = transformer_block.attention
253-
# attn.register_parameter(
254-
# "sinks",
255-
# nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])),
256-
# )
257-
258257
if enable_async_tp:
259258
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
260259

torchtitan/experiments/gpt_oss/model/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,18 @@ def forward(
245245
k,
246246
v,
247247
scale=None,
248-
return_lse=False,
248+
return_lse=True,
249249
)
250250

251251
# Apply attention sink rescaling: rescale by σ(lse - w[h])
252252
# This is mathematically equivalent to concatenating learnable sink weights
253+
# TODO: If attention part is, but self.sinks are registered as a DTensor, while lse is a plain tensor
254+
# q, k, v are already sharded by TP: [batch, local_heads, seq_len, head_dim] (plain tensor)
255+
# sinks shape needs to match: [local_heads],
256+
# [rank0]:lse.shape torch.Size([8, 32, 2048]), <class 'torch.Tensor'>
253257
sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(
254258
-1
255-
) # [B,H,S,1]
259+
)
256260
output = output * sink_scale.to(output.dtype)
257261

258262
else:

0 commit comments

Comments
 (0)