Skip to content

Commit

Permalink
Add a disconnect button for RayContext in notebooks (#35507)
Browse files Browse the repository at this point in the history
This PR reverts #35426, adding back in a nice button for disconnecting (i.e. calling `ray.shutdown()`) when `ray.init()` is called in a notebook.

## Testing

The main issue with the previous disconnect PR is that checking for the `ipywidgets` soft dependency at run time introduces a small performance penalty which bumps some test suites up beyond their timeout upper bound, breaking them. This happens in tests unrelated to where these changes were made previously, making troubleshooting more difficult.

Here, I've introduced some optimizations to avoid this penalty wherever possible:
* In `ray.widgets.util.in_notebook`, a significant performance penalty was previously being incurred by the `try-import/except` block upon each call. Now, we just check whether `"IPython" in sys.modules` is `True` before attempting an import. This is _much_ faster as `IPython` is pre-loaded for IPython kernels. I've also added this optimization to `repr_fallback_if_colab`.
* Secondly, `ensure_notebook_deps` now does dependency checking at function definition time, rather than on each function call.
* I've also reworked the logic for detecting the current shell, since there's really only one situation where displaying widgets is okay (when the user is running Jupyterlab). In all other cases, we fall back to simple reprs.
* The changes here also solve #35490.

Here's the artifact from the CI run below loaded onto google colab:

![image](https://github.com/ray-project/ray/assets/14017872/c1852128-9dce-4744-8bbc-1ed641ed9cce)

The rendered output has no dashboard URL row (evidently this isn't available when run on colab), but the HTML repr looks okay. The only graphical issue that I see is that colab is doing something weird to the "RAY" text that is supposed to appear by the Ray logo. Here's what this looks like in Jupyterlab:

![image](https://github.com/ray-project/ray/assets/14017872/7592f5f0-46d7-4fa3-8996-86371d39f3e2)
  • Loading branch information
peytondmurray authored May 30, 2023
1 parent 5b5d83c commit 6f2424f
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 199 deletions.
90 changes: 74 additions & 16 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tracing.tracing_helper import _import_from_string
from ray.widgets import Template
from ray.widgets.util import repr_with_fallback

SCRIPT_MODE = 0
WORKER_MODE = 1
Expand Down Expand Up @@ -1030,6 +1031,10 @@ class BaseContext(metaclass=ABCMeta):
Base class for RayContext and ClientContext
"""

dashboard_url: Optional[str]
python_version: str
ray_version: str

@abstractmethod
def disconnect(self):
"""
Expand All @@ -1047,6 +1052,73 @@ def __enter__(self):
def __exit__(self):
pass

def _context_table_template(self):
if self.dashboard_url:
dashboard_row = Template("context_dashrow.html.j2").render(
dashboard_url="http://" + self.dashboard_url
)
else:
dashboard_row = None

return Template("context_table.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
dashboard_row=dashboard_row,
)

def _repr_html_(self):
return Template("context.html.j2").render(
context_logo=Template("context_logo.html.j2").render(),
context_table=self._context_table_template(),
)

@repr_with_fallback(["ipywidgets", "8"])
def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]:
"""Get the mimebundle for the widget representation of the context.
Args:
**kwargs: Passed to the _repr_mimebundle_() function for the widget
Returns:
Dictionary ("mimebundle") of the widget representation of the context.
"""
import ipywidgets

disconnect_button = ipywidgets.Button(
description="Disconnect",
disabled=False,
button_style="",
tooltip="Disconnect from the Ray cluster",
layout=ipywidgets.Layout(margin="auto 0px 0px 0px"),
)

def disconnect_callback(button):
button.disabled = True
button.description = "Disconnecting..."
self.disconnect()
button.description = "Disconnected"

disconnect_button.on_click(disconnect_callback)
left_content = ipywidgets.VBox(
[
ipywidgets.HTML(Template("context_logo.html.j2").render()),
disconnect_button,
],
layout=ipywidgets.Layout(),
)
right_content = ipywidgets.HTML(self._context_table_template())
widget = ipywidgets.HBox(
[left_content, right_content], layout=ipywidgets.Layout(width="100%")
)
return widget._repr_mimebundle_(**kwargs)

def _repr_mimebundle_(self, **kwargs):
bundle = self._get_widget_bundle(**kwargs)

# Overwrite the widget html repr and default repr with those of the BaseContext
bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)})
return bundle


@dataclass
class RayContext(BaseContext, Mapping):
Expand All @@ -1058,10 +1130,10 @@ class RayContext(BaseContext, Mapping):
python_version: str
ray_version: str
ray_commit: str
protocol_version = Optional[str]
address_info: Dict[str, Optional[str]]
protocol_version: Optional[str]

def __init__(self, address_info: Dict[str, Optional[str]]):
super().__init__()
self.dashboard_url = get_dashboard_url()
self.python_version = "{}.{}.{}".format(*sys.version_info[:3])
self.ray_version = ray.__version__
Expand Down Expand Up @@ -1103,20 +1175,6 @@ def disconnect(self):
# Include disconnect() to stay consistent with ClientContext
ray.shutdown()

def _repr_html_(self):
if self.dashboard_url:
dashboard_row = Template("context_dashrow.html.j2").render(
dashboard_url="http://" + self.dashboard_url
)
else:
dashboard_row = None

return Template("context.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
dashboard_row=dashboard_row,
)


global_worker = Worker()
"""Worker: The global Worker object for this worker process.
Expand Down
15 changes: 0 additions & 15 deletions python/ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ray._private.worker import init as ray_driver_init
from ray.job_config import JobConfig
from ray.util.annotations import Deprecated, PublicAPI
from ray.widgets import Template

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,20 +85,6 @@ def _disconnect_with_context(self, force_disconnect: bool) -> None:
# This is only a driver connected to an existing cluster.
ray.shutdown()

def _repr_html_(self):
if self.dashboard_url:
dashboard_row = Template("context_dashrow.html.j2").render(
dashboard_url="http://" + self.dashboard_url
)
else:
dashboard_row = None

return Template("context.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
dashboard_row=dashboard_row,
)


@Deprecated
class ClientBuilder:
Expand Down
Loading

0 comments on commit 6f2424f

Please sign in to comment.