diff --git a/r2ai/auto.py b/r2ai/auto.py index 8188a90..637bed2 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -2,22 +2,17 @@ import json import sys import re -import os -from llama_cpp import Llama -from llama_cpp.llama_tokenizer import LlamaHFTokenizer -from transformers import AutoTokenizer -from . import index -from .pipe import have_rlang, r2lang, get_r2_inst +from . import LOGGER from litellm import _should_retry, acompletion, utils, ModelResponse import asyncio -from r2ai.pipe import get_r2_inst from .tools import r2cmd, run_python import json import signal from .spinner import spinner +from .completion import create_chat_completion +import uuid ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') - SYSTEM_PROMPT_AUTO = """ You are a reverse engineer and you are using radare2 to analyze a binary. The user will ask questions about the binary and you will respond with the answer to the best of your ability. @@ -41,12 +36,15 @@ """ class ChatAuto: - def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ): + def __init__(self, model, interpreter, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ): + self.logger = LOGGER self.functions = {} self.tools = [] self.model = model self.system = system self.messages = messages + self.interpreter = interpreter + self.system_message = None if messages and messages[0]['role'] != 'system' and system: self.messages.insert(0, { "role": "system", "content": system }) if cb: @@ -67,16 +65,25 @@ def __init__(self, model, system=None, tools=None, messages=None, tool_choice='a async def process_tool_calls(self, tool_calls): if tool_calls: for tool_call in tool_calls: + self.logger.debug(f"tool_call: {tool_call}") tool_name = tool_call["function"]["name"] + if "id" not in tool_call: + tool_call["id"] = str(uuid.uuid4()) try: tool_args = json.loads(tool_call["function"]["arguments"]) except Exception: - self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Unable to parse JSON" , "tool_call_id": tool_call["id"]}) - continue + if "arguments" in tool_call["function"] and type(tool_call["function"]["arguments"]) == dict: + tool_args = tool_call["function"]["arguments"] + else: + self.logger.error(f'Error parsing JSON: {tool_call["function"]["arguments"]}') + # raise Exception('Error parsing JSON') + self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Unable to parse JSON" , "tool_call_id": tool_call["id"]}) + continue if tool_name not in self.functions: + self.logger.error(f'Tool not found: {tool_name}') self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Tool not found" , "tool_call_id": tool_call["id"]}) continue - + self.cb('tool_call', { "id": tool_call["id"], "function": { "name": tool_name, "arguments": tool_args } }) if asyncio.iscoroutinefunction(self.functions[tool_name]): tool_response = await self.functions[tool_name](**tool_args) @@ -84,12 +91,14 @@ async def process_tool_calls(self, tool_calls): tool_response = self.functions[tool_name](**tool_args) self.cb('tool_response', { "id": tool_call["id"] + '_response', "content": tool_response }) self.messages.append({"role": "tool", "name": tool_name, "content": ANSI_REGEX.sub('', tool_response), "tool_call_id": tool_call["id"]}) - + return await self.get_completion() async def process_streaming_response(self, resp): tool_calls = [] msgs = [] + parts = [] + current_message = { "role": "assistant", "content": "", "tool_calls": [] } async for chunk in resp: delta = None choice = chunk.choices[0] @@ -99,47 +108,89 @@ async def process_streaming_response(self, resp): index = delta_tool_calls.index fn_delta = delta_tool_calls.function tool_call_id = delta_tool_calls.id - if len(tool_calls) < index + 1: - tool_calls.append({ - "id": tool_call_id, + if len(current_message['tool_calls']) < index + 1: + tool_call = { + "id": tool_call_id or str(uuid.uuid4()), "type": "function", "function": { "name":fn_delta.name, "arguments": fn_delta.arguments } } - ) + current_message['tool_calls'].append(tool_call) else: - tool_calls[index]["function"]["arguments"] += fn_delta.arguments + if fn_delta.name: + current_message['tool_calls'][index]["function"]["name"] = fn_delta.name + current_message['tool_calls'][index]["function"]["arguments"] += fn_delta.arguments else: m = None done = False if delta.content is not None: m = delta.content if m is not None: - msgs.append(m) + current_message['content'] += m self.cb('message', { "content": m, "id": 'message_' + chunk.id, 'done': False }) if 'finish_reason' in choice and choice['finish_reason'] == 'stop': done = True self.cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True }) self.cb('message_stream', { "content": m if m else '', "id": 'message_' + chunk.id, 'done': done }) - if (len(tool_calls) > 0): - self.messages.append({"role": "assistant", "tool_calls": tool_calls}) + self.messages.append(current_message) + if len(current_message['tool_calls']) > 0: + await self.process_tool_calls(current_message['tool_calls']) + if len(current_message['content']) > 0: + return current_message + + async def process_response(self, resp): + content = resp.choices[0].message.content + tool_calls = [] + for tool_call in resp.choices[0].message.tool_calls or []: + tool_calls.append({ + "id": tool_call.id, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments + }, + }) + if len(tool_calls) == 0: + try: + tool_call = json.loads(content) + if 'name' in tool_call and tool_call['name'] in self.functions: + args = None + if 'arguments' in tool_call: + args = tool_call['arguments'] + elif 'parameters' in tool_call: + args = tool_call['parameters'] + if args: + tool_calls.append({ "id": resp.id, "function": { "name": tool_call['name'], "arguments": json.dumps(args) } }) + except Exception: + pass + if len(tool_calls) > 0: + self.messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""}) await self.process_tool_calls(tool_calls) - if len(msgs) > 0: - response_message = ''.join(msgs) - self.messages.append({"role": "assistant", "content": response_message}) - return response_message + if content is None or len(content) > 0: + self.messages.append({"role": "assistant", "content": content or ""}) + + return content + + async def async_response_generator(self, response): + for item in response: + self.logger.debug(item) + resp = ModelResponse(stream=True, **item) + yield resp - async def attempt_completion(self): + async def attempt_completion(self, stream=True): args = { "temperature": 0, - "tools": self.tools, - "tool_choice": self.tool_choice, - "stream": True + "max_tokens": int(self.interpreter.env["llm.maxtokens"]), + "stream": stream, } - if self.llama_instance: - return self.llama_instance.create_chat_completion(self.messages, **args) + + if self.llama_instance: + res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args) + if args['stream']: + return self.async_response_generator(res) + else: + return ModelResponse(**next(res)) return await acompletion( model=self.model, @@ -150,10 +201,11 @@ async def attempt_completion(self): async def get_completion(self): if self.llama_instance: response = await self.attempt_completion() - async def async_generator(response): - for item in response: - yield ModelResponse(stream=True, **item) - return await self.process_streaming_response(async_generator(response)) + stream = True + if stream: + return await self.process_streaming_response(response) + else: + return await self.process_response(response) max_retries = 5 base_delay = 2 @@ -162,12 +214,12 @@ async def async_generator(response): response = await self.attempt_completion() return await self.process_streaming_response(response) except Exception as e: - print(e) + self.logger.error(f'Error getting completion: {e}') if not _should_retry(getattr(e, 'status_code', None)) or retry_count == max_retries - 1: raise delay = base_delay * (2 ** retry_count) - print(f"Retrying in {delay} seconds...") + self.logger.info(f"Retrying in {delay} seconds...") await asyncio.sleep(delay) raise Exception("Max retries reached. Unable to get completion.") @@ -179,7 +231,8 @@ async def chat(self) -> str: def cb(type, data): spinner.stop() if type == 'message_stream': - sys.stdout.write(data['content']) + if 'content' in data: + sys.stdout.write(data['content']) elif type == 'tool_call': if data['function']['name'] == 'r2cmd': builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m') @@ -187,8 +240,9 @@ def cb(type, data): builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m') builtins.print(data['function']['arguments']['command']) elif type == 'tool_response': - sys.stdout.write(data['content']) - sys.stdout.flush() + if 'content' in data: + sys.stdout.write(data['content']) + sys.stdout.flush() # builtins.print(data['content']) elif type == 'message' and data['done']: builtins.print() @@ -196,7 +250,7 @@ def cb(type, data): def signal_handler(signum, frame): raise KeyboardInterrupt -def chat(interpreter, llama_instance=None): +def chat(interpreter, **kwargs): model = interpreter.model.replace(":", "/") tools = [r2cmd, run_python] messages = interpreter.messages @@ -211,7 +265,7 @@ def chat(interpreter, llama_instance=None): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=llama_instance, cb=cb) + chat_auto = ChatAuto(model, interpreter=interpreter, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=interpreter.llama_instance, cb=cb) original_handler = signal.getsignal(signal.SIGINT) diff --git a/r2ai/completion.py b/r2ai/completion.py index 598ff37..fe506b1 100644 --- a/r2ai/completion.py +++ b/r2ai/completion.py @@ -1,8 +1,17 @@ import sys import traceback from . import LOGGER +import json +from llama_cpp.llama_types import * +from llama_cpp.llama_grammar import LlamaGrammar +from llama_cpp.llama import StoppingCriteriaList, LogitsProcessorList +from typing import List, Iterator, Dict, Any, Optional, Union, Callable, Sequence, Generator +import uuid +import llama_cpp +import re +from .partial_json_parser import parse_incomplete_json -def messages_to_prompt(self, messages): +def messages_to_prompt(self, messages, tools=None): for message in messages: # Happens if it immediatly writes code if "role" not in message: @@ -36,8 +45,16 @@ def messages_to_prompt(self, messages): formatted_messages = template_alpaca(self, messages) elif "deepseek" in lowermodel: formatted_messages = template_alpaca(self, messages) + elif "llama-3.2" in lowermodel or "llama-3.1" in lowermodel: + formatted_messages = template_llama31(self, messages, tools) elif "llama-3" in lowermodel: formatted_messages = template_llama3(self, messages) + elif "functionary" in lowermodel and 'v3.1' in lowermodel: + formatted_messages = template_functionary_v31(self, messages, tools) + elif "functionary" in lowermodel and 'v3.2' in lowermodel: + formatted_messages = template_functionary_v32(self, messages, tools) + elif 'qwen' in lowermodel: + formatted_messages = template_qwen(self, messages, tools) elif "uncensor" in lowermodel: # formatted_messages = template_gpt4all(self, messages) # formatted_messages = template_alpaca(self, messages) @@ -63,6 +80,21 @@ def messages_to_prompt(self, messages): LOGGER.debug(formatted_messages) return formatted_messages +def response_to_message(self, response): + lowermodel = self.model.lower() + if "llama-3.2" in lowermodel or "llama-3.1" in lowermodel: + return response_llama31(self, response) + elif "functionary" in lowermodel and 'v3.1' in lowermodel: + return response_functionary_v31(self, response) + elif "functionary" in lowermodel and 'v3.2' in lowermodel: + return response_functionary_v32(self, response) + elif 'qwen' in lowermodel: + return response_qwen(self, response) + else: + print("This model has not been tested with auto mode yet. Defaulting to llama-3.1", file=sys.stderr) + return response_llama31(self, response) + + def template_granite(self,messages): self.terminator = ["Question:", "Answer:"] msg = "" @@ -84,6 +116,47 @@ def template_granite(self,messages): traceback.print_exc() return msg +def template_qwen(self, messages, tools): + system_prompt = self.system_message or "" + formatted_messages = "" + if messages[0]['role'] == 'system': + system_prompt += messages[0]['content'] + formatted_messages += f"<|im_start|>system\n{system_prompt}" + if tools: + formatted_messages += f"\n\n## Tools\n\nYou have access to the following tools:\n\n" + function_names = [] + for tool in tools: + fn = tool['function'] + function_names.append(fn['name']) + formatted_messages += f"### {fn['name']}\n\n{fn['name']}: {fn['description']} Parameters: {json.dumps(fn['parameters'])} Format the arguments as a JSON object.\n\n" + formatted_messages += f"""## When you need to call a tool, please insert the following command in your reply, which can be called zero or multiple times according to your needs: + +✿FUNCTION✿: The tool to use, should be one of [{", ".join(function_names)}] +✿ARGS✿: The input of the tool +✿RESULT✿: Tool results +✿RETURN✿: Reply based on tool results. Images need to be rendered as ![](url)""" + + for index, item in enumerate(messages): + role = item['role'] + if role == 'system': + continue + formatted_messages += f"<|im_start|>{role if role != 'tool' else 'user'}\n" + if role == 'assistant': + if 'tool_calls' in item: + for tool_call in item['tool_calls']: + args = tool_call['function']['arguments'] + if type(args) != str: + args = json.dumps(args) + formatted_messages += f"✿FUNCTION✿: {tool_call['function']['name']}\n✿ARGS✿: {args}\n" + formatted_messages += f"{item['content']}" + elif role == 'tool': + formatted_messages += f"\n✿RESULT✿: {item['content'] or 'NO RESULT'}" + else: + formatted_messages += f"{item['content']}" + formatted_messages += f"<|im_end|>" + formatted_messages += f"<|im_start|>assistant\n" + return formatted_messages + def template_gemma(self,messages): self.terminator = "" msg = "" @@ -459,6 +532,438 @@ def template_llama3(self,messages): formatted_messages += f"<|start_header_id|>assistant<|end_header_id|>" return formatted_messages +def template_llama31(self,messages, tools): + formatted_messages = "" # f"<|begin_of_text|>" + system_message = "" + if tools is not None: + system_message += """ + +Environment: ipython + +You are an expert in composing functions. You are given a question and a set of possible functions. +Based on the question, you will need to make one or more function/tool calls to achieve the purpose. + +If you decide to invoke any of the function(s), you MUST put it in the format of <|python_tag|>{ "name": "func_name", "parameters": {"param_name1": "param_value1", "param_name2": "param_value2"}} +You SHOULD NOT include any other text in the response. + +Here is a list of functions in JSON format that you can invoke: + +""" + system_message += json.dumps([tool["function"] for tool in tools]) + + if self.system_message != "" and self.system_message is not None: + system_message += self.system_message + user_message = None + if messages[0]['role'] == 'system': + system_message += messages[0]['content'] + if system_message != "": + formatted_messages += f"<|start_header_id|>system<|end_header_id|>" + formatted_messages += f"{system_message}" + formatted_messages += "<|eot_id|>" + + for index, item in enumerate(messages): + role = item['role'] + if role == 'tool': + role = 'ipython' + if role == 'system': + continue + formatted_messages += f"<|start_header_id|>{role}<|end_header_id|>" + if 'tool_calls' in item: + for tool_call in item['tool_calls']: + formatted_messages += "\n\n<|python_tag|>" + json.dumps({ "name": tool_call['function']['name'], "parameters": tool_call['function']['arguments'] }) + formatted_messages += "\n<|eom_id|>" + else: + content = item['content'].strip() + if role == 'ipython': + formatted_messages += "\n\n" + if content == "": + formatted_messages += 'NO RESULTS' + formatted_messages += content + formatted_messages += "<|eot_id|>" + formatted_messages += f"<|start_header_id|>assistant<|end_header_id|>" + return formatted_messages + +def delta_text(id, text): + return { "id": id, "choices": [{ "delta": { "content": text } }] } + +def delta_tool_call(id, tool_call_id, name, params): + return { "id": id, "choices": [{ "delta": { "tool_calls": [{ "function": {"name": name, "arguments": params}, "id": tool_call_id, "type": "function", "index": 0 }] } }] } + +def response_llama31(self, response): + full_text = "" + tool_call_text_index = -1 + tool_call_id = None + tool_call = None + message = None + id = str(uuid.uuid4()) + for text in response: + full_text += text + if text == "<|python_tag|>": + tool_call_text_index = len(full_text) + continue + elif tool_call_text_index == -1: + message = delta_text(id, text) + else: + function_call_json = full_text[tool_call_text_index:].strip() + + try: + function_call = parse_incomplete_json(function_call_json) + + if function_call is not None: + if 'name' in function_call and not tool_call_id: + tool_call_id = str(uuid.uuid4()) + message = delta_tool_call(id, tool_call_id, function_call["name"], None) + elif 'parameters' in function_call: + params = function_call["parameters"] + if type(params) == str: + params = params.replace('\\"', '"') + elif type(params) == dict: + params = json.dumps(params) + tool_call = delta_tool_call(id, tool_call_id, function_call["name"], params) + except Exception: + message = delta_text(id, text) + yield message + if tool_call is not None: + yield tool_call + yield { "id": id, "choices": [{ "finish_reason": "stop" }] } + +def response_qwen(self, response): + id = str(uuid.uuid4()) + full_text = "" + lines = [] + curr_line = "" + fn_call = None + for text in response: + full_text += text + + if text == "\n": + if curr_line.startswith("✿FUNCTION✿:"): + fn_call = { 'name': curr_line[11:].strip(), 'id': str(uuid.uuid4()), 'arguments': None } + elif curr_line.startswith("✿ARGS✿:"): + fn_call['arguments'] = curr_line[7:].strip().replace('\\"', '"') + yield delta_tool_call(id, fn_call['id'], fn_call['name'], fn_call['arguments']) + lines.append(curr_line) + curr_line = "" + else: + curr_line += text + if curr_line.startswith("✿"): + continue + yield delta_text(id, text) + + if curr_line.startswith("✿ARGS✿:") and fn_call is not None: + fn_call['arguments'] = curr_line[7:].strip().replace('\\"', '"') + yield delta_tool_call(id, fn_call['id'], fn_call['name'], fn_call['arguments']) + + yield { "id": id, "choices": [{ "finish_reason": "stop" }] } + +def parse_functionary31_calls(input_str): + pattern = re.compile( + r'[^>]+)>\s*(?P(\"|\').*?(\"|\'))\s*|[^>]+)>\s*(?P\{.*?\})\s*', + re.DOTALL + ) + + matches = pattern.finditer(input_str) + parsed_functions = [] + + for match in matches: + if match.group('func_name') and match.group('params'): + func_name = match.group('func_name').strip() + params_str = match.group('params').strip() + elif match.group('func_name2') and match.group('params2'): + func_name = match.group('func_name2').strip() + params_str = match.group('params2').strip() + else: + continue # Skip if no valid match + + if not func_name: + raise ValueError("Function name is missing in one of the function strings.") + + if not params_str: + raise ValueError(f"Parameters JSON is missing for function '{func_name}'.") + + # Handle parameters wrapped in quotes + if (params_str.startswith('"') and params_str.endswith('"')) or (params_str.startswith("'") and params_str.endswith("'")): + # Strip the surrounding quotes + params_str = params_str[1:-1] + # Unescape any escaped characters + params_str = bytes(params_str, "utf-8").decode("unicode_escape") + + # Parse the JSON parameters + try: + params = json.loads(params_str) + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON for parameters in function '{func_name}': {e}") + + parsed_functions.append((func_name, params)) + + if not parsed_functions: + return None + + return parsed_functions + +def serialize_functionary31_call(name, params): + return f"{json.dumps(params)}" + +def response_functionary_v31(self, response): + full_text = "" + id = str(uuid.uuid4()) + in_function = False + for text in response: + full_text += text + if '>>' + # It captures the tool name or recipient after '>>>' + parts = re.split(r'>>>(\w+)', input_str) + + content = "" + tool_calls = [] + + # The first element in 'parts' is the content before the first '>>>' + if parts[0]: + content += parts[0].strip() + + # Iterate over the split parts + # 'parts' alternates between tool names/recipients and their corresponding data + for i in range(1, len(parts), 2): + name = parts[i].strip() + if i + 1 < len(parts): + data = parts[i + 1].strip() + if name.lower() == 'all': + # If the name is 'all', treat the following data as additional content + if content: + content += " " + data + else: + content = data + else: + # Assume it's a tool call with JSON parameters + try: + parameters = json.loads(data) + tool_calls.append({ + 'name': name, + 'parameters': parameters + }) + except json.JSONDecodeError: + pass + + return { + 'content': content, + 'tool_calls': tool_calls if len(tool_calls) > 0 else None + } + +def serialize_functionary32_calls(structure): + content = structure.get('content', None) + tool_calls = structure.get('tool_calls', []) + + parts = [] + + if content: + content = content.strip() + parts.append(f'>>>all\n{content}') + + for tool in tool_calls: + name = tool.get('name', '').strip() + parameters = tool.get('arguments', "") + + # Serialize parameters to JSON string + try: + if type(parameters) == str: + parameters_str = parameters + else: + parameters_str = json.dumps(parameters, ensure_ascii=False) + except (TypeError, ValueError) as e: + raise ValueError(f"Invalid parameters for tool '{name}': {e}") + + # Append the tool call with '>>>tool_name' followed by a newline and the parameters + parts.append(f'>>>{name}\n{parameters_str}') + + # Join all parts without any additional separators + serialized_str = ''.join(parts) + + return serialized_str + + +def template_functionary_v32(self, messages, tools): + formatted_messages = "" # f"<|begin_of_text|>" + system_message = "" + tool_system_message = "" + if tools is not None: + tools_str = "" + for tool in tools: + params_str = ", ".join([f"// {param.get('description', param_name)}\n{param_name}: {param.get('type', '')}" for param_name, param in tool['function']['parameters']['properties'].items() if isinstance(param, dict)]) + tools_str += f"// {tool['function']['description']}\ntype {tool['function']['name']} = (_: {{\n{params_str}\n}}) => any;\n\n" + + tool_system_message += f""" + +You are capable of executing available function(s) if required. +Only execute function(s) when absolutely necessary. +Ask for the required input to:recipient==all +Use JSON for function arguments. +Respond in this format: +>>>${{recipient}} +${{content}} +Available functions: +// Supported function definitions that should be called when necessary. +namespace functions {{ + +{tools_str} +}} // namespace functions""" + + if self.system_message != "" and self.system_message is not None: + system_message += self.system_message + user_message = None + if messages[0]['role'] == 'system': + system_message += messages[0]['content'] + if system_message != "": + formatted_messages += f"<|start_header_id|>system<|end_header_id|>" + formatted_messages += f"{system_message}" + formatted_messages += "<|eot_id|>" + if tool_system_message != "": + formatted_messages = f"<|start_header_id|>system<|end_header_id|>{tool_system_message}<|eot_id|>" + formatted_messages + last_tool_call = None + prev_tool_call = None + for index, item in enumerate(messages): + role = item['role'] + if role == 'system': + continue + formatted_messages += f"<|start_header_id|>{role}<|end_header_id|>" + if 'tool_calls' in item: + for tool_call in item['tool_calls']: + formatted_messages += '\n\n' + serialize_functionary32_calls({ "content": item['content'], "tool_calls": [tool_call['function']] }) + formatted_messages += "<|eot_id|>" + prev_tool_call = last_tool_call + last_tool_call = tool_call['function'] + + else: + content = item['content'].strip() + if role == 'tool': + if content != "": + formatted_messages += "\n\n" + content + else: + formatted_messages += "\n\nNO RESULTS" + else: + formatted_messages += content + formatted_messages += "<|eot_id|>" + formatted_messages += f"<|start_header_id|>assistant<|end_header_id|>" + return formatted_messages + +def response_functionary_v32(self, response): + full_text = "" + id = str(uuid.uuid4()) + tool_calls = None + in_function_call = False + for text in response: + message = None + full_text += text + if not in_function_call and re.search(r'>>>(?!all)', full_text): + in_function_call = True + if not in_function_call: + message = delta_text(id, text) + if in_function_call: + tool_calls = parse_functionary32_calls(full_text)['tool_calls'] + if message is not None: + yield message + + if tool_calls is not None: + for tool_call in tool_calls: + tool_call_id = str(uuid.uuid4()) + yield delta_tool_call(id, tool_call_id, tool_call['name'], tool_call['parameters']) + + yield { "id": id, "choices": [{ "finish_reason": "stop" }] } + +def template_functionary_v31(self, messages, tools): + formatted_messages = "" # f"<|begin_of_text|>" + system_message = "" + if tools is not None: + tools_str = "" + for tool in tools: + tools_str += f"Use the function '{tool['function']['name']}' to {tool['function']['description']}\n" + tools_str += json.dumps(tool['function']) + "\n\n" + system_message += f""" + +Cutting Knowledge Date: December 2023 + + +You have access to the following functions: + +{tools_str} +Think very carefully before calling functions. +If a you choose to call a function ONLY reply in the following format: +<{{start_tag}}={{function_name}}>{{parameters}}{{end_tag}} +where + +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +end_tag => `` + +Here is an example, +{serialize_functionary31_call("example_function_name", {"example_name": "example_value"})} + +""" + + if self.system_message != "" and self.system_message is not None: + system_message += self.system_message + user_message = None + if messages[0]['role'] == 'system': + system_message += messages[0]['content'] + if system_message != "": + formatted_messages += f"<|start_header_id|>system<|end_header_id|>" + formatted_messages += f"{system_message}" + formatted_messages += """ + +Reminder: +- Function calls MUST follow the specified format, start with +- Required parameters MUST be specified +- Only call one function at a time +- Put the entire function call reply on one line +- Respond with something after the function call is complete""" + formatted_messages += "<|eot_id|>" + for index, item in enumerate(messages): + role = item['role'] + if role == 'tool': + role = 'ipython' + if role == 'system': + continue + formatted_messages += f"<|start_header_id|>{role}<|end_header_id|>" + if 'tool_calls' in item: + for tool_call in item['tool_calls']: + formatted_messages += "\n\n" + serialize_functionary31_call(tool_call['function']['name'], tool_call['function']['arguments']) + formatted_messages += "<|eom_id|>" + content = item['content'].strip() + if role == 'ipython': + formatted_messages += "\n\n" + if content == "": + formatted_messages += "NO RESULTS" + formatted_messages += content + formatted_messages += "<|eot_id|>" + formatted_messages += f"<|start_header_id|>assistant<|end_header_id|>" + return formatted_messages + def template_llama(self,messages): formatted_messages = f"[INST]" if self.system_message != "": @@ -481,3 +986,71 @@ def template_llama(self,messages): elif role == 'assistant' and self.env["chat.reply"] == "true": formatted_messages += f"{content}[INST]" return formatted_messages + +def create_chat_completion(self, **kwargs): + messages = kwargs.pop('messages') + tools = kwargs.pop('tools') + prompt = messages_to_prompt(self, messages, tools) + return response_to_message(self, create_completion(self.llama_instance, prompt=prompt, **kwargs)) + + +def create_completion( + self, + prompt: Union[str, List[int]], + suffix: Optional[str] = None, + max_tokens: Optional[int] = 16, + temperature: float = 0.8, + top_p: float = 0.95, + min_p: float = 0.05, + typical_p: float = 1.0, + logprobs: Optional[int] = None, + echo: bool = False, + stop: Optional[Union[str, List[str]]] = [], + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repeat_penalty: float = 1.0, + top_k: int = 40, + stream: bool = False, + seed: Optional[int] = None, + tfs_z: float = 1.0, + mirostat_mode: int = 0, + mirostat_tau: float = 5.0, + mirostat_eta: float = 0.1, + model: Optional[str] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_processor: Optional[LogitsProcessorList] = None, + grammar: Optional[LlamaGrammar] = None, + logit_bias: Optional[Dict[str, float]] = None, + ) -> Union[ + Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] + ]: + + prompt_tokens = self.tokenize( + prompt.encode("utf-8"), + add_bos=False, + special=True + ) + + for token in self.generate( + prompt_tokens, + top_k=top_k, + top_p=top_p, + min_p=min_p, + typical_p=typical_p, + temp=temperature, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + repeat_penalty=repeat_penalty, + stopping_criteria=stopping_criteria, + logits_processor=logits_processor, + grammar=grammar, + ): + + if llama_cpp.llama_token_is_eog(self._model.model, token): + break + text = self.detokenize([token], special=True).decode("utf-8") + yield text \ No newline at end of file diff --git a/r2ai/partial_json_parser.py b/r2ai/partial_json_parser.py new file mode 100644 index 0000000..658c455 --- /dev/null +++ b/r2ai/partial_json_parser.py @@ -0,0 +1,223 @@ +class IncompleteJSONParser: + def __init__(self, s): + self.s = s + self.len = len(s) + self.pos = 0 + + def parse(self): + self.skip_whitespace() + if self.pos >= self.len: + return None + char = self.s[self.pos] + if char == '{': + return self.parse_object() + elif char == '[': + return self.parse_array() + else: + return None # Top-level must be object or array + + def parse_object(self): + if self.s[self.pos] != '{': + return None + self.pos += 1 # Skip '{' + self.skip_whitespace() + obj = {} + parsed_any = False # Flag to check if any key-value pair was parsed + + while self.pos < self.len: + self.skip_whitespace() + if self.pos >= self.len: + break + if self.s[self.pos] == '}': + self.pos += 1 + return obj if parsed_any else None + key = self.parse_string() + if key is None: + # Incomplete key, skip the rest + break + self.skip_whitespace() + if self.pos >= self.len or self.s[self.pos] != ':': + # Missing colon, skip this key + break + self.pos += 1 # Skip ':' + self.skip_whitespace() + value = self.parse_value() + if value is None: + # Incomplete value, skip this key-value pair + break + obj[key] = value + parsed_any = True + self.skip_whitespace() + if self.pos >= self.len: + break + if self.s[self.pos] == ',': + self.pos += 1 + continue + elif self.s[self.pos] == '}': + self.pos += 1 + return obj if parsed_any else None + else: + # Unexpected character, skip + break + # Auto-close the object + return obj if parsed_any else None + + def parse_array(self): + if self.s[self.pos] != '[': + return None + self.pos += 1 # Skip '[' + self.skip_whitespace() + array = [] + parsed_any = False # Flag to check if any element was parsed + + while self.pos < self.len: + self.skip_whitespace() + if self.pos >= self.len: + break + if self.s[self.pos] == ']': + self.pos += 1 + return array if parsed_any else None + value = self.parse_value() + if value is None: + # Incomplete value, skip this element + break + array.append(value) + parsed_any = True + self.skip_whitespace() + if self.pos >= self.len: + break + if self.s[self.pos] == ',': + self.pos += 1 + continue + elif self.s[self.pos] == ']': + self.pos += 1 + return array if parsed_any else None + else: + # Unexpected character, skip + break + # Auto-close the array + return array if parsed_any else None + + def parse_value(self): + self.skip_whitespace() + if self.pos >= self.len: + return None + char = self.s[self.pos] + if char == '"': + return self.parse_string() + elif char == '{': + return self.parse_object() + elif char == '[': + return self.parse_array() + elif char in '-0123456789': + return self.parse_number() + elif self.s.startswith('true', self.pos): + self.pos += 4 + return True + elif self.s.startswith('false', self.pos): + self.pos += 5 + return False + elif self.s.startswith('null', self.pos): + self.pos += 4 + return None + else: + return None # Invalid value + + def parse_string(self): + if self.s[self.pos] != '"': + return None + self.pos += 1 # Skip opening quote + result = "" + while self.pos < self.len: + char = self.s[self.pos] + if char == '\\': + if self.pos + 1 >= self.len: + result += '\\' + self.pos += 1 + break # Incomplete escape, return what we have + self.pos += 1 + escape_char = self.s[self.pos] + if escape_char == '"': + result += '"' + elif escape_char == '\\': + result += '\\' + elif escape_char == '/': + result += '/' + elif escape_char == 'b': + result += '\b' + elif escape_char == 'f': + result += '\f' + elif escape_char == 'n': + result += '\n' + elif escape_char == 'r': + result += '\r' + elif escape_char == 't': + result += '\t' + elif escape_char == 'u': + # Unicode escape + if self.pos + 4 >= self.len: + # Incomplete unicode escape + break + hex_digits = self.s[self.pos+1:self.pos+5] + try: + code_point = int(hex_digits, 16) + result += chr(code_point) + self.pos += 4 + except ValueError: + break # Invalid unicode escape + else: + # Invalid escape character + result += '\\' + escape_char # Keep it as is + elif char == '"': + self.pos += 1 # Skip closing quote + return result + else: + result += char + self.pos += 1 + # Return the partial string if incomplete + return result + + def parse_number(self): + start = self.pos + if self.s[self.pos] == '-': + self.pos += 1 + if self.pos >= self.len: + return None + if self.s[self.pos] == '0': + self.pos += 1 + elif '1' <= self.s[self.pos] <= '9': + while self.pos < self.len and self.s[self.pos].isdigit(): + self.pos += 1 + else: + return None # Invalid number + if self.pos < self.len and self.s[self.pos] == '.': + self.pos += 1 + if self.pos >= self.len or not self.s[self.pos].isdigit(): + return None # Incomplete fraction + while self.pos < self.len and self.s[self.pos].isdigit(): + self.pos += 1 + if self.pos < self.len and self.s[self.pos] in 'eE': + self.pos += 1 + if self.pos < self.len and self.s[self.pos] in '+-': + self.pos += 1 + if self.pos >= self.len or not self.s[self.pos].isdigit(): + return None # Incomplete exponent + while self.pos < self.len and self.s[self.pos].isdigit(): + self.pos += 1 + num_str = self.s[start:self.pos] + try: + if '.' in num_str or 'e' in num_str or 'E' in num_str: + return float(num_str) + else: + return int(num_str) + except ValueError: + return None # Invalid number + + def skip_whitespace(self): + while self.pos < self.len and self.s[self.pos] in ' \t\n\r': + self.pos += 1 + +def parse_incomplete_json(s): + parser = IncompleteJSONParser(s) + result = parser.parse() + return result \ No newline at end of file