Skip to content

Commit

Permalink
Add no_flash_attn option
Browse files Browse the repository at this point in the history
  • Loading branch information
oobabooga committed Nov 2, 2023
1 parent aaf726d commit 77abd9b
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ Optionally, you can use the following command-line flags:
|`--gpu-split` | Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7. |
|`--max_seq_len MAX_SEQ_LEN` | Maximum sequence length. |
|`--cfg-cache` | ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama. |
|`--no_flash_attn` | Force flash-attention to not be used. |

#### AutoGPTQ

Expand Down
1 change: 1 addition & 0 deletions modules/exllamav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def from_pretrained(self, path_to_model):
config.max_seq_len = shared.args.max_seq_len
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn

model = ExLlamaV2(config)

Expand Down
1 change: 1 addition & 0 deletions modules/exllamav2_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,5 +152,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
config.max_seq_len = shared.args.max_seq_len
config.scale_pos_emb = shared.args.compress_pos_emb
config.scale_alpha_value = shared.args.alpha_value
config.no_flash_attn = shared.args.no_flash_attn

return Exllamav2HF(config)
1 change: 1 addition & 0 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@
parser.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.')
parser.add_argument('--max_seq_len', type=int, default=2048, help='Maximum sequence length.')
parser.add_argument('--cfg-cache', action='store_true', help='ExLlama_HF: Create an additional cache for CFG negative prompts. Necessary to use CFG with that loader, but not necessary for CFG with base ExLlama.')
parser.add_argument('--no_flash_attn', action='store_true', help='Force flash-attention to not be used.')

# AutoGPTQ
parser.add_argument('--triton', action='store_true', help='Use triton.')
Expand Down

0 comments on commit 77abd9b

Please sign in to comment.