From fa56a6e164282f2427012d44d71231fd1f29cf00 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 22 Oct 2025 15:54:50 +0200 Subject: [PATCH] Fix attn_implementation name for OnlineDPO depending on transformers version --- trl/trainer/online_dpo_trainer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index a35955980aa..f74abf2c27e 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -26,6 +26,7 @@ import torch.nn as nn import torch.nn.functional as F import torch.utils.data +import transformers from accelerate import logging from accelerate.utils import broadcast_object_list, gather_object, is_peft_model from datasets import Dataset @@ -1049,10 +1050,11 @@ def _generate(self, model, prompts, images=None): if self.use_transformers_paged: previous_attn = self.model_wrapped.config._attn_implementation - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + if version.parse(transformers.__version__).release >= version.parse("5.0.0").release: + new_attn = "paged|flash_attention_2" if is_flash_attn_2_available() else "paged|sdpa" else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" + new_attn = "paged_attention" if is_flash_attn_2_available() else "sdpa_paged" + self.model_wrapped.config._attn_implementation = new_attn with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation(