diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 217b96e19..4cd830581 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -114,7 +114,11 @@ def load_model( if cfg.flash_attention: replace_btlm_attn_with_flash_attn() - if hasattr(model_config, "model_type") and model_config.model_type == "falcon": + if hasattr(model_config, "model_type") and model_config.model_type in [ + "falcon", + "RefinedWebModel", + "RefinedWeb", + ]: if cfg.flash_attention: replace_falcon_attn_with_flash_attn()