Skip to content

Commit

Permalink
move all the logic for picking preference pairs to torch, to ready fo…
Browse files Browse the repository at this point in the history
…r batching the candidate responses generation. knock off a todo for customizing sampling of preference pairs strat
  • Loading branch information
lucidrains committed Jan 28, 2024
1 parent 22e0c21 commit fd06041
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 40 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ trainer = SelfRewardingTrainer(
- [x] generalize the sampling so that it can progress at different positions in the batch, fix all sampling to be batched. also allow for left padded sequences, in the case some people have transformers with relative positions that allow for that
- [x] handle eos
- [x] show an example for using your own reward prompt instead of default llm-as-judge
- [x] allow for different strategies for sampling the pairs

- [ ] allow for a validation function on the rewards (say reward must be integer, float, in between some range etc)
- [ ] early stopper should accept an evaluation module that takes in the model and outputs a score, should also accept whether to do it distributed or all on the main process (in which case distributed break signal needs to be handled correctly)
- [ ] figure out how best to handle different impl of kv cache, for now just do without
- [ ] allow for different strategies for sampling the pairs
- [ ] consider KTO
- [ ] any order of sft, spin, self-rewarding dpo, dpo with external reward model
- [ ] allow for a validation function on the rewards (say reward must be integer, float, in between some range etc)

## Citation

Expand Down
3 changes: 3 additions & 0 deletions self_rewarding_lm_pytorch/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ def __init__(
def is_main(self):
return self.accelerator.is_main_process

def print(self, *msg):
self.accelerator.print(*msg)

def wait(self):
return self.accelerator.wait_for_everyone()

Expand Down
78 changes: 41 additions & 37 deletions self_rewarding_lm_pytorch/self_rewarding_lm_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from random import randrange
from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass
Expand Down Expand Up @@ -196,6 +197,14 @@ def init(self):
)
)

@beartype
def default_pick_paired_rewards_fn(rewards: Tensor):
is_nan_mask = torch.isnan(rewards)
rewards_max, rewards_min = rewards.clone(), rewards.clone()
rewards_max[is_nan_mask] = -1e6
rewards_min[is_nan_mask] = 1e6
return torch.stack((rewards_max.argmax(dim = -1), rewards_min.argmin(dim = -1)))

# sft trainer

class SFTTrainer(Module):
Expand Down Expand Up @@ -354,6 +363,7 @@ def __init__(
preference_max_seq_len: int = 1024,
generate_reward_max_seq_len: int = 256,
is_valid_reward_pair: Optional[Callable[[float, float], bool]] = None,
pick_paired_rewards: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn,
pad_id: int = -1
):
super().__init__()
Expand All @@ -373,13 +383,14 @@ def __init__(
self.eval_nucleus_p = eval_nucleus_p
self.eval_temperature = eval_temperature

self.tokenizer_encode = cast_input(lambda t: t.long())(tokenizer_encode)
self.tokenizer_decode = cast_output(lambda t: t.long())(tokenizer_decode)
self.tokenizer_encode = cast_output(lambda t: t.long())(tokenizer_encode)
self.tokenizer_decode = cast_input(lambda t: t.long() if torch.is_tensor(t) else [*map(int, t)])(tokenizer_decode)

self.num_evals_to_average = num_evals_to_average

# logic for sampling the reward pair and to validate it before adding it to generated preference dataset

self.pick_paired_rewards = pick_paired_rewards
self.is_valid_reward_pair = default(is_valid_reward_pair, lambda *args: True)

# shapes and padding
Expand Down Expand Up @@ -501,33 +512,32 @@ def forward(self) -> DPODataset:

rewards: List[Optional[float]] = [self.generate_reward(prompt, response) for response in candidate_responses]

# zip together the responses and rewards and filter out if reward is not generated correctly
# turn rewards into a Tensor

paired_reward_response = [(reward, candidate_response) for reward, candidate_response in zip(rewards, candidate_tensor_responses)]
rewards_tensor = Tensor([default(reward, float('nan')) for reward in rewards])

paired_reward_response = [*filter(lambda pair: exists(first(pair)), paired_reward_response)]
paired_reward_response.sort(key = first)
preference_pair_indices = self.pick_paired_rewards(rewards_tensor)

if len(paired_reward_response) < 2:
continue
# pick out the max and min reward values

unpreferred_reward, unpreferred_response = paired_reward_response[0]
preferred_reward, preferred_response = paired_reward_response[1]
paired_rewards = rewards_tensor[preference_pair_indices]

if not self.is_valid_reward_pair(preferred_reward, unpreferred_reward):
break
# pick out the preferred and unpreferred response

memmap_idx = num_generated
candidate_tensor_responses = pad_sequence(candidate_tensor_responses, batch_first = True, padding_value = self.pad_id)
paired_preference_responses = candidate_tensor_responses[preference_pair_indices]

paired_responses = pad_sequence((preferred_response, unpreferred_response), padding_value = self.pad_id, batch_first = True)
if not self.is_valid_reward_pair(*paired_rewards.unbind(dim = -1)):
break

paired_responses_with_prompt = torch.cat((repeated_prompt_tensor[:2], paired_responses), dim = -1)
memmap_idx = num_generated

paired_responses_with_prompt = torch.cat((repeated_prompt_tensor[:2], paired_preference_responses), dim = -1)
paired_responses_with_prompt = pad_or_slice_to(paired_responses_with_prompt, self.preference_max_seq_len, dim = -1, pad_value = self.pad_id)

self.prompt_len_memmap[memmap_idx] = prompt_len
self.preference_seq_memmap[memmap_idx] = paired_responses_with_prompt.cpu().numpy()
self.self_reward_memmap_file[memmap_idx] = np.array([preferred_reward, unpreferred_reward])
self.self_reward_memmap_file[memmap_idx] = paired_rewards.cpu().numpy()

num_generated += 1
pbar.update(1)
Expand Down Expand Up @@ -585,7 +595,8 @@ def __init__(
dpo_trainer_kwargs: dict = dict(),
dropout: float = 0.1,
checkpoints_folder: str = './checkpoints',
is_valid_reward_pair: Optional[Callable[[float, float], bool]] = lambda preferred_reward, unpreferred_reward: preferred_reward != unpreferred_reward,
is_valid_reward_pair: Optional[Callable[[Tensor, Tensor], bool]] = lambda preferred_reward, unpreferred_reward: (preferred_reward != unpreferred_reward).all(),
pick_paired_rewards: Callable[[Tensor], Tensor] = default_pick_paired_rewards_fn,
pad_id: int = -1
):
super().__init__()
Expand Down Expand Up @@ -663,6 +674,7 @@ def __init__(
tokenizer_encode = tokenizer_encode,
tokenizer_decode = tokenizer_decode,
is_valid_reward_pair = is_valid_reward_pair,
pick_paired_rewards = pick_paired_rewards,
**reward_generator_kwargs
) for reward_config, one_stage_num_preference_pairs in zip(self.reward_prompt_configs, num_preference_pairs)
]
Expand Down Expand Up @@ -704,18 +716,21 @@ def wait(self):
return self.accelerator.wait_for_everyone()

def save(self, path: str, overwrite: bool = False):
if not self.accelerator.is_main_process:
return
self.wait()

path = self.checkpoints_folder / path
if self.accelerator.is_main_process:

assert not path.exists() or overwrite, f'file already exists'
path = self.checkpoints_folder / path

pkg = dict(
model = self.unwrapped_model.state_dict()
)
assert not path.exists() or overwrite, f'file already exists'

pkg = dict(
model = self.unwrapped_model.state_dict()
)

torch.save(pkg, str(path))

torch.save(pkg, str(path))
self.wait()

def forward(
self,
Expand All @@ -725,22 +740,15 @@ def forward(
if self.first_iterate_on_sft:
self.sft_trainer()

self.wait()

self.save('sft.ckpt.pt', overwrite = overwrite_checkpoints)

self.wait()

for ind, spin_trainer in enumerate(self.spin_trainers):
spin_cycle = ind + 1

spin_trainer()

self.wait()

self.save(f'spin.{spin_cycle}.ckpt.pt', overwrite = overwrite_checkpoints)

self.wait()

for ind, (dpo_dataset_generator, dpo_trainer) in enumerate(zip(self.dpo_dataset_generators, self.dpo_trainers)):

Expand All @@ -750,12 +758,8 @@ def forward(

dpo_trainer(dpo_dataset_from_self_reward)

self.wait()

self.dpo.update_reference_model_with_policy()

self.save(f'self-reward.{iterate_num}.ckpt.pt', overwrite = overwrite_checkpoints)

self.wait()

self.print(f'done')
self.print(f'self-reward training done')
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'self-rewarding-lm-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.20',
version = '0.0.22',
license='MIT',
description = 'Self Rewarding LM - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit fd06041

Please sign in to comment.