File tree Expand file tree Collapse file tree 2 files changed +2
-20
lines changed Expand file tree Collapse file tree 2 files changed +2
-20
lines changed Original file line number Diff line number Diff line change 66#
77# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88
9- import functools
109from typing import Callable , ClassVar
1110
1211import torch
1312import torch .nn .functional as F
14- from torch .distributed .tensor .experimental ._attention import create_cp_block_mask
1513from torch .nn .attention import sdpa_kernel , SDPBackend
1614from torch .nn .attention .flex_attention import (
1715 _mask_mod_signature ,
@@ -241,18 +239,5 @@ def build_attention(
241239 return ScaledDotProductAttention (attn_mask_type )
242240
243241
244- def init_attention_mask (
245- batch : torch .Tensor ,
246- eos_id : int | None ,
247- cp_mesh : torch .distributed .device_mesh .DeviceMesh | None = None ,
248- ) -> None :
249-
250- # This is not functional yet because we currently gate the use of Flex + CP
251- # while we continue debugging accuracy issues. However, we want to evaluate
252- # the user experience with CP enabled.
253- if cp_mesh is not None :
254- FlexAttention .compiled_create_block_mask = functools .partial (
255- create_cp_block_mask , device_mesh = cp_mesh
256- )
257-
242+ def init_attention_mask (batch : torch .Tensor , eos_id : int | None ) -> None :
258243 FlexAttention .init_attention_mask (batch , eos_id )
Original file line number Diff line number Diff line change @@ -416,10 +416,7 @@ def forward_backward_step(
416416 extra_inputs = {k : v for k , v in input_dict .items () if k != "input" }
417417 # Create the FlexAttention mask according to the input
418418 if getattr (self .model_args , "use_flex_attn" , False ):
419- cp_mesh = (
420- parallel_dims .world_mesh ["cp" ] if parallel_dims .cp_enabled else None
421- )
422- init_attention_mask (inputs , self .tokenizer .eos_id , cp_mesh )
419+ init_attention_mask (inputs , self .tokenizer .eos_id )
423420
424421 # apply context parallelism if cp is enabled
425422 # ensure CP handles the separate freqs_cis buffer for each pp stage
You can’t perform that action at this time.
0 commit comments