Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add function call #25

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 31 additions & 21 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from loguru import logger
from torch.utils.data import Dataset
from .utils.tool_utils import tool_formater, function_formatter
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused import tool_formater.

The function tool_formater is imported but not used in this file. Removing it will clean up the code and prevent confusion.

Apply this diff to remove the unused import:

-from .utils.tool_utils import tool_formater, function_formatter
+from .utils.tool_utils import function_formatter
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from .utils.tool_utils import tool_formater, function_formatter
from .utils.tool_utils import function_formatter
🧰 Tools
🪛 Ruff

7-7: .utils.tool_utils.tool_formater imported but unused

Remove unused import: .utils.tool_utils.tool_formater

(F401)



class SFTDataset(Dataset):
Expand All @@ -12,6 +13,9 @@ def __init__(self, file, tokenizer, max_seq_length, template):
self.system_format = template["system_format"]
self.user_format = template["user_format"]
self.assistant_format = template["assistant_format"]
self.tool_format = template["tool_format"]
self.function_format = template["function_format"]
self.observation_format = template["observation_format"]
Comment on lines +16 to +18
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling for missing keys in template.

Accessing template["tool_format"], template["function_format"], and template["observation_format"] without checking if these keys exist may raise a KeyError if any of them are missing. Consider using dict.get() with default values or adding error handling to ensure robustness.

Apply this diff to safely access template keys with default values:

-self.tool_format = template["tool_format"]
-self.function_format = template["function_format"]
-self.observation_format = template["observation_format"]
+self.tool_format = template.get("tool_format", default_tool_format)
+self.function_format = template.get("function_format", default_function_format)
+self.observation_format = template.get("observation_format", default_observation_format)

Make sure to define appropriate default values for default_tool_format, default_function_format, and default_observation_format.

Committable suggestion was skipped due to low confidence.


self.max_seq_length = max_seq_length
logger.info("Loading data: {}".format(file))
Expand Down Expand Up @@ -39,27 +43,33 @@ def __getitem__(self, index):

conversations = data["conversations"]

for i in range(0, len(conversations) - 1, 2):
if (
conversations[i]["role"] != "user"
or conversations[i + 1]["role"] != "assistant"
):
raise ValueError("The role order of the conversation is not correct")
human = conversations[i]["content"].strip()
assistant = conversations[i + 1]["content"].strip()

human = self.user_format.format(
content=human, stop_token=self.tokenizer.eos_token
)
assistant = self.assistant_format.format(
content=assistant, stop_token=self.tokenizer.eos_token
)

input_tokens = self.tokenizer.encode(human, add_special_tokens=False)
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False)

input_ids += input_tokens + output_tokens
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)
input_buffer = ""
for i in range(len(conversations)):
role = conversations[i]["role"]
content = conversations[i]["content"].strip()

if role != "assistant":
if role == "user":
human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token)
input_buffer += human

elif role == "function_call":
tool_calls = function_formatter(json.loads(content))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add exception handling for JSON parsing in 'function_call' role.

Parsing JSON content without error handling can lead to crashes if the content is invalid. To enhance robustness, add a try-except block to handle json.JSONDecodeError.

Apply this diff to add error handling:

-    tool_calls = function_formatter(json.loads(content))
+    try:
+        tool_calls = function_formatter(json.loads(content))
+    except json.JSONDecodeError as e:
+        logger.error(f"Invalid JSON in 'function_call' content at index {i}: {e}")
+        # Handle the error appropriately, e.g., skip this entry or provide a fallback
+        continue  # or handle as needed

Committable suggestion was skipped due to low confidence.

function = self.function_format.format(content=tool_calls)
input_buffer += function

elif role == "observation":
observation = self.observation_format.format(content=content)
input_buffer += observation
else:
Comment on lines +51 to +64
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Handle unexpected conversation roles to prevent missing content.

The current implementation processes specific roles: 'user', 'function_call', and 'observation'. If a conversation entry has an unexpected role, it will be ignored, potentially leading to incomplete data processing. Consider adding an else clause to log or handle unexpected roles.

Apply this diff to handle unexpected roles:

elif role == "observation":
    observation = self.observation_format.format(content=content)
    input_buffer += observation
+else:
+    logger.warning(f"Unhandled role '{role}' at index {i}. Content may be skipped.")

Committable suggestion was skipped due to low confidence.

assistant = self.assistant_format.format(content=content, stop_token=self.tokenizer.eos_token)

input_tokens = self.tokenizer.encode(input_buffer, add_special_tokens=False)
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False)

input_ids += input_tokens + output_tokens
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)
input_buffer = ""

assert len(input_ids) == len(target_mask)

Expand Down
Loading