Skip to content

Commit

Permalink
feat: add function call
Browse files Browse the repository at this point in the history
  • Loading branch information
astridesa committed Oct 29, 2024
1 parent 67a41e4 commit cd1eb8f
Show file tree
Hide file tree
Showing 4 changed files with 402 additions and 21 deletions.
52 changes: 31 additions & 21 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from loguru import logger
from torch.utils.data import Dataset
from .utils.tool_utils import tool_formater, function_formatter


class SFTDataset(Dataset):
Expand All @@ -12,6 +13,9 @@ def __init__(self, file, tokenizer, max_seq_length, template):
self.system_format = template["system_format"]
self.user_format = template["user_format"]
self.assistant_format = template["assistant_format"]
self.tool_format = template["tool_format"]
self.function_format = template["function_format"]
self.observation_format = template["observation_format"]

self.max_seq_length = max_seq_length
logger.info("Loading data: {}".format(file))
Expand Down Expand Up @@ -39,27 +43,33 @@ def __getitem__(self, index):

conversations = data["conversations"]

for i in range(0, len(conversations) - 1, 2):
if (
conversations[i]["role"] != "user"
or conversations[i + 1]["role"] != "assistant"
):
raise ValueError("The role order of the conversation is not correct")
human = conversations[i]["content"].strip()
assistant = conversations[i + 1]["content"].strip()

human = self.user_format.format(
content=human, stop_token=self.tokenizer.eos_token
)
assistant = self.assistant_format.format(
content=assistant, stop_token=self.tokenizer.eos_token
)

input_tokens = self.tokenizer.encode(human, add_special_tokens=False)
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False)

input_ids += input_tokens + output_tokens
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)
input_buffer = ""
for i in range(len(conversations)):
role = conversations[i]["role"]
content = conversations[i]["content"].strip()

if role != "assistant":
if role == "user":
human = self.user_format.format(content=content, stop_token=self.tokenizer.eos_token)
input_buffer += human

elif role == "function_call":
tool_calls = function_formatter(json.loads(content))
function = self.function_format.format(content=tool_calls)
input_buffer += function

elif role == "observation":
observation = self.observation_format.format(content=content)
input_buffer += observation
else:
assistant = self.assistant_format.format(content=content, stop_token=self.tokenizer.eos_token)

input_tokens = self.tokenizer.encode(input_buffer, add_special_tokens=False)
output_tokens = self.tokenizer.encode(assistant, add_special_tokens=False)

input_ids += input_tokens + output_tokens
target_mask += [0] * len(input_tokens) + [1] * len(output_tokens)
input_buffer = ""

assert len(input_ids) == len(target_mask)

Expand Down
Loading

0 comments on commit cd1eb8f

Please sign in to comment.