Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add functionality to push best models to the hub during training #275

Merged
merged 6 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tests/test_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,3 +875,38 @@ def test_generation(self):
generations_single = tokenizer.batch_decode(generations_single)

self.assertEqual(generations_single, generations_batched)

@unittest.skip("Fix by either patching `whomai()` to work in the staging endpoint or use a dummy prod user.")
def test_push_to_hub_if_best_reward(self):
REPO_NAME = "test-ppo-trainer"
repo_id = f"{CI_HUB_USER}/{REPO_NAME}"

dummy_dataset = self._init_dummy_dataset()

push_to_hub_if_best_kwargs = {"repo_id": repo_id}

ppo_config = PPOConfig(
batch_size=2,
mini_batch_size=1,
log_with=None,
push_to_hub_if_best_kwargs=push_to_hub_if_best_kwargs,
compare_steps=1,
)

ppo_trainer = PPOTrainer(
config=ppo_config,
model=self.gpt2_model,
ref_model=self.gpt2_model_ref,
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)

dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
# train model
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break
8 changes: 8 additions & 0 deletions trl/trainer/ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ class PPOConfig(object):
target_kl: Optional[float] = field(
default=0.1, metadata={"help": "Stop early if we exceed this value by over 50%"}
)
push_to_hub_if_best_kwargs: Optional[dict] = field(
default_factory=dict,
metadata={"help": "Keyword arguments for pushing model to the hub during training (e.g. repo_id)"},
)
compare_steps: Optional[int] = field(
default=1,
metadata={"help": "Number of steps between comparison of the current reward with the best seen so far"},
)

def __post_init__(self):
if self.forward_batch_size is not None:
Expand Down
19 changes: 19 additions & 0 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,14 @@ def __init__(
# init the current step
self.current_step = 0

# init variables for pushing model to hub
if config.push_to_hub_if_best_kwargs:
if "repo_id" not in config.push_to_hub_if_best_kwargs:
raise ValueError("You have to specify repo_id in order to push the model to the hub!")
self.push_to_hub_kwargs = config.push_to_hub_if_best_kwargs
self.compare_step = 0
self.highest_reward = torch.tensor(-float("inf"))

# post process for PP
if not getattr(self.model, "is_sequential_parallel", False):
self.current_device = self.accelerator.device
Expand Down Expand Up @@ -540,6 +548,17 @@ def step(

queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)

# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = torch.tensor(scores).mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
# push model to hub
self.push_to_hub(**self.push_to_hub_kwargs)
self.compare_step += 1

timing = dict()
t0 = time.time()

Expand Down