-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add function call #25
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
import torch | ||
from loguru import logger | ||
from torch.utils.data import Dataset | ||
from .utils.tool_utils import tool_formater, function_formatter | ||
|
||
|
||
class SFTDataset(Dataset): | ||
|
@@ -12,6 +13,9 @@ def __init__(self, file, tokenizer, max_seq_length, template): | |
self.system_format = template["system_format"] | ||
self.user_format = template["user_format"] | ||
self.assistant_format = template["assistant_format"] | ||
self.tool_format = template["tool_format"] | ||
self.function_format = template["function_format"] | ||
self.observation_format = template["observation_format"] | ||
Comment on lines
+16
to
+18
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling for missing keys in Accessing Apply this diff to safely access template keys with default values: -self.tool_format = template["tool_format"]
-self.function_format = template["function_format"]
-self.observation_format = template["observation_format"]
+self.tool_format = template.get("tool_format", default_tool_format)
+self.function_format = template.get("function_format", default_function_format)
+self.observation_format = template.get("observation_format", default_observation_format) Make sure to define appropriate default values for
|
||
|
||
self.max_seq_length = max_seq_length | ||
logger.info("Loading data: {}".format(file)) | ||
|
@@ -39,27 +43,33 @@ def __getitem__(self, index): | |
|
||
conversations = data["conversations"] | ||
|
||
for i in range(0, len(conversations) - 1, 2): | ||
if ( | ||
conversations[i]["role"] != "user" | ||
or conversations[i + 1]["role"] != "assistant" | ||
): | ||
raise ValueError("The role order of the conversation is not correct") | ||
human = conversations[i]["content"].strip() | ||
assistant = conversations[i + 1]["content"].strip() | ||
|
||
human = self.user_format.format( | ||
content=human, stop_token=self.tokenizer.eos_token | ||
) | ||
assistant = self.assistant_format.format( | ||
content=assistant, stop_token=self.tokenizer.eos_token | ||
) | ||
|
||
input_tokens = self.tokenizer.encode(human, add_special_tokens=False) | ||
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False) | ||
|
||
input_ids += input_tokens + output_tokens | ||
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens) | ||
input_buffer = "" | ||
for i in range(len(conversations)): | ||
role = conversations[i]["role"] | ||
content = conversations[i]["content"].strip() | ||
|
||
if role != "assistant": | ||
if role == "user": | ||
human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token) | ||
input_buffer += human | ||
|
||
elif role == "function_call": | ||
tool_calls = function_formatter(json.loads(content)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add exception handling for JSON parsing in Parsing JSON content without error handling can lead to crashes if the content is invalid. To enhance robustness, add a try-except block to handle Apply this diff to add error handling: - tool_calls = function_formatter(json.loads(content))
+ try:
+ tool_calls = function_formatter(json.loads(content))
+ except json.JSONDecodeError as e:
+ logger.error(f"Invalid JSON in 'function_call' content at index {i}: {e}")
+ # Handle the error appropriately, e.g., skip this entry or provide a fallback
+ continue # or handle as needed
|
||
function = self.function_format.format(content=tool_calls) | ||
input_buffer += function | ||
|
||
elif role == "observation": | ||
observation = self.observation_format.format(content=content) | ||
input_buffer += observation | ||
else: | ||
Comment on lines
+51
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Handle unexpected conversation roles to prevent missing content. The current implementation processes specific roles: Apply this diff to handle unexpected roles: elif role == "observation":
observation = self.observation_format.format(content=content)
input_buffer += observation
+else:
+ logger.warning(f"Unhandled role '{role}' at index {i}. Content may be skipped.")
|
||
assistant = self.assistant_format.format(content=content, stop_token=self.tokenizer.eos_token) | ||
|
||
input_tokens = self.tokenizer.encode(input_buffer, add_special_tokens=False) | ||
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False) | ||
|
||
input_ids += input_tokens + output_tokens | ||
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens) | ||
input_buffer = "" | ||
|
||
assert len(input_ids) == len(target_mask) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import
tool_formater
.The function
tool_formater
is imported but not used in this file. Removing it will clean up the code and prevent confusion.Apply this diff to remove the unused import:
📝 Committable suggestion
🧰 Tools
🪛 Ruff
7-7:
.utils.tool_utils.tool_formater
imported but unusedRemove unused import:
.utils.tool_utils.tool_formater
(F401)