diff --git a/examples/scripts/ppo/ppo.py b/examples/scripts/ppo/ppo.py index d75af791b1..07e65f331f 100644 --- a/examples/scripts/ppo/ppo.py +++ b/examples/scripts/ppo/ppo.py @@ -9,8 +9,7 @@ HfArgumentParser, ) -from trl import ModelConfig -from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer +from trl import ModelConfig, PPOv2Config, PPOv2Trainer from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE diff --git a/examples/scripts/ppo/ppo_tldr.py b/examples/scripts/ppo/ppo_tldr.py index 146194e3d8..8292414426 100644 --- a/examples/scripts/ppo/ppo_tldr.py +++ b/examples/scripts/ppo/ppo_tldr.py @@ -9,8 +9,7 @@ HfArgumentParser, ) -from trl import ModelConfig -from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer +from trl import ModelConfig, PPOv2Config, PPOv2Trainer from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE diff --git a/trl/__init__.py b/trl/__init__.py index d63b6b27a9..c6b30ac364 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -56,6 +56,8 @@ "ORPOTrainer", "PPOConfig", "PPOTrainer", + "PPOv2Config", + "PPOv2Trainer", "RewardConfig", "RewardTrainer", "SFTConfig", @@ -144,6 +146,8 @@ ORPOTrainer, PPOConfig, PPOTrainer, + PPOv2Config, + PPOv2Trainer, RewardConfig, RewardTrainer, SFTConfig, diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 77f3162aef..4ddd475ecc 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -47,6 +47,8 @@ "orpo_trainer": ["ORPOTrainer"], "ppo_config": ["PPOConfig"], "ppo_trainer": ["PPOTrainer"], + "ppov2_config": ["PPOv2Config"], + "ppov2_trainer": ["PPOv2Trainer"], "reward_config": ["RewardConfig"], "reward_trainer": ["RewardTrainer", "compute_accuracy"], "sft_config": ["SFTConfig"], @@ -112,6 +114,8 @@ from .orpo_trainer import ORPOTrainer from .ppo_config import PPOConfig from .ppo_trainer import PPOTrainer + from .ppov2_config import PPOv2Config + from .ppov2_trainer import PPOv2Trainer from .reward_config import RewardConfig from .reward_trainer import RewardTrainer, compute_accuracy from .sft_config import SFTConfig