diff --git a/r2ai/auto.py b/r2ai/auto.py index 637bed2..2e859ab 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -3,6 +3,7 @@ import sys import re from . import LOGGER +import litellm from litellm import _should_retry, acompletion, utils, ModelResponse import asyncio from .tools import r2cmd, run_python @@ -12,6 +13,8 @@ from .completion import create_chat_completion import uuid +litellm.drop_params = True + ANSI_REGEX = re.compile(r'\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])') SYSTEM_PROMPT_AUTO = """ You are a reverse engineer and you are using radare2 to analyze a binary. @@ -36,15 +39,19 @@ """ class ChatAuto: - def __init__(self, model, interpreter, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, cb=None ): + def __init__(self, model, max_tokens = 1024, top_p=0.95, temperature=0.0, interpreter=None, system=None, tools=None, messages=None, tool_choice='auto', llama_instance=None, timeout=None, cb=None ): self.logger = LOGGER self.functions = {} self.tools = [] self.model = model + self.max_tokens = max_tokens + self.top_p = top_p + self.temperature = temperature self.system = system self.messages = messages self.interpreter = interpreter self.system_message = None + self.timeout = timeout if messages and messages[0]['role'] != 'system' and system: self.messages.insert(0, { "role": "system", "content": system }) if cb: @@ -58,7 +65,7 @@ def __init__(self, model, interpreter, system=None, tools=None, messages=None, t self.tools.append({ "type": "function", "function": f }) self.functions[f['name']] = tool self.tool_choice = tool_choice - self.llama_instance = llama_instance + self.llama_instance = llama_instance or interpreter.llama_instance if interpreter else None #self.tool_end_message = '\nNOTE: The user saw this output, do not repeat it.' @@ -137,21 +144,23 @@ async def process_streaming_response(self, resp): self.messages.append(current_message) if len(current_message['tool_calls']) > 0: await self.process_tool_calls(current_message['tool_calls']) - if len(current_message['content']) > 0: - return current_message + return current_message async def process_response(self, resp): content = resp.choices[0].message.content tool_calls = [] + current_message = { 'role': 'assistant', 'content': content or '', 'tool_calls': [] } for tool_call in resp.choices[0].message.tool_calls or []: - tool_calls.append({ + current_message['tool_calls'].append({ "id": tool_call.id, + "type": "function", + "index": tool_call.index, "function": { "name": tool_call.function.name, "arguments": tool_call.function.arguments }, }) - if len(tool_calls) == 0: + if len(current_message['tool_calls']) == 0: try: tool_call = json.loads(content) if 'name' in tool_call and tool_call['name'] in self.functions: @@ -161,27 +170,27 @@ async def process_response(self, resp): elif 'parameters' in tool_call: args = tool_call['parameters'] if args: - tool_calls.append({ "id": resp.id, "function": { "name": tool_call['name'], "arguments": json.dumps(args) } }) + current_message['tool_calls'].append({ "id": resp.id, "function": { "name": tool_call['name'], "arguments": json.dumps(args) } }) except Exception: pass - if len(tool_calls) > 0: - self.messages.append({"role": "assistant", "tool_calls": tool_calls, "content": ""}) - await self.process_tool_calls(tool_calls) - if content is None or len(content) > 0: - self.messages.append({"role": "assistant", "content": content or ""}) + if len(current_message['tool_calls']) > 0: + self.messages.append(current_message) + await self.process_tool_calls(current_message['tool_calls']) - return content + self.messages.append(current_message) + + return current_message async def async_response_generator(self, response): for item in response: - self.logger.debug(item) resp = ModelResponse(stream=True, **item) yield resp async def attempt_completion(self, stream=True): args = { - "temperature": 0, - "max_tokens": int(self.interpreter.env["llm.maxtokens"]), + "temperature": self.temperature, + "top_p": self.top_p, + "max_tokens": self.max_tokens, "stream": stream, } @@ -195,13 +204,13 @@ async def attempt_completion(self, stream=True): return await acompletion( model=self.model, messages=self.messages, + timeout=self.timeout, **args ) - async def get_completion(self): + async def get_completion(self, stream=False): if self.llama_instance: - response = await self.attempt_completion() - stream = True + response = await self.attempt_completion(stream=stream) if stream: return await self.process_streaming_response(response) else: @@ -211,8 +220,11 @@ async def get_completion(self): for retry_count in range(max_retries): try: - response = await self.attempt_completion() - return await self.process_streaming_response(response) + response = await self.attempt_completion(stream=stream) + if stream: + return await self.process_streaming_response(response) + else: + return await self.process_response(response) except Exception as e: self.logger.error(f'Error getting completion: {e}') if not _should_retry(getattr(e, 'status_code', None)) or retry_count == max_retries - 1: @@ -224,10 +236,15 @@ async def get_completion(self): raise Exception("Max retries reached. Unable to get completion.") - async def chat(self) -> str: - response = await self.get_completion() + async def achat(self, messages=None, stream=False) -> str: + if messages: + self.messages = messages + response = await self.get_completion(stream) return response + def chat(self, **kwargs) -> str: + return asyncio.run(self.achat(**kwargs)) + def cb(type, data): spinner.stop() if type == 'message_stream': @@ -265,14 +282,14 @@ def chat(interpreter, **kwargs): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - chat_auto = ChatAuto(model, interpreter=interpreter, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=interpreter.llama_instance, cb=cb) + chat_auto = ChatAuto(model, max_tokens=int(interpreter.env["llm.maxtokens"]), top_p=float(interpreter.env["llm.top_p"]), temperature=float(interpreter.env["llm.temperature"]), interpreter=interpreter, system=SYSTEM_PROMPT_AUTO, tools=tools, messages=messages, tool_choice=tool_choice, llama_instance=interpreter.llama_instance, cb=cb) original_handler = signal.getsignal(signal.SIGINT) try: signal.signal(signal.SIGINT, signal_handler) spinner.start() - return loop.run_until_complete(chat_auto.chat()) + return loop.run_until_complete(chat_auto.achat(stream=True)) except KeyboardInterrupt: builtins.print("\033[91m\nOperation cancelled by user.\033[0m") tasks = asyncio.all_tasks(loop=loop)