-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add function call #25
base: main
Are you sure you want to change the base?
Conversation
WalkthroughThe changes involve modifications to the Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Assistant
participant SFTDataset
participant ToolFormatter
participant FunctionFormatter
User->>SFTDataset: Request conversation
SFTDataset->>User: Process user input
SFTDataset->>FunctionFormatter: Format function calls
FunctionFormatter-->>SFTDataset: Return formatted function calls
SFTDataset->>ToolFormatter: Format tool descriptions
ToolFormatter-->>SFTDataset: Return formatted tool descriptions
SFTDataset->>Assistant: Provide formatted conversation
Assistant-->>User: Respond with output
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (2)
utils/tool_utils.py (2)
16-17
: Remove unnecessary empty lines.
Consider removing one of the empty lines to maintain consistent spacing between sections.
1-65
: Consider architectural improvements for better maintainability.
- Consider creating a dedicated exception class (e.g.,
ToolFormattingError
) for better error handling and distinction between different types of failures. - Add comprehensive docstrings with examples showing the expected format of tool and function call dictionaries.
- Consider extracting the template strings into a separate configuration file for easier maintenance.
- Ensure unit tests cover various edge cases, especially for the complex nested dictionary processing in
tool_formatter
.
Would you like me to help with implementing any of these suggestions or creating unit tests?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
- dataset.py (3 hunks)
- utils/constants.py (1 hunks)
- utils/tool_utils.py (1 hunks)
🔥 Files not summarized due to errors (1)
- utils/constants.py: Error: Disallowed special token found: <|im_start|>
🧰 Additional context used
🪛 Ruff
dataset.py
7-7: .utils.tool_utils.tool_formater
imported but unused
Remove unused import: .utils.tool_utils.tool_formater
(F401)
def tool_formater(tools: List[Dict[str, Any]]) -> str: | ||
tool_text = "" | ||
tool_names = [] | ||
for tool in tools: | ||
param_text = "" | ||
for name, param in tool["parameters"]["properties"].items(): | ||
required, enum, items = "", "", "" | ||
if name in tool["parameters"].get("required", []): | ||
required = ", required" | ||
|
||
if param.get("enum", None): | ||
enum = ", should be one of [{}]".format(", ".join(param["enum"])) | ||
|
||
if param.get("items", None): | ||
items = ", where each item should be {}".format(param["items"].get("type", "")) | ||
|
||
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( | ||
name=name, | ||
type=param.get("type", ""), | ||
required=required, | ||
desc=param.get("description", ""), | ||
enum=enum, | ||
items=items, | ||
) | ||
|
||
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( | ||
name=tool["name"], desc=tool.get("description", ""), args=param_text | ||
) | ||
tool_names.append(tool["name"]) | ||
|
||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix typo in function name and enhance input validation.
- The function name should be
tool_formatter
for consistency. - Add input validation for required fields.
- Add type hints for internal variables.
Here's the suggested implementation:
-def tool_formater(tools: List[Dict[str, Any]]) -> str:
+def tool_formatter(tools: List[Dict[str, Any]]) -> str:
+ """Format a list of tools into a structured string description.
+
+ Args:
+ tools: List of tool dictionaries containing name, description, and parameters.
+
+ Returns:
+ Formatted string describing the tools.
+
+ Raises:
+ ValueError: If required fields are missing in the tool dictionary.
+ """
+ if not tools:
+ raise ValueError("Tools list cannot be empty")
+
tool_text: str = ""
- tool_names = []
+ tool_names: List[str] = []
for tool in tools:
+ if not isinstance(tool, dict) or "name" not in tool or "parameters" not in tool:
+ raise ValueError("Each tool must be a dictionary with 'name' and 'parameters' keys")
+
param_text = ""
for name, param in tool["parameters"]["properties"].items():
Committable suggestion was skipped due to low confidence.
def function_formatter(tool_calls, function_slots=DEFAULT_FUNCTION_SLOTS) -> str: | ||
functions : List[Tuple[str, str]] = [] | ||
if not isinstance(tool_calls, list): | ||
tool_calls = [tool_calls] # parrallel function calls | ||
|
||
for tool_call in tool_calls: | ||
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) | ||
|
||
elements = [] | ||
for name, arguments in functions: | ||
text = function_slots.format(name=name, arguments=arguments) | ||
elements.append(text) | ||
|
||
return "\n".join(elements)+"\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Input validation needed, but ensure_ascii=False can be retained
The function needs improvements for robustness, but the ensure_ascii=False
setting should be kept as it's used in a dataset loading context where preserving potential Unicode content is important:
def function_formatter(tool_calls, function_slots=DEFAULT_FUNCTION_SLOTS) -> str:
+ """Format tool calls into a string representation.
+
+ Args:
+ tool_calls: Single tool call or list of tool calls.
+ function_slots: Template for formatting function calls.
+
+ Returns:
+ Formatted string of function calls.
+
+ Raises:
+ ValueError: If tool_calls are malformed.
+ TypeError: If arguments cannot be serialized to JSON.
+ """
functions : List[Tuple[str, str]] = []
if not isinstance(tool_calls, list):
- tool_calls = [tool_calls] # parrallel function calls
+ tool_calls = [tool_calls] # parallel function calls
for tool_call in tool_calls:
- functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
+ if not isinstance(tool_call, dict) or "name" not in tool_call or "arguments" not in tool_call:
+ raise ValueError("Each tool call must be a dictionary with 'name' and 'arguments' keys")
+ try:
+ arguments_json = json.dumps(tool_call["arguments"], ensure_ascii=False)
+ functions.append((tool_call["name"], arguments_json))
+ except (TypeError, ValueError) as e:
+ raise TypeError(f"Failed to serialize arguments for tool {tool_call['name']}: {str(e)}")
The function is used in dataset loading (dataset.py
) where it processes JSON content, so we should keep ensure_ascii=False
to preserve any potential Unicode characters in the training data. The other suggested improvements for input validation and error handling are still valuable additions.
🔗 Analysis chain
Add input validation and error handling.
The function needs several improvements for robustness and security:
- Fix typo in comment ("parrallel" → "parallel")
- Add input validation
- Add error handling for JSON serialization
- Consider security implications of
ensure_ascii=False
Here's the suggested implementation:
def function_formatter(tool_calls, function_slots=DEFAULT_FUNCTION_SLOTS) -> str:
+ """Format tool calls into a string representation.
+
+ Args:
+ tool_calls: Single tool call or list of tool calls.
+ function_slots: Template for formatting function calls.
+
+ Returns:
+ Formatted string of function calls.
+
+ Raises:
+ ValueError: If tool_calls are malformed.
+ TypeError: If arguments cannot be serialized to JSON.
+ """
functions : List[Tuple[str, str]] = []
if not isinstance(tool_calls, list):
- tool_calls = [tool_calls] # parrallel function calls
+ tool_calls = [tool_calls] # parallel function calls
for tool_call in tool_calls:
- functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
+ if not isinstance(tool_call, dict) or "name" not in tool_call or "arguments" not in tool_call:
+ raise ValueError("Each tool call must be a dictionary with 'name' and 'arguments' keys")
+ try:
+ arguments_json = json.dumps(tool_call["arguments"])
+ functions.append((tool_call["name"], arguments_json))
+ except (TypeError, ValueError) as e:
+ raise TypeError(f"Failed to serialize arguments for tool {tool_call['name']}: {str(e)}")
Consider using ensure_ascii=True
in json.dumps()
for better security unless there's a specific requirement for non-ASCII characters.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if there are any non-ASCII characters in the codebase that would require ensure_ascii=False
rg -l "[^\x00-\x7F]"
Length of output: 22
Script:
#!/bin/bash
# Let's check how this function is used in the codebase to understand the requirements
rg -A 5 "function_formatter"
# Also check for any test files that might reveal expected behavior
fd "test.*\.py" --exec grep -l "function_formatter" {}
Length of output: 1143
Script:
#!/bin/bash
# Let's check the actual content being processed to understand if non-ASCII support is needed
rg -B 2 -A 2 "json\.loads\(content\)" dataset.py
# Also check the DEFAULT_FUNCTION_SLOTS format
rg "DEFAULT_FUNCTION_SLOTS\s*=\s*" utils/tool_utils.py
Length of output: 438
@@ -4,6 +4,7 @@ | |||
import torch | |||
from loguru import logger | |||
from torch.utils.data import Dataset | |||
from .utils.tool_utils import tool_formater, function_formatter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import tool_formater
.
The function tool_formater
is imported but not used in this file. Removing it will clean up the code and prevent confusion.
Apply this diff to remove the unused import:
-from .utils.tool_utils import tool_formater, function_formatter
+from .utils.tool_utils import function_formatter
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from .utils.tool_utils import tool_formater, function_formatter | |
from .utils.tool_utils import function_formatter |
🧰 Tools
🪛 Ruff
7-7: .utils.tool_utils.tool_formater
imported but unused
Remove unused import: .utils.tool_utils.tool_formater
(F401)
self.tool_format = template["tool_format"] | ||
self.function_format = template["function_format"] | ||
self.observation_format = template["observation_format"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add error handling for missing keys in template
.
Accessing template["tool_format"]
, template["function_format"]
, and template["observation_format"]
without checking if these keys exist may raise a KeyError
if any of them are missing. Consider using dict.get()
with default values or adding error handling to ensure robustness.
Apply this diff to safely access template keys with default values:
-self.tool_format = template["tool_format"]
-self.function_format = template["function_format"]
-self.observation_format = template["observation_format"]
+self.tool_format = template.get("tool_format", default_tool_format)
+self.function_format = template.get("function_format", default_function_format)
+self.observation_format = template.get("observation_format", default_observation_format)
Make sure to define appropriate default values for default_tool_format
, default_function_format
, and default_observation_format
.
Committable suggestion was skipped due to low confidence.
if role != "assistant": | ||
if role == "user": | ||
human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token) | ||
input_buffer += human | ||
|
||
elif role == "function_call": | ||
tool_calls = function_formatter(json.loads(content)) | ||
function = self.function_format.format(content=tool_calls) | ||
input_buffer += function | ||
|
||
elif role == "observation": | ||
observation = self.observation_format.format(content=content) | ||
input_buffer += observation | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Handle unexpected conversation roles to prevent missing content.
The current implementation processes specific roles: 'user'
, 'function_call'
, and 'observation'
. If a conversation entry has an unexpected role, it will be ignored, potentially leading to incomplete data processing. Consider adding an else
clause to log or handle unexpected roles.
Apply this diff to handle unexpected roles:
elif role == "observation":
observation = self.observation_format.format(content=content)
input_buffer += observation
+else:
+ logger.warning(f"Unhandled role '{role}' at index {i}. Content may be skipped.")
Committable suggestion was skipped due to low confidence.
input_buffer += human | ||
|
||
elif role == "function_call": | ||
tool_calls = function_formatter(json.loads(content)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add exception handling for JSON parsing in 'function_call'
role.
Parsing JSON content without error handling can lead to crashes if the content is invalid. To enhance robustness, add a try-except block to handle json.JSONDecodeError
.
Apply this diff to add error handling:
- tool_calls = function_formatter(json.loads(content))
+ try:
+ tool_calls = function_formatter(json.loads(content))
+ except json.JSONDecodeError as e:
+ logger.error(f"Invalid JSON in 'function_call' content at index {i}: {e}")
+ # Handle the error appropriately, e.g., skip this entry or provide a fallback
+ continue # or handle as needed
Committable suggestion was skipped due to low confidence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (2)
data/test_data_loading.py (2)
12-12
: Add type hints and docstringThe test function should be properly documented with its purpose and expected behavior.
-def test_data_loading(load_data): +def test_data_loading(load_data: list[str]) -> None: + """Verify that the conversation data file contains valid JSON structures. + + Tests that: + - Data is loaded successfully + - Each line is valid JSON + - Each entry has the required conversation structure + - Conversations include both user and assistant roles + + Args: + load_data: List of JSON lines from the test data file + """
12-39
: Consider adding parameterized test casesThe test could benefit from explicit test cases for different scenarios.
Consider adding these test cases:
@pytest.mark.parametrize("test_data", [ # Valid case '[{"conversations": [{"role": "user", "content": "hi"}, {"role": "assistant", "content": "hello"}]}]', # Invalid cases to test error handling '[{"conversations": []}]', # Empty conversations '[{"conversations": [{"role": "invalid", "content": "test"}]}]', # Invalid role '[{"conversations": [{"role": "user"}]}]', # Missing content ]) def test_data_structure(test_data: str) -> None: """Test various conversation data structures."""
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
data/test_data_loading.py
(1 hunks)
🔇 Additional comments (1)
data/test_data_loading.py (1)
1-39
: Verify test data file existence in repository
Need to ensure the test data file exists and contains appropriate test data.
✅ Verification successful
Test data file exists and contains valid data
The verification confirms that:
- The test data file exists at
data/function_calling_demo.jsonl
- The file contains valid JSON data with the expected structure:
- Has "conversations" array with required "user" and "assistant" roles
- JSON format is valid and matches the test's expectations
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Check if the test data file exists and validate its content
echo "Checking for test data file..."
fd -t f "function_calling_demo.jsonl"
echo "Checking file content structure..."
if [ -f "data/function_calling_demo.jsonl" ]; then
head -n 1 data/function_calling_demo.jsonl
fi
Length of output: 2584
@pytest.fixture | ||
def load_data(): | ||
with open("function_calling_demo.jsonl", "r", encoding="utf8") as f: | ||
data_list = f.readlines() | ||
return data_list |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve fixture robustness and documentation
The fixture needs improvements in file path handling and documentation:
+import os
+from pathlib import Path
@pytest.fixture
def load_data():
+ """Load test data from function_calling_demo.jsonl file.
+
+ Returns:
+ list[str]: List of JSON lines from the test data file.
+ """
+ data_path = Path(__file__).parent / "function_calling_demo.jsonl"
+ if not data_path.exists():
+ raise FileNotFoundError(f"Test data file not found: {data_path}")
+
- with open("function_calling_demo.jsonl", "r", encoding="utf8") as f:
+ with open(data_path, "r", encoding="utf8") as f:
data_list = f.readlines()
return data_list
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
@pytest.fixture | |
def load_data(): | |
with open("function_calling_demo.jsonl", "r", encoding="utf8") as f: | |
data_list = f.readlines() | |
return data_list | |
import os | |
from pathlib import Path | |
@pytest.fixture | |
def load_data(): | |
"""Load test data from function_calling_demo.jsonl file. | |
Returns: | |
list[str]: List of JSON lines from the test data file. | |
""" | |
data_path = Path(__file__).parent / "function_calling_demo.jsonl" | |
if not data_path.exists(): | |
raise FileNotFoundError(f"Test data file not found: {data_path}") | |
with open(data_path, "r", encoding="utf8") as f: | |
data_list = f.readlines() | |
return data_list |
for line in load_data: | ||
try: | ||
data = json.loads(line) | ||
assert isinstance( | ||
data, dict | ||
), "Each line should be a JSON object (dictionary)" | ||
|
||
# Check for "conversations" key | ||
assert ( | ||
"conversations" in data | ||
), "'conversations' key is missing in the JSON object" | ||
assert isinstance( | ||
data["conversations"], list | ||
), "'conversations' should be a list" | ||
|
||
# Check for 'user' and 'assistant' roles within "conversations" | ||
roles = {conv["role"] for conv in data["conversations"] if "role" in conv} | ||
assert "user" in roles, "Role 'user' is missing in conversations" | ||
assert "assistant" in roles, "Role 'assistant' is missing in conversations" | ||
|
||
except json.JSONDecodeError: | ||
pytest.fail("Each line in the data file should be valid JSON") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Enhance test coverage and error messages
The test could be more comprehensive in validating conversation structure and provide better error messages.
# Verify each line is valid JSON with required fields
- for line in load_data:
+ for i, line in enumerate(load_data, 1):
try:
data = json.loads(line)
assert isinstance(
data, dict
- ), "Each line should be a JSON object (dictionary)"
+ ), f"Line {i}: Expected JSON object, got {type(data).__name__}"
# Check for "conversations" key
assert (
"conversations" in data
- ), "'conversations' key is missing in the JSON object"
+ ), f"Line {i}: Missing required 'conversations' key"
assert isinstance(
data["conversations"], list
- ), "'conversations' should be a list"
+ ), f"Line {i}: 'conversations' must be a list, got {type(data['conversations']).__name__}"
# Check for 'user' and 'assistant' roles within "conversations"
roles = {conv["role"] for conv in data["conversations"] if "role" in conv}
- assert "user" in roles, "Role 'user' is missing in conversations"
- assert "assistant" in roles, "Role 'assistant' is missing in conversations"
+ assert "user" in roles, f"Line {i}: Required role 'user' is missing in conversations"
+ assert "assistant" in roles, f"Line {i}: Required role 'assistant' is missing in conversations"
+
+ # Validate conversation structure
+ for j, conv in enumerate(data["conversations"]):
+ assert "role" in conv, f"Line {i}, conversation {j}: Missing 'role' field"
+ assert "content" in conv, f"Line {i}, conversation {j}: Missing 'content' field"
+ assert isinstance(conv["content"], str), f"Line {i}, conversation {j}: 'content' must be string"
+ assert conv["role"] in {"user", "assistant"}, f"Line {i}, conversation {j}: Invalid role '{conv['role']}'"
except json.JSONDecodeError:
- pytest.fail("Each line in the data file should be valid JSON")
+ pytest.fail(f"Line {i}: Invalid JSON format")
+ except Exception as e:
+ pytest.fail(f"Line {i}: Unexpected error: {str(e)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for line in load_data: | |
try: | |
data = json.loads(line) | |
assert isinstance( | |
data, dict | |
), "Each line should be a JSON object (dictionary)" | |
# Check for "conversations" key | |
assert ( | |
"conversations" in data | |
), "'conversations' key is missing in the JSON object" | |
assert isinstance( | |
data["conversations"], list | |
), "'conversations' should be a list" | |
# Check for 'user' and 'assistant' roles within "conversations" | |
roles = {conv["role"] for conv in data["conversations"] if "role" in conv} | |
assert "user" in roles, "Role 'user' is missing in conversations" | |
assert "assistant" in roles, "Role 'assistant' is missing in conversations" | |
except json.JSONDecodeError: | |
pytest.fail("Each line in the data file should be valid JSON") | |
for i, line in enumerate(load_data, 1): | |
try: | |
data = json.loads(line) | |
assert isinstance( | |
data, dict | |
), f"Line {i}: Expected JSON object, got {type(data).__name__}" | |
# Check for "conversations" key | |
assert ( | |
"conversations" in data | |
), f"Line {i}: Missing required 'conversations' key" | |
assert isinstance( | |
data["conversations"], list | |
), f"Line {i}: 'conversations' must be a list, got {type(data['conversations']).__name__}" | |
# Check for 'user' and 'assistant' roles within "conversations" | |
roles = {conv["role"] for conv in data["conversations"] if "role" in conv} | |
assert "user" in roles, f"Line {i}: Required role 'user' is missing in conversations" | |
assert "assistant" in roles, f"Line {i}: Required role 'assistant' is missing in conversations" | |
# Validate conversation structure | |
for j, conv in enumerate(data["conversations"]): | |
assert "role" in conv, f"Line {i}, conversation {j}: Missing 'role' field" | |
assert "content" in conv, f"Line {i}, conversation {j}: Missing 'content' field" | |
assert isinstance(conv["content"], str), f"Line {i}, conversation {j}: 'content' must be string" | |
assert conv["role"] in {"user", "assistant"}, f"Line {i}, conversation {j}: Invalid role '{conv['role']}'" | |
except json.JSONDecodeError: | |
pytest.fail(f"Line {i}: Invalid JSON format") | |
except Exception as e: | |
pytest.fail(f"Line {i}: Unexpected error: {str(e)}") |
Summary by CodeRabbit
New Features
Bug Fixes
Tests