From 94ba96bbcb547eef3e3929859dd1e50dd553eb2e Mon Sep 17 00:00:00 2001 From: pdmurray Date: Tue, 25 Apr 2023 21:15:52 -0700 Subject: [PATCH] Add a disconnect button to the context widgets in notebooks Signed-off-by: pdmurray --- python/ray/_private/worker.py | 90 +++++++++++++++---- python/ray/client_builder.py | 15 ---- python/ray/tests/test_widgets.py | 4 + python/ray/widgets/render.py | 5 +- python/ray/widgets/templates/context.html.j2 | 37 +------- .../widgets/templates/context_logo.html.j2 | 13 +++ .../widgets/templates/context_table.html.j2 | 11 +++ python/ray/widgets/util.py | 16 ++-- 8 files changed, 114 insertions(+), 77 deletions(-) create mode 100644 python/ray/widgets/templates/context_logo.html.j2 create mode 100644 python/ray/widgets/templates/context_table.html.j2 diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 69a8327173c96..79b1af201f690 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -96,6 +96,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.tracing.tracing_helper import _import_from_string from ray.widgets import Template +from ray.widgets.util import ensure_ipywidgets_dep SCRIPT_MODE = 0 WORKER_MODE = 1 @@ -1019,6 +1020,10 @@ class BaseContext(metaclass=ABCMeta): Base class for RayContext and ClientContext """ + dashboard_url: Optional[str] + python_version: str + ray_version: str + @abstractmethod def disconnect(self): """ @@ -1036,6 +1041,73 @@ def __enter__(self): def __exit__(self): pass + def _context_table_template(self): + if self.dashboard_url: + dashboard_row = Template("context_dashrow.html.j2").render( + dashboard_url="http://" + self.dashboard_url + ) + else: + dashboard_row = None + + return Template("context_table.html.j2").render( + python_version=self.python_version, + ray_version=self.ray_version, + dashboard_row=dashboard_row, + ) + + def _repr_html_(self): + return Template("context.html.j2").render( + context_logo=Template("context_logo.html.j2").render(), + context_table=self._context_table_template(), + ) + + @ensure_ipywidgets_dep("8") + def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]: + """Get the mimebundle for the widget representation of the context. + + Args: + **kwargs: Passed to the _repr_mimebundle_() function for the widget + + Returns: + Dictionary ("mimebundle") of the widget representation of the context. + """ + import ipywidgets + + disconnect_button = ipywidgets.Button( + description="Disconnect", + disabled=False, + button_style="", + tooltip="Disconnect from the Ray cluster", + layout=ipywidgets.Layout(margin="auto 0px 0px 0px"), + ) + + def disconnect_callback(button): + button.disabled = True + button.description = "Disconnecting..." + self.disconnect() + button.description = "Disconnected" + + disconnect_button.on_click(disconnect_callback) + left_content = ipywidgets.VBox( + [ + ipywidgets.HTML(Template("context_logo.html.j2").render()), + disconnect_button, + ], + layout=ipywidgets.Layout(), + ) + right_content = ipywidgets.HTML(self._context_table_template()) + widget = ipywidgets.HBox( + [left_content, right_content], layout=ipywidgets.Layout(width="100%") + ) + return widget._repr_mimebundle_(**kwargs) + + def _repr_mimebundle_(self, **kwargs): + bundle = self._get_widget_bundle(**kwargs) + + # Overwrite the widget html repr and default repr with those of the BaseContext + bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)}) + return bundle + @dataclass class RayContext(BaseContext, Mapping): @@ -1047,10 +1119,10 @@ class RayContext(BaseContext, Mapping): python_version: str ray_version: str ray_commit: str - protocol_version = Optional[str] - address_info: Dict[str, Optional[str]] + protocol_version: Optional[str] def __init__(self, address_info: Dict[str, Optional[str]]): + super().__init__() self.dashboard_url = get_dashboard_url() self.python_version = "{}.{}.{}".format(*sys.version_info[:3]) self.ray_version = ray.__version__ @@ -1092,20 +1164,6 @@ def disconnect(self): # Include disconnect() to stay consistent with ClientContext ray.shutdown() - def _repr_html_(self): - if self.dashboard_url: - dashboard_row = Template("context_dashrow.html.j2").render( - dashboard_url="http://" + self.dashboard_url - ) - else: - dashboard_row = None - - return Template("context.html.j2").render( - python_version=self.python_version, - ray_version=self.ray_version, - dashboard_row=dashboard_row, - ) - global_worker = Worker() """Worker: The global Worker object for this worker process. diff --git a/python/ray/client_builder.py b/python/ray/client_builder.py index 4e379dc5c5b62..d5cbc03142e10 100644 --- a/python/ray/client_builder.py +++ b/python/ray/client_builder.py @@ -19,7 +19,6 @@ from ray._private.worker import init as ray_driver_init from ray.job_config import JobConfig from ray.util.annotations import Deprecated, PublicAPI -from ray.widgets import Template logger = logging.getLogger(__name__) @@ -86,20 +85,6 @@ def _disconnect_with_context(self, force_disconnect: bool) -> None: # This is only a driver connected to an existing cluster. ray.shutdown() - def _repr_html_(self): - if self.dashboard_url: - dashboard_row = Template("context_dashrow.html.j2").render( - dashboard_url="http://" + self.dashboard_url - ) - else: - dashboard_row = None - - return Template("context.html.j2").render( - python_version=self.python_version, - ray_version=self.ray_version, - dashboard_row=dashboard_row, - ) - @Deprecated class ClientBuilder: diff --git a/python/ray/tests/test_widgets.py b/python/ray/tests/test_widgets.py index ad95d2b3c9e4b..3697516def3a3 100644 --- a/python/ray/tests/test_widgets.py +++ b/python/ray/tests/test_widgets.py @@ -1,3 +1,4 @@ +import logging from unittest import mock import pytest @@ -7,6 +8,7 @@ @mock.patch("importlib.import_module") def test_ensure_notebook_dep_missing(mock_import_module, caplog): """Test that missing notebook dependencies trigger a warning.""" + caplog.set_level(logging.INFO) class MockDep: __version__ = "8.0.0" @@ -30,6 +32,7 @@ def dummy_ipython_display(self): @mock.patch("importlib.import_module") def test_ensure_notebook_dep_outdated(mock_import_module, caplog): """Test that outdated notebook dependencies trigger a warning.""" + caplog.set_level(logging.INFO) class MockDep: __version__ = "7.0.0" @@ -49,6 +52,7 @@ def dummy_ipython_display(): @mock.patch("importlib.import_module") def test_ensure_notebook_valid(mock_import_module, caplog): """Test that valid notebook dependencies don't trigger a warning.""" + caplog.set_level(logging.INFO) class MockDep: __version__ = "8.0.0" diff --git a/python/ray/widgets/render.py b/python/ray/widgets/render.py index 1fc0c69d5f0f7..f9e861d399256 100644 --- a/python/ray/widgets/render.py +++ b/python/ray/widgets/render.py @@ -19,7 +19,7 @@ def render(self, **kwargs) -> str: from the keyword arguments. Returns: - str: HTML template with the keys of the kwargs replaced with corresponding + HTML template with the keys of the kwargs replaced with corresponding values. """ rendered = self.template @@ -34,7 +34,6 @@ def list_templates() -> List[pathlib.Path]: """List the available HTML templates. Returns: - List[pathlib.Path]: A list of files with .html.j2 extensions inside - ./templates/ + A list of files with .html.j2 extensions inside ../templates/ """ return (pathlib.Path(__file__).parent / "templates").glob("*.html.j2") diff --git a/python/ray/widgets/templates/context.html.j2 b/python/ray/widgets/templates/context.html.j2 index 3e664e01eae2e..26cc0ef6c8784 100644 --- a/python/ray/widgets/templates/context.html.j2 +++ b/python/ray/widgets/templates/context.html.j2 @@ -1,37 +1,6 @@ -
+ diff --git a/python/ray/widgets/templates/context_logo.html.j2 b/python/ray/widgets/templates/context_logo.html.j2 new file mode 100644 index 0000000000000..9233fe3a77226 --- /dev/null +++ b/python/ray/widgets/templates/context_logo.html.j2 @@ -0,0 +1,13 @@ + diff --git a/python/ray/widgets/templates/context_table.html.j2 b/python/ray/widgets/templates/context_table.html.j2 new file mode 100644 index 0000000000000..d06822d0c1f56 --- /dev/null +++ b/python/ray/widgets/templates/context_table.html.j2 @@ -0,0 +1,11 @@ + + + + + + + + + + {{ dashboard_row }} + diff --git a/python/ray/widgets/util.py b/python/ray/widgets/util.py index 6991384779f2d..3fbe3b4aa9ce4 100644 --- a/python/ray/widgets/util.py +++ b/python/ray/widgets/util.py @@ -73,9 +73,9 @@ def ensure_notebook_deps( ) -> Callable[[F], F]: """Generate a decorator which checks for soft dependencies. - This decorator is meant to wrap repr methods. If the dependency is not found, - or a version is specified here and the version of the package is older than the - specified version, the original repr is used. + This decorator is meant to wrap _repr_mimebundle_ methods. If the dependency is not + found, or a version is specified here and the version of the package is older than + the specified version, the original repr is used. If the dependency is missing or the version is old, a log message is displayed. Args: @@ -155,11 +155,11 @@ def _has_missing( message = f"Run `pip install {' '.join(missing)}` for rich notebook output." if sys.version_info < (3, 8): - logger.warning(f"Missing packages: {missing}. {message}") + logger.info(f"Missing packages: {missing}. {message}") else: # stacklevel=3: First level is this function, then ensure_notebook_deps, # then the actual function affected. - logger.warning(f"Missing packages: {missing}. {message}", stacklevel=3) + logger.info(f"Missing packages: {missing}. {message}", stacklevel=3) return missing @@ -190,13 +190,11 @@ def _has_outdated( message = f"Run `pip install -U {install_str}` for rich notebook output." if sys.version_info < (3, 8): - logger.warning(f"Outdated packages:\n{outdated_str}\n{message}") + logger.info(f"Outdated packages:\n{outdated_str}\n{message}") else: # stacklevel=3: First level is this function, then ensure_notebook_deps, # then the actual function affected. - logger.warning( - f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3 - ) + logger.info(f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3) return outdated