From baa0b0698c56cdac30a6204bd5738bfa1bafab40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 6 Nov 2025 23:01:59 +0000 Subject: [PATCH] consistency: relative imports --- .../grpo_with_replay_buffer_config.py | 2 +- .../grpo_with_replay_buffer_trainer.py | 11 +++-------- trl/experimental/gspo_token/grpo_trainer.py | 3 +-- trl/mergekit_utils.py | 2 +- trl/rewards/accuracy_rewards.py | 2 +- trl/trainer/nash_md_config.py | 2 +- 6 files changed, 8 insertions(+), 14 deletions(-) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py index 37a4341f27b..6f0b0381bef 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field -from trl.trainer.grpo_config import GRPOConfig +from ...trainer.grpo_config import GRPOConfig @dataclass diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 178bb164976..e5c44710123 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -18,14 +18,9 @@ import torch from accelerate.utils import gather_object -from trl.data_utils import ( - apply_chat_template, - is_conversational, - prepare_multimodal_messages, -) -from trl.trainer.grpo_trainer import GRPOTrainer -from trl.trainer.utils import nanmax, nanmin, nanstd, pad - +from ...data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages +from ...trainer.grpo_trainer import GRPOTrainer +from ...trainer.utils import nanmax, nanmin, nanstd, pad from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig diff --git a/trl/experimental/gspo_token/grpo_trainer.py b/trl/experimental/gspo_token/grpo_trainer.py index f267bbd4b62..62c124ab134 100644 --- a/trl/experimental/gspo_token/grpo_trainer.py +++ b/trl/experimental/gspo_token/grpo_trainer.py @@ -14,8 +14,7 @@ import torch -from trl import GRPOTrainer as _GRPOTrainer - +from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer from ...trainer.utils import nanmax, nanmin diff --git a/trl/mergekit_utils.py b/trl/mergekit_utils.py index fc9787b8f6b..d070a8dd923 100644 --- a/trl/mergekit_utils.py +++ b/trl/mergekit_utils.py @@ -15,7 +15,7 @@ import torch from huggingface_hub import HfApi -from trl.import_utils import is_mergekit_available +from .import_utils import is_mergekit_available if is_mergekit_available(): diff --git a/trl/rewards/accuracy_rewards.py b/trl/rewards/accuracy_rewards.py index 1ae7d21426d..f6d45fca559 100644 --- a/trl/rewards/accuracy_rewards.py +++ b/trl/rewards/accuracy_rewards.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from trl.import_utils import is_math_verify_available +from ..import_utils import is_math_verify_available if is_math_verify_available(): diff --git a/trl/trainer/nash_md_config.py b/trl/trainer/nash_md_config.py index 07d8152f4fa..ddc653ee619 100644 --- a/trl/trainer/nash_md_config.py +++ b/trl/trainer/nash_md_config.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field -from trl.trainer.online_dpo_config import OnlineDPOConfig +from .online_dpo_config import OnlineDPOConfig @dataclass