Skip to content

Commit

Permalink
Promote PPOv2Trainer and PPOv2Config to top-level import
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Sep 4, 2024
1 parent fc20db8 commit 96ae02a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
3 changes: 1 addition & 2 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
HfArgumentParser,
)

from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl import ModelConfig, PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


Expand Down
3 changes: 1 addition & 2 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
HfArgumentParser,
)

from trl import ModelConfig
from trl.trainer.ppov2_trainer import PPOv2Config, PPOv2Trainer
from trl import ModelConfig, PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE


Expand Down
4 changes: 4 additions & 0 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
"ORPOTrainer",
"PPOConfig",
"PPOTrainer",
"PPOv2Config",
"PPOv2Trainer",
"RewardConfig",
"RewardTrainer",
"SFTConfig",
Expand Down Expand Up @@ -144,6 +146,8 @@
ORPOTrainer,
PPOConfig,
PPOTrainer,
PPOv2Config,
PPOv2Trainer,
RewardConfig,
RewardTrainer,
SFTConfig,
Expand Down
4 changes: 4 additions & 0 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
"orpo_trainer": ["ORPOTrainer"],
"ppo_config": ["PPOConfig"],
"ppo_trainer": ["PPOTrainer"],
"ppov2_config": ["PPOv2Config"],
"ppov2_trainer": ["PPOv2Trainer"],
"reward_config": ["RewardConfig"],
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
"sft_config": ["SFTConfig"],
Expand Down Expand Up @@ -112,6 +114,8 @@
from .orpo_trainer import ORPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer
from .ppov2_config import PPOv2Config
from .ppov2_trainer import PPOv2Trainer
from .reward_config import RewardConfig
from .reward_trainer import RewardTrainer, compute_accuracy
from .sft_config import SFTConfig
Expand Down

0 comments on commit 96ae02a

Please sign in to comment.