Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🪜 Stepwise supervision dataset type #2148

Merged
merged 20 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/datasets/zen.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,74 @@ def main(test_size, push_to_hub, repo_id):
if push_to_hub:
standard_unpaired_preference_dataset.push_to_hub(repo_id, config_name="standard_unpaired_preference")

standard_step_dataset = Dataset.from_dict({
"prompt": [
"Beautiful is better than",
"Explicit is better than",
"Simple is better than",
"Complex is better than",
"Flat is better than",
"Sparse is better than",
"Readability counts",
"Special cases aren't special enough",
"Although practicality beats",
"Errors should never pass",
"In the face of ambiguity, refuse",
"There should be one-- and preferably only one --",
"Although that way may not be",
"Now is better than",
"Never is often better than",
"If the implementation is hard to explain, it's",
"If the implementation is easy to explain, it",
"Namespaces are one",
"Although practicality sometimes beats purity,",
],
"completion":[
[", let me think...", " ugly."],
[", of course,", " implicit.", " because clarity matters."],
["... let's keep it basic,", " complex."],
[" when needed,", " complicated."],
[" in terms of structure,", " nested."],
["... especially for readability."],
[" especially when others read it."],
[", unless...", " they follow the rules."],
[" some theoretical elegance,", " purity."],
[" silently,", " unless explicitly silenced."],
[" the temptation to guess."],
[" way to do it,"," but sometimes it's not obvious.", " especially when there's more than one possibility."],
[" clear at first,", " it will eventually emerge."],
[" later."],
[" problematic fixes."],
[" likely because it's too complicated."],
[" might be a good design."],
[" of those great ideas,", " that solve many problems."],
[" the code should still aim for balance."],
],
"label": [
[False, True],
[False, True, False],
[False, True],
[True, True],
[True, False],
[True],
[False],
[True, False],
[False, False],
[False, False],
[True],
[True, True, False],
[True, True],
[False],
[True], [False],
[False],
[True, True],
[False]
]
})
standard_step_dataset = standard_step_dataset.train_test_split(test_size=test_size)
if push_to_hub:
standard_step_dataset.push_to_hub(repo_id, config_name="standard_step")

conversational_language_modeling_dataset = Dataset.from_dict({
"messages": [
[{"role": "user", "content": "What is better than ugly?"}, {"role": "assistant", "content": "Beautiful."},],
Expand Down
50 changes: 50 additions & 0 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,49 @@ def is_conversational(example: Dict[str, Any]) -> bool:
return False


def _merge_same_role_messages(messages):
"""
Merge messages that share the same role into a single message.

Args:
messages (List[Dict[str, str]]): List of messages, where each message is a dictionary with keys "role" and
"content".

Returns:
List[Dict[str, str]]: List of messages with the same role merged into a single message.

Example:

```python
>>> messages = [
... {"role": "user", "content": "What color is the sky?"},
... {"role": "assistant", "content": "Let me think..."},
... {"role": "assistant", "content": "It is blue."}
... ]
>>> _merge_same_role_messages(messages)
[{'role': 'user', 'content': 'What color is the sky?'},
{'role': 'assistant', 'content': 'Let me think...\nIt is blue.'}]
```
"""
if not messages:
return []

merged_messages = [messages[0]] # Initialize with the first message

for i in range(1, len(messages)):
current_message = messages[i]
previous_message = merged_messages[-1]

if current_message["role"] == previous_message["role"]:
# Merge the content of the current message into the previous one
previous_message["content"] += "\n" + current_message["content"]
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved
else:
# Add the current message to the merged list if the roles are different
merged_messages.append(current_message)

return merged_messages


def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: PreTrainedTokenizer) -> Dict[str, str]:
r"""
Apply a chat template to a conversational example.
Expand All @@ -76,9 +119,15 @@ def apply_chat_template(example: Dict[str, List[Dict[str, str]]], tokenizer: Pre
{"prompt", "chosen", "rejected"}, # preference
{"chosen", "rejected"}, # preference with implicit prompt
{"prompt", "completion", "label"}, # unpaired preference
{"prompt", "completion", "labels"}, # [Name to find]
]:
raise KeyError(f"Invalid keys in the example: {example_keys}")

# Merge neighbors messages that share the same role
for key in ["messages", "prompt", "chosen", "rejected", "completion"]:
if key in example:
example[key] = _merge_same_role_messages(example[key])

# Apply the chat template to the whole conversation
if "messages" in example:
messages = tokenizer.apply_chat_template(example["messages"], tokenize=False)
Expand Down Expand Up @@ -155,6 +204,7 @@ def maybe_apply_chat_template(
- Preference dataset: `"prompt"`, `"chosen"`, and `"rejected"`.
- Preference dataset with implicit prompt: `"chosen"` and `"rejected"`.
- Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`.
- [Name to find]: `"prompt"`, `"completion"`, and `"labels"`.
qgallouedec marked this conversation as resolved.
Show resolved Hide resolved

For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
messages, where each message is a dictionary with keys `"role"` and `"content"`.
Expand Down
Loading