Skip to content

Commit

Permalink
Auto updates to handle more params and non-streaming mode
Browse files Browse the repository at this point in the history
  • Loading branch information
dnakov authored and trufae committed Oct 28, 2024
1 parent 69089ae commit 91ac2fb
Showing 1 changed file with 42 additions and 25 deletions.
67 changes: 42 additions & 25 deletions r2ai/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
import re
from . import LOGGER
import litellm
from litellm import _should_retry, acompletion, utils, ModelResponse
import asyncio
from .tools import r2cmd, run_python
Expand All @@ -12,6 +13,8 @@
from .completion import create_chat_completion
import uuid

litellm.drop_params = True

ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])')
SYSTEM_PROMPT_AUTO = """
You are a reverse engineer and you are using radare2 to analyze a binary.
Expand All @@ -36,15 +39,19 @@
"""

class ChatAuto:
def __init__(self, model, interpreter, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ):
def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, cb=None ):
self.logger = LOGGER
self.functions = {}
self.tools = []
self.model = model
self.max_tokens = max_tokens
self.top_p = top_p
self.temperature = temperature
self.system = system
self.messages = messages
self.interpreter = interpreter
self.system_message = None
self.timeout = timeout
if messages and messages[0]['role'] != 'system' and system:
self.messages.insert(0, { "role": "system", "content": system })
if cb:
Expand All @@ -58,7 +65,7 @@ def __init__(self, model, interpreter, system=None, tools=None, messages=None, t
self.tools.append({ "type": "function", "function": f })
self.functions[f['name']] = tool
self.tool_choice = tool_choice
self.llama_instance = llama_instance
self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None

#self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.'

Expand Down Expand Up @@ -137,21 +144,23 @@ async def process_streaming_response(self, resp):
self.messages.append(current_message)
if len(current_message['tool_calls']) > 0:
await self.process_tool_calls(current_message['tool_calls'])
if len(current_message['content']) > 0:
return current_message
return current_message

async def process_response(self, resp):
content = resp.choices[0].message.content
tool_calls = []
current_message = { 'role': 'assistant', 'content': content or '', 'tool_calls': [] }
for tool_call in resp.choices[0].message.tool_calls or []:
tool_calls.append({
current_message['tool_calls'].append({
"id": tool_call.id,
"type": "function",
"index": tool_call.index,
"function": {
"name": tool_call.function.name,
"arguments": tool_call.function.arguments
},
})
if len(tool_calls) == 0:
if len(current_message['tool_calls']) == 0:
try:
tool_call = json.loads(content)
if 'name' in tool_call and tool_call['name'] in self.functions:
Expand All @@ -161,27 +170,27 @@ async def process_response(self, resp):
elif 'parameters' in tool_call:
args = tool_call['parameters']
if args:
tool_calls.append({ "id": resp.id, "function": { "name": tool_call['name'], "arguments": json.dumps(args) } })
current_message['tool_calls'].append({ "id": resp.id, "function": { "name": tool_call['name'], "arguments": json.dumps(args) } })
except Exception:
pass
if len(tool_calls) > 0:
self.messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""})
await self.process_tool_calls(tool_calls)
if content is None or len(content) > 0:
self.messages.append({"role": "assistant", "content": content or ""})
if len(current_message['tool_calls']) > 0:
self.messages.append(current_message)
await self.process_tool_calls(current_message['tool_calls'])

return content
self.messages.append(current_message)

return current_message

async def async_response_generator(self, response):
for item in response:
self.logger.debug(item)
resp = ModelResponse(stream=True, **item)
yield resp

async def attempt_completion(self, stream=True):
args = {
"temperature": 0,
"max_tokens": int(self.interpreter.env["llm.maxtokens"]),
"temperature": self.temperature,
"top_p": self.top_p,
"max_tokens": self.max_tokens,
"stream": stream,
}

Expand All @@ -195,13 +204,13 @@ async def attempt_completion(self, stream=True):
return await acompletion(
model=self.model,
messages=self.messages,
timeout=self.timeout,
**args
)

async def get_completion(self):
async def get_completion(self, stream=False):
if self.llama_instance:
response = await self.attempt_completion()
stream = True
response = await self.attempt_completion(stream=stream)
if stream:
return await self.process_streaming_response(response)
else:
Expand All @@ -211,8 +220,11 @@ async def get_completion(self):

for retry_count in range(max_retries):
try:
response = await self.attempt_completion()
return await self.process_streaming_response(response)
response = await self.attempt_completion(stream=stream)
if stream:
return await self.process_streaming_response(response)
else:
return await self.process_response(response)
except Exception as e:
self.logger.error(f'Error getting completion: {e}')
if not _should_retry(getattr(e, 'status_code', None)) or retry_count == max_retries - 1:
Expand All @@ -224,10 +236,15 @@ async def get_completion(self):

raise Exception("Max retries reached. Unable to get completion.")

async def chat(self) -> str:
response = await self.get_completion()
async def achat(self, messages=None, stream=False) -> str:
if messages:
self.messages = messages
response = await self.get_completion(stream)
return response

def chat(self, **kwargs) -> str:
return asyncio.run(self.achat(**kwargs))

def cb(type, data):
spinner.stop()
if type == 'message_stream':
Expand Down Expand Up @@ -265,14 +282,14 @@ def chat(interpreter, **kwargs):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

chat_auto = ChatAuto(model, interpreter=interpreter, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=interpreter.llama_instance, cb=cb)
chat_auto = ChatAuto(model, max_tokens=int(interpreter.env["llm.maxtokens"]), top_p=float(interpreter.env["llm.top_p"]), temperature=float(interpreter.env["llm.temperature"]), interpreter=interpreter, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=interpreter.llama_instance, cb=cb)

original_handler = signal.getsignal(signal.SIGINT)

try:
signal.signal(signal.SIGINT, signal_handler)
spinner.start()
return loop.run_until_complete(chat_auto.chat())
return loop.run_until_complete(chat_auto.achat(stream=True))
except KeyboardInterrupt:
builtins.print("\033[91m\nOperation cancelled by user.\033[0m")
tasks = asyncio.all_tasks(loop=loop)
Expand Down

0 comments on commit 91ac2fb

Please sign in to comment.