Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions solara/server/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import threading
from io import TextIOBase
from typing import Callable, List, Optional

import IPython


class ThreadLocal(threading.local):
redirect: Optional[TextIOBase] = None
hooks: Optional[List[Callable[[str], None]]] = None


class OutStream(TextIOBase):
"""A file like object that can dispatch/redirect based on a thread local state."""

def __init__(self, default, name):
self._default = default
self.name = name
self._local = ThreadLocal()

@property
def _redirect(self):
return self._local.redirect

def write(self, string: str) -> Optional[int]:
# self._default.write("DEBUG: [" + string + "]")
data = string
content = {"name": self.name, "text": data}

kernel = IPython.get_ipython().kernel
session = kernel.session
msg = session.msg("stream", content) # does it matter to not have parent, parent=self.parent_header)
for hook in self._hooks:
msg = hook(msg)
if msg is None:
return None

dispatch = self._redirect or self._default
return dispatch.write(string)

@property
def _hooks(self):
if self._local.hooks is None:
self._local.hooks = []
return self._local.hooks

def register_hook(self, hook):
self._hooks.append(hook)

def unregister_hook(self, hook):
self._hooks.remove(hook)
13 changes: 13 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from IPython.core.interactiveshell import InteractiveShell

from . import app, reload, settings
from .io import OutStream
from .utils import pdb_guard

logger = logging.getLogger("solara.server.app")
Expand Down Expand Up @@ -260,6 +261,8 @@ def Thread_debug_run(self):
_patched = False
global_widgets_dict = {}
global_templates_dict: Dict[Any, Any] = {}
stdout = sys.stdout
stderr = sys.stderr


def Output_enter(self):
Expand All @@ -269,16 +272,23 @@ def hook(msg):
if msg["msg_type"] == "display_data":
self.outputs += ({"output_type": "display_data", "data": msg["content"]["data"], "metadata": msg["content"]["metadata"]},)
return None
if msg["msg_type"] == "stream":
self.outputs += ({"output_type": "stream", "name": msg["content"]["name"], "text": msg["content"]["text"]},)
return None
if msg["msg_type"] == "clear_output":
self.outputs = ()
return None
return msg

get_ipython().display_pub.register_hook(hook)
assert isinstance(sys.stdout, OutStream)
sys.stdout.register_hook(hook)


def Output_exit(self, exc_type, exc_value, traceback):
get_ipython().display_pub._hooks.pop()
assert isinstance(sys.stdout, OutStream)
sys.stdout._hooks.pop()


def patch():
Expand Down Expand Up @@ -337,6 +347,9 @@ def patch():
ipywidgets.widgets.widget_output.Output.__enter__ = Output_enter
ipywidgets.widgets.widget_output.Output.__exit__ = Output_exit

sys.stdout = OutStream(sys.stdout, "stdout") # type: ignore
sys.stderr = OutStream(sys.stderr, "stderr") # type: ignore

original_close = ipywidgets.widget.Widget.close
closed_ids = set()
closed_stack: Dict[int, str] = {}
Expand Down