-
Notifications
You must be signed in to change notification settings - Fork 523
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Organize our collation utils #1261
Comments
Can you get around using dictionaries if that's the expected return from the datasets and the expected input into models? |
I guess the idea would be to have a general utility which applies a choice of padding function to a dict? The padding logic would be abstracted here |
I'm going to bump this because I'd like to right-pad outputs from a dataset which are of the form: {'tokens': [920,
508,
10110,
13047,
5786,
1371,
12566,
592,
2750,
10110,
278,
615,
...],
'labels': 1} i.e. for a classification task. Similar to |
One proposal could be: We define def left_pad_sequence(List[torch.Tensor]) -> torch.Tensor:
...
def right_pad_sequence(List[torch.Tensor]) -> torch.Tensor:
... Then, we have task-specific pad functionality: def padded_collate_sft(
batch: List[Dict[str, List[int]]],
pad_fn: Callable[List[Any], torch.Tensor] = right_pad_sequence
padding_idx: int = 0,
ignore_idx: int = CROSS_ENTROPY_IGNORE_IDX
) -> Dict[str, torch.Tensor]:
...
def padded_collate_dpo(
...
):
... i.e. largely the same but paramterize the pad sequence. This won't be exposed at the config level - it's just for recipe-writers and hackers. from torchtune.., import right_padded_collate
from operator import itemgetter
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=4,
collate_fn=lambda batch:
(right_padded_collate(map(itemgetter("tokens"), batch)),
torch.Tensor(list(map(itemgetter("labels"), batch))))) for classifier recipes, or even: from torchtune.modules.rlhf.collate import left_padded_collate
from operator import itemgetter
from torch.utils.data import DataLoader
dl = DataLoader(ds, batch_size=4,
collate_fn=lambda batch: left_padded_collate(map(itemgetter("tokens"), batch))
) for PPO. I understand if people hate this and we could just define these as functions like normal people. thoughts appreciated @RdoubleA @ebsmothers @joecummings |
Having utilities for left pad, right pad, and separate utilities for task specific collating makes sense. But do you need to make the pad_fn configurable for each of these? will users ever do left padding instead of right for SFT for example, and likewise for DPO? All the collate utils should also just go under |
🚫🚫 Yeah no need to make them configurable, you're right.
|
see #1005 for some context. From @ebsmothers (and @joecummings)@
The text was updated successfully, but these errors were encountered: