forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(sandbox): Add Jupyter Kernel for Interactive Python Interpreter …
…for Sandbox (All-Hands-AI#1215) * add initial version of py interpreter * fix bug * fix async issue * remove debugging print statement * initialize kernel & update printing * fix port mapping * uncomment debug lines * fix poetry lock * make jupyter py interpreter into a subclass
- Loading branch information
1 parent
7b47a35
commit 4271ad9
Showing
8 changed files
with
341 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
import re | ||
import os | ||
import tornado | ||
import asyncio | ||
from tornado.escape import json_encode, json_decode, url_escape | ||
from tornado.websocket import websocket_connect, WebSocketClientConnection | ||
from tornado.ioloop import PeriodicCallback | ||
from tornado.httpclient import AsyncHTTPClient, HTTPRequest | ||
|
||
from opendevin.logger import opendevin_logger as logger | ||
from uuid import uuid4 | ||
|
||
|
||
def strip_ansi(o: str) -> str: | ||
""" | ||
Removes ANSI escape sequences from `o`, as defined by ECMA-048 in | ||
http://www.ecma-international.org/publications/files/ECMA-ST/Ecma-048.pdf | ||
# https://github.com/ewen-lbh/python-strip-ansi/blob/master/strip_ansi/__init__.py | ||
>>> strip_ansi("\\033[33mLorem ipsum\\033[0m") | ||
'Lorem ipsum' | ||
>>> strip_ansi("Lorem \\033[38;25mIpsum\\033[0m sit\\namet.") | ||
'Lorem Ipsum sit\\namet.' | ||
>>> strip_ansi("") | ||
'' | ||
>>> strip_ansi("\\x1b[0m") | ||
'' | ||
>>> strip_ansi("Lorem") | ||
'Lorem' | ||
>>> strip_ansi('\\x1b[38;5;32mLorem ipsum\\x1b[0m') | ||
'Lorem ipsum' | ||
>>> strip_ansi('\\x1b[1m\\x1b[46m\\x1b[31mLorem dolor sit ipsum\\x1b[0m') | ||
'Lorem dolor sit ipsum' | ||
""" | ||
|
||
# pattern = re.compile(r'/(\x9B|\x1B\[)[0-?]*[ -\/]*[@-~]/') | ||
pattern = re.compile(r'\x1B\[\d+(;\d+){0,2}m') | ||
stripped = pattern.sub('', o) | ||
return stripped | ||
|
||
|
||
class JupyterKernel: | ||
def __init__(self, url_suffix, convid, lang='python'): | ||
self.base_url = f'http://{url_suffix}' | ||
self.base_ws_url = f'ws://{url_suffix}' | ||
self.lang = lang | ||
self.kernel_id = None | ||
self.convid = convid | ||
logger.info( | ||
f'Jupyter kernel created for conversation {convid} at {url_suffix}' | ||
) | ||
|
||
self.heartbeat_interval = 10000 # 10 seconds | ||
self.heartbeat_callback = None | ||
|
||
async def initialize(self): | ||
await self.execute(r'%colors nocolor') | ||
# pre-defined tools | ||
# self.tools_to_run = [ | ||
# # TODO: You can add code for your pre-defined tools here | ||
# ] | ||
# for tool in self.tools_to_run: | ||
# # logger.info(f"Tool initialized:\n{tool}") | ||
# await self.execute(tool) | ||
|
||
async def _send_heartbeat(self): | ||
if not hasattr(self, 'ws') or not self.ws: | ||
return | ||
try: | ||
self.ws.ping() | ||
logger.debug('Heartbeat sent...') | ||
except tornado.iostream.StreamClosedError: | ||
logger.info('Heartbeat failed, reconnecting...') | ||
try: | ||
await self._connect() | ||
except ConnectionRefusedError: | ||
logger.info( | ||
'ConnectionRefusedError: Failed to reconnect to kernel websocket - Is the kernel still running?' | ||
) | ||
|
||
async def _connect(self): | ||
if hasattr(self, 'ws') and self.ws: | ||
self.ws.close() | ||
self.ws: WebSocketClientConnection = None | ||
|
||
client = AsyncHTTPClient() | ||
if not self.kernel_id: | ||
n_tries = 5 | ||
while n_tries > 0: | ||
try: | ||
response = await client.fetch( | ||
'{}/api/kernels'.format(self.base_url), | ||
method='POST', | ||
body=json_encode({'name': self.lang}), | ||
) | ||
kernel = json_decode(response.body) | ||
self.kernel_id = kernel['id'] | ||
break | ||
except Exception as e: | ||
logger.error('Failed to connect to kernel') | ||
logger.exception(e) | ||
logger.info('Retrying in 5 seconds...') | ||
# kernels are not ready yet | ||
n_tries -= 1 | ||
await asyncio.sleep(5) | ||
|
||
if n_tries == 0: | ||
raise ConnectionRefusedError('Failed to connect to kernel') | ||
|
||
ws_req = HTTPRequest( | ||
url='{}/api/kernels/{}/channels'.format( | ||
self.base_ws_url, url_escape(self.kernel_id) | ||
) | ||
) | ||
self.ws = await websocket_connect(ws_req) | ||
logger.info('Connected to kernel websocket') | ||
|
||
# Setup heartbeat | ||
if self.heartbeat_callback: | ||
self.heartbeat_callback.stop() | ||
self.heartbeat_callback = PeriodicCallback( | ||
self._send_heartbeat, self.heartbeat_interval | ||
) | ||
self.heartbeat_callback.start() | ||
|
||
async def execute(self, code, timeout=60): | ||
if not hasattr(self, 'ws') or not self.ws: | ||
await self._connect() | ||
|
||
msg_id = uuid4().hex | ||
self.ws.write_message( | ||
json_encode( | ||
{ | ||
'header': { | ||
'username': '', | ||
'version': '5.0', | ||
'session': '', | ||
'msg_id': msg_id, | ||
'msg_type': 'execute_request', | ||
}, | ||
'parent_header': {}, | ||
'channel': 'shell', | ||
'content': { | ||
'code': code, | ||
'silent': False, | ||
'store_history': False, | ||
'user_expressions': {}, | ||
'allow_stdin': False, | ||
}, | ||
'metadata': {}, | ||
'buffers': {}, | ||
} | ||
) | ||
) | ||
logger.info(f'EXECUTE REQUEST SENT:\n{code}') | ||
|
||
outputs = [] | ||
|
||
async def wait_for_messages(): | ||
execution_done = False | ||
while not execution_done: | ||
msg = await self.ws.read_message() | ||
msg = json_decode(msg) | ||
msg_type = msg['msg_type'] | ||
parent_msg_id = msg['parent_header'].get('msg_id', None) | ||
|
||
if parent_msg_id != msg_id: | ||
continue | ||
|
||
if os.environ.get('DEBUG', False): | ||
logger.info( | ||
f"MSG TYPE: {msg_type.upper()} DONE:{execution_done}\nCONTENT: {msg['content']}" | ||
) | ||
|
||
if msg_type == 'error': | ||
traceback = '\n'.join(msg['content']['traceback']) | ||
outputs.append(traceback) | ||
execution_done = True | ||
elif msg_type == 'stream': | ||
outputs.append(msg['content']['text']) | ||
elif msg_type in ['execute_result', 'display_data']: | ||
outputs.append(msg['content']['data']['text/plain']) | ||
if 'image/png' in msg['content']['data']: | ||
# use markdone to display image (in case of large image) | ||
# outputs.append(f"\n<img src=\"data:image/png;base64,{msg['content']['data']['image/png']}\"/>\n") | ||
outputs.append( | ||
f"![image](data:image/png;base64,{msg['content']['data']['image/png']})" | ||
) | ||
|
||
elif msg_type == 'execute_reply': | ||
execution_done = True | ||
return execution_done | ||
|
||
async def interrupt_kernel(): | ||
client = AsyncHTTPClient() | ||
interrupt_response = await client.fetch( | ||
f'{self.base_url}/api/kernels/{self.kernel_id}/interrupt', | ||
method='POST', | ||
body=json_encode({'kernel_id': self.kernel_id}), | ||
) | ||
logger.info(f'Kernel interrupted: {interrupt_response}') | ||
|
||
try: | ||
execution_done = await asyncio.wait_for(wait_for_messages(), timeout) | ||
except asyncio.TimeoutError: | ||
await interrupt_kernel() | ||
return f'[Execution timed out ({timeout} seconds).]' | ||
|
||
if not outputs and execution_done: | ||
ret = '[Code executed successfully with no output]' | ||
else: | ||
ret = ''.join(outputs) | ||
|
||
# Remove ANSI | ||
ret = strip_ansi(ret) | ||
|
||
if os.environ.get('DEBUG', False): | ||
logger.info(f'OUTPUT:\n{ret}') | ||
return ret | ||
|
||
async def shutdown_async(self): | ||
if self.kernel_id: | ||
client = AsyncHTTPClient() | ||
await client.fetch( | ||
'{}/api/kernels/{}'.format(self.base_url, self.kernel_id), | ||
method='DELETE', | ||
) | ||
self.kernel_id = None | ||
if self.ws: | ||
self.ws.close() | ||
self.ws = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.