Skip to content

Commit

Permalink
Auto mode - rawdog local models; add tool support for llama3.1, 3.2, …
Browse files Browse the repository at this point in the history
…qwen, functionary 3.1 and 3.2
  • Loading branch information
dnakov committed Oct 28, 2024
1 parent c04dde6 commit c451a1c
Show file tree
Hide file tree
Showing 3 changed files with 893 additions and 43 deletions.
138 changes: 96 additions & 42 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,17 @@
import json
import sys
import re
import os
from llama_cpp import Llama
from llama_cpp.llama_tokenizer import LlamaHFTokenizer
from transformers import AutoTokenizer
from . import index
from .pipe import have_rlang, r2lang, get_r2_inst
from . import LOGGER
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from r2ai.pipe import get_r2_inst
from .tools import r2cmd, run_python
import json
import signal
from .spinner import spinner
from .completion import create_chat_completion
import uuid

ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')

SYSTEM_PROMPT_AUTO = """
You are a reverse engineer and you are using radare2 to analyze a binary.
The user will ask questions about the binary and you will respond with the answer to the best of your ability.
Expand All @@ -41,12 +36,15 @@
"""

class ChatAuto:
def __init__(self, model, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ):
def __init__(self, model, interpreter, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ):
self.logger = LOGGER
self.functions = {}
self.tools = []
self.model = model
self.system = system
self.messages = messages
self.interpreter = interpreter
self.system_message = None
if messages and messages[0]['role'] != 'system' and system:
self.messages.insert(0, { "role": "system", "content": system })
if cb:
Expand All @@ -67,29 +65,40 @@ def __init__(self, model, system=None, tools=None, messages=None, tool_choice='a
async def process_tool_calls(self, tool_calls):
if tool_calls:
for tool_call in tool_calls:
self.logger.debug(f"tool_call: {tool_call}")
tool_name = tool_call["function"]["name"]
if "id" not in tool_call:
tool_call["id"] = str(uuid.uuid4())
try:
tool_args = json.loads(tool_call["function"]["arguments"])
except Exception:
self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Unable to parse JSON" , "tool_call_id": tool_call["id"]})
continue
if "arguments" in tool_call["function"] and type(tool_call["function"]["arguments"]) == dict:
tool_args = tool_call["function"]["arguments"]
else:
self.logger.error(f'Error parsing JSON: {tool_call["function"]["arguments"]}')
# raise Exception('Error parsing JSON')
self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Unable to parse JSON" , "tool_call_id": tool_call["id"]})
continue
if tool_name not in self.functions:
self.logger.error(f'Tool not found: {tool_name}')
self.messages.append({"role": "tool", "name": tool_name, "content": "Error: Tool not found" , "tool_call_id": tool_call["id"]})
continue

self.cb('tool_call', { "id": tool_call["id"], "function": { "name": tool_name, "arguments": tool_args } })
if asyncio.iscoroutinefunction(self.functions[tool_name]):
tool_response = await self.functions[tool_name](**tool_args)
else:
tool_response = self.functions[tool_name](**tool_args)
self.cb('tool_response', { "id": tool_call["id"] + '_response', "content": tool_response })
self.messages.append({"role": "tool", "name": tool_name, "content": ANSI_REGEX.sub('', tool_response), "tool_call_id": tool_call["id"]})

return await self.get_completion()

async def process_streaming_response(self, resp):
tool_calls = []
msgs = []
parts = []
current_message = { "role": "assistant", "content": "", "tool_calls": [] }
async for chunk in resp:
delta = None
choice = chunk.choices[0]
Expand All @@ -99,47 +108,89 @@ async def process_streaming_response(self, resp):
index = delta_tool_calls.index
fn_delta = delta_tool_calls.function
tool_call_id = delta_tool_calls.id
if len(tool_calls) < index + 1:
tool_calls.append({
"id": tool_call_id,
if len(current_message['tool_calls']) < index + 1:
tool_call = {
"id": tool_call_id or str(uuid.uuid4()),
"type": "function",
"function": {
"name":fn_delta.name,
"arguments": fn_delta.arguments
}
}
)
current_message['tool_calls'].append(tool_call)
else:
tool_calls[index]["function"]["arguments"] += fn_delta.arguments
if fn_delta.name:
current_message['tool_calls'][index]["function"]["name"] = fn_delta.name
current_message['tool_calls'][index]["function"]["arguments"] += fn_delta.arguments
else:
m = None
done = False
if delta.content is not None:
m = delta.content
if m is not None:
msgs.append(m)
current_message['content'] += m
self.cb('message', { "content": m, "id": 'message_' + chunk.id, 'done': False })
if 'finish_reason' in choice and choice['finish_reason'] == 'stop':
done = True
self.cb('message', { "content": "", "id": 'message_' + chunk.id, 'done': True })
self.cb('message_stream', { "content": m if m else '', "id": 'message_' + chunk.id, 'done': done })
if (len(tool_calls) > 0):
self.messages.append({"role": "assistant", "tool_calls": tool_calls})
self.messages.append(current_message)
if len(current_message['tool_calls']) > 0:
await self.process_tool_calls(current_message['tool_calls'])
if len(current_message['content']) > 0:
return current_message

async def process_response(self, resp):
content = resp.choices[0].message.content
tool_calls = []
for tool_call in resp.choices[0].message.tool_calls or []:
tool_calls.append({
"id": tool_call.id,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
},
})
if len(tool_calls) == 0:
try:
tool_call = json.loads(content)
if 'name' in tool_call and tool_call['name'] in self.functions:
args = None
if 'arguments' in tool_call:
args = tool_call['arguments']
elif 'parameters' in tool_call:
args = tool_call['parameters']
if args:
tool_calls.append({ "id": resp.id, "function": { "name": tool_call['name'], "arguments": json.dumps(args) } })
except Exception:
pass
if len(tool_calls) > 0:
self.messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""})
await self.process_tool_calls(tool_calls)
if len(msgs) > 0:
response_message = ''.join(msgs)
self.messages.append({"role": "assistant", "content": response_message})
return response_message
if content is None or len(content) > 0:
self.messages.append({"role": "assistant", "content": content or ""})

return content

async def async_response_generator(self, response):
for item in response:
self.logger.debug(item)
resp = ModelResponse(stream=True, **item)
yield resp

async def attempt_completion(self):
async def attempt_completion(self, stream=True):
args = {
"temperature": 0,
"tools": self.tools,
"tool_choice": self.tool_choice,
"stream": True
"max_tokens": int(self.interpreter.env["llm.maxtokens"]),
"stream": stream,
}
if self.llama_instance:
return self.llama_instance.create_chat_completion(self.messages, **args)

if self.llama_instance:
res = create_chat_completion(self.interpreter, messages=self.messages, tools=[self.tools[0]], **args)
if args['stream']:
return self.async_response_generator(res)
else:
return ModelResponse(**next(res))

return await acompletion(
model=self.model,
Expand All @@ -150,10 +201,11 @@ async def attempt_completion(self):
async def get_completion(self):
if self.llama_instance:
response = await self.attempt_completion()
async def async_generator(response):
for item in response:
yield ModelResponse(stream=True, **item)
return await self.process_streaming_response(async_generator(response))
stream = True
if stream:
return await self.process_streaming_response(response)
else:
return await self.process_response(response)
max_retries = 5
base_delay = 2

Expand All @@ -162,12 +214,12 @@ async def async_generator(response):
response = await self.attempt_completion()
return await self.process_streaming_response(response)
except Exception as e:
print(e)
self.logger.error(f'Error getting completion: {e}')
if not _should_retry(getattr(e, 'status_code', None)) or retry_count == max_retries - 1:
raise

delay = base_delay * (2 ** retry_count)
print(f"Retrying in {delay} seconds...")
self.logger.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)

raise Exception("Max retries reached. Unable to get completion.")
Expand All @@ -179,24 +231,26 @@ async def chat(self) -> str:
def cb(type, data):
spinner.stop()
if type == 'message_stream':
sys.stdout.write(data['content'])
if 'content' in data:
sys.stdout.write(data['content'])
elif type == 'tool_call':
if data['function']['name'] == 'r2cmd':
builtins.print('\x1b[1;32m> \x1b[4m' + data['function']['arguments']['command'] + '\x1b[0m')
elif data['function']['name'] == 'run_python':
builtins.print('\x1b[1;32m> \x1b[4m' + "#!python" + '\x1b[0m')
builtins.print(data['function']['arguments']['command'])
elif type == 'tool_response':
sys.stdout.write(data['content'])
sys.stdout.flush()
if 'content' in data:
sys.stdout.write(data['content'])
sys.stdout.flush()
# builtins.print(data['content'])
elif type == 'message' and data['done']:
builtins.print()

def signal_handler(signum, frame):
raise KeyboardInterrupt

def chat(interpreter, llama_instance=None):
def chat(interpreter, **kwargs):
model = interpreter.model.replace(":", "/")
tools = [r2cmd, run_python]
messages = interpreter.messages
Expand All @@ -211,7 +265,7 @@ def chat(interpreter, llama_instance=None):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

chat_auto = ChatAuto(model, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=llama_instance, cb=cb)
chat_auto = ChatAuto(model, interpreter=interpreter, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=interpreter.llama_instance, cb=cb)

original_handler = signal.getsignal(signal.SIGINT)

Expand Down
Loading

0 comments on commit c451a1c

Please sign in to comment.