Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

💾 Deprecate config in favor of args in PPOTrainer #2384

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/research_projects/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def solution():
optimize_cuda_cache=True,
)

ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
ppo_trainer = PPOTrainer(args=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)

# text env
Expand Down
2 changes: 1 addition & 1 deletion examples/research_projects/tools/triviaqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class ScriptArguments:
seed=script_args.seed,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer)
ppo_trainer = PPOTrainer(args=config, model=model, tokenizer=tokenizer)
dataset = load_dataset("mandarjoshi/trivia_qa", "rc", split="train")
local_seed = script_args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
dataset = dataset.shuffle(local_seed)
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def tokenize(element):
# Training
################
trainer = PPOTrainer(
config=training_args,
args=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def tokenize(element):
# Training
################
trainer = PPOTrainer(
config=training_args,
args=training_args,
processing_class=tokenizer,
policy=policy,
ref_policy=ref_policy,
Expand Down
13 changes: 7 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,21 @@
from ..core import masked_mean, masked_whiten
from ..models import create_reference_model
from ..models.utils import unwrap_model_for_generation
from ..trainer.utils import (
from .ppo_config import PPOConfig
from .utils import (
OnlineTrainerState,
batch_generation,
disable_dropout_in_model,
exact_div,
first_true_indices,
forward,
generate_model_card,
get_reward,
peft_module_casting_to_bf16,
prepare_deepspeed,
print_rich_table,
truncate_response,
)
from .ppo_config import PPOConfig
from .utils import generate_model_card, peft_module_casting_to_bf16


if is_peft_available():
Expand Down Expand Up @@ -97,10 +98,11 @@ def forward(self, **kwargs):
class PPOTrainer(Trainer):
_tag_names = ["trl", "ppo"]

@deprecate_kwarg("config", new_name="args", version="0.15.0", raise_if_both_names=True)
@deprecate_kwarg("tokenizer", new_name="processing_class", version="0.15.0", raise_if_both_names=True)
def __init__(
self,
config: PPOConfig,
args: PPOConfig,
processing_class: Optional[
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
],
Expand All @@ -122,8 +124,7 @@ def __init__(
"same as `policy`, you must make a copy of it, or `None` if you use peft."
)

self.args = config
args = config
self.args = args
self.processing_class = processing_class
self.policy = policy

Expand Down
Loading