Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix deepspeed for PPOv2Trainer #2080

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
from transformers.integrations import get_reporting_integration_callbacks
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
from transformers.trainer_callback import CallbackHandler, PrinterCallback
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback

from ..core import masked_mean, masked_whiten
from ..models.utils import unwrap_model_for_generation
Expand Down Expand Up @@ -174,17 +174,20 @@ def __init__(
#########
### trainer specifics
#########
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
)
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
self.callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
self.control = TrainerControl()
self.state = OnlineTrainerState(
is_local_process_zero=self.is_local_process_zero(),
is_world_process_zero=self.is_world_process_zero(),
stateful_callbacks=[
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
],
)
self.current_flos = 0
self.hp_search_backend = None
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
Expand Down Expand Up @@ -311,6 +314,11 @@ def repeat_generator():
self.state.save_steps = args.save_steps
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

# backward compatibility
if self.is_deepspeed_enabled:
self.deepspeed = self.model
self.model_wrapped = self.model

for update in range(1, args.num_total_batches + 1):
self.state.episode += 1 * args.batch_size
data = next(iter_dataloader)
Expand Down
Loading