diff --git a/torchchat/cli/builder.py b/torchchat/cli/builder.py index 38d0e33b2..79046c377 100644 --- a/torchchat/cli/builder.py +++ b/torchchat/cli/builder.py @@ -69,6 +69,7 @@ class BuilderArgs: prefill_possible: bool = False dynamic_shapes: bool = False max_seq_length: Optional[int] = None + attention_backend: str = "math" def __post_init__(self): if self.device is None: @@ -178,6 +179,17 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": pp = getattr(args, "pp", 1) tp = getattr(args, "tp", 1) chpt_from = getattr(args, "chpt_from", "hf") + sdp_backend_dict = { + 'math': torch.nn.attention.SDPBackend.MATH, + 'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION, + 'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION, + 'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION, + } + attention_backend = sdp_backend_dict[args.attention_backend] + if args.device == "cpu" and (args.attention_backend == "efficient_attention" + or args.attention_backend == "cudnn_attention"): + print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.") + attention_backend = torch.nn.attention.SDPBackend.MATH return cls( checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path, @@ -202,6 +214,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs": is_chat_model=is_chat_model, dynamic_shapes=getattr(args, "dynamic_shapes", False), max_seq_length=getattr(args, "max_seq_length", None), + attention_backend=attention_backend, ) @classmethod diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index 91bdcaf26..427edc452 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -179,6 +179,13 @@ def _add_model_config_args(parser, verb: str) -> None: choices=["fast", "cpu", "cuda", "mps"], help="Hardware device to use. Options: fast, cpu, cuda, mps", ) + model_config_parser.add_argument( + "--attention-backend", + type=str, + default="math", + choices=["math", "flash_attention", "efficient_attention", "cudnn_attention"], + help="SDPBackend to use. Options: MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION", + ) # Add CLI Args representing output paths of exported model files diff --git a/torchchat/generate.py b/torchchat/generate.py index e271f5027..bf3f4cea2 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -26,6 +26,7 @@ import torch.distributed as dist import torch.multiprocessing as mp from torch.distributed.pipelining import PipelineStage, ScheduleGPipe +from torch._C import _SDPBackend as SDPBackend from PIL import Image @@ -531,6 +532,7 @@ def decode_n_tokens( callback=lambda _: _, eos_token_id: int = 2, eot_id: Optional[int] = None, + attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH, **sampling_kwargs, ): new_tokens, new_probs = [], [] @@ -539,7 +541,7 @@ def decode_n_tokens( num_new_tokens - 1 ): # -1 to save space to run an EoS if dont generate it naturally # Actually better for Inductor to codegen attention here - with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]): + with torch.nn.attention.sdpa_kernel([attention_backend]): out_token = cur_token.clone() next_token, next_prob = self.decode_one_token( @@ -683,6 +685,7 @@ def generate( sequential_prefill=True, callback=lambda x: x, max_seq_length: int, + attention_backend: str = "math", seed: Optional[int] = None, **sampling_kwargs, ) -> torch.Tensor: @@ -799,6 +802,7 @@ def generate( if self.is_llama3_model else None ), + attention_backend=attention_backend, **sampling_kwargs, ): generated_tokens.append(generated_token.view(-1)) @@ -1186,6 +1190,7 @@ def callback(x, *, done_generating=False): start_pos=start_pos, skip_cache_setup=not is_first_sample, max_seq_length=max_seq_length, + attention_backend=self.builder_args.attention_backend, ) for token_tensor, metrics in generator_func: if token_tensor is not None: