Skip to content

Commit 96149f6

Browse files
authored
Remove create_cp_block_mask (#1798)
This code logic was committed prematurely as the final CP UX may be different. Remove it for now to avoid confusion and future BC issues.
1 parent 4409c13 commit 96149f6

File tree

2 files changed

+2
-20
lines changed

2 files changed

+2
-20
lines changed

torchtitan/models/attention.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9-
import functools
109
from typing import Callable, ClassVar
1110

1211
import torch
1312
import torch.nn.functional as F
14-
from torch.distributed.tensor.experimental._attention import create_cp_block_mask
1513
from torch.nn.attention import sdpa_kernel, SDPBackend
1614
from torch.nn.attention.flex_attention import (
1715
_mask_mod_signature,
@@ -241,18 +239,5 @@ def build_attention(
241239
return ScaledDotProductAttention(attn_mask_type)
242240

243241

244-
def init_attention_mask(
245-
batch: torch.Tensor,
246-
eos_id: int | None,
247-
cp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
248-
) -> None:
249-
250-
# This is not functional yet because we currently gate the use of Flex + CP
251-
# while we continue debugging accuracy issues. However, we want to evaluate
252-
# the user experience with CP enabled.
253-
if cp_mesh is not None:
254-
FlexAttention.compiled_create_block_mask = functools.partial(
255-
create_cp_block_mask, device_mesh=cp_mesh
256-
)
257-
242+
def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None:
258243
FlexAttention.init_attention_mask(batch, eos_id)

torchtitan/train.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,7 @@ def forward_backward_step(
416416
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
417417
# Create the FlexAttention mask according to the input
418418
if getattr(self.model_args, "use_flex_attn", False):
419-
cp_mesh = (
420-
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
421-
)
422-
init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh)
419+
init_attention_mask(inputs, self.tokenizer.eos_id)
423420

424421
# apply context parallelism if cp is enabled
425422
# ensure CP handles the separate freqs_cis buffer for each pp stage

0 commit comments

Comments
 (0)