From 6f2424ffab4bbf30f76a7139811fc1ed3e2f7d00 Mon Sep 17 00:00:00 2001 From: Peyton Murray Date: Tue, 30 May 2023 16:55:26 -0700 Subject: [PATCH] Add a disconnect button for RayContext in notebooks (#35507) This PR reverts #35426, adding back in a nice button for disconnecting (i.e. calling `ray.shutdown()`) when `ray.init()` is called in a notebook. ## Testing The main issue with the previous disconnect PR is that checking for the `ipywidgets` soft dependency at run time introduces a small performance penalty which bumps some test suites up beyond their timeout upper bound, breaking them. This happens in tests unrelated to where these changes were made previously, making troubleshooting more difficult. Here, I've introduced some optimizations to avoid this penalty wherever possible: * In `ray.widgets.util.in_notebook`, a significant performance penalty was previously being incurred by the `try-import/except` block upon each call. Now, we just check whether `"IPython" in sys.modules` is `True` before attempting an import. This is _much_ faster as `IPython` is pre-loaded for IPython kernels. I've also added this optimization to `repr_fallback_if_colab`. * Secondly, `ensure_notebook_deps` now does dependency checking at function definition time, rather than on each function call. * I've also reworked the logic for detecting the current shell, since there's really only one situation where displaying widgets is okay (when the user is running Jupyterlab). In all other cases, we fall back to simple reprs. * The changes here also solve #35490. Here's the artifact from the CI run below loaded onto google colab: ![image](https://github.com/ray-project/ray/assets/14017872/c1852128-9dce-4744-8bbc-1ed641ed9cce) The rendered output has no dashboard URL row (evidently this isn't available when run on colab), but the HTML repr looks okay. The only graphical issue that I see is that colab is doing something weird to the "RAY" text that is supposed to appear by the Ray logo. Here's what this looks like in Jupyterlab: ![image](https://github.com/ray-project/ray/assets/14017872/7592f5f0-46d7-4fa3-8996-86371d39f3e2) --- python/ray/_private/worker.py | 90 ++++++-- python/ray/client_builder.py | 15 -- python/ray/tests/test_widgets.py | 204 +++++++++++------- python/ray/widgets/render.py | 5 +- python/ray/widgets/templates/context.html.j2 | 37 +--- .../widgets/templates/context_logo.html.j2 | 13 ++ .../widgets/templates/context_table.html.j2 | 11 + python/ray/widgets/util.py | 106 ++++----- 8 files changed, 282 insertions(+), 199 deletions(-) create mode 100644 python/ray/widgets/templates/context_logo.html.j2 create mode 100644 python/ray/widgets/templates/context_table.html.j2 diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py index 447d8fa5dc7c..151b4716338a 100644 --- a/python/ray/_private/worker.py +++ b/python/ray/_private/worker.py @@ -91,6 +91,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.tracing.tracing_helper import _import_from_string from ray.widgets import Template +from ray.widgets.util import repr_with_fallback SCRIPT_MODE = 0 WORKER_MODE = 1 @@ -1030,6 +1031,10 @@ class BaseContext(metaclass=ABCMeta): Base class for RayContext and ClientContext """ + dashboard_url: Optional[str] + python_version: str + ray_version: str + @abstractmethod def disconnect(self): """ @@ -1047,6 +1052,73 @@ def __enter__(self): def __exit__(self): pass + def _context_table_template(self): + if self.dashboard_url: + dashboard_row = Template("context_dashrow.html.j2").render( + dashboard_url="http://" + self.dashboard_url + ) + else: + dashboard_row = None + + return Template("context_table.html.j2").render( + python_version=self.python_version, + ray_version=self.ray_version, + dashboard_row=dashboard_row, + ) + + def _repr_html_(self): + return Template("context.html.j2").render( + context_logo=Template("context_logo.html.j2").render(), + context_table=self._context_table_template(), + ) + + @repr_with_fallback(["ipywidgets", "8"]) + def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]: + """Get the mimebundle for the widget representation of the context. + + Args: + **kwargs: Passed to the _repr_mimebundle_() function for the widget + + Returns: + Dictionary ("mimebundle") of the widget representation of the context. + """ + import ipywidgets + + disconnect_button = ipywidgets.Button( + description="Disconnect", + disabled=False, + button_style="", + tooltip="Disconnect from the Ray cluster", + layout=ipywidgets.Layout(margin="auto 0px 0px 0px"), + ) + + def disconnect_callback(button): + button.disabled = True + button.description = "Disconnecting..." + self.disconnect() + button.description = "Disconnected" + + disconnect_button.on_click(disconnect_callback) + left_content = ipywidgets.VBox( + [ + ipywidgets.HTML(Template("context_logo.html.j2").render()), + disconnect_button, + ], + layout=ipywidgets.Layout(), + ) + right_content = ipywidgets.HTML(self._context_table_template()) + widget = ipywidgets.HBox( + [left_content, right_content], layout=ipywidgets.Layout(width="100%") + ) + return widget._repr_mimebundle_(**kwargs) + + def _repr_mimebundle_(self, **kwargs): + bundle = self._get_widget_bundle(**kwargs) + + # Overwrite the widget html repr and default repr with those of the BaseContext + bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)}) + return bundle + @dataclass class RayContext(BaseContext, Mapping): @@ -1058,10 +1130,10 @@ class RayContext(BaseContext, Mapping): python_version: str ray_version: str ray_commit: str - protocol_version = Optional[str] - address_info: Dict[str, Optional[str]] + protocol_version: Optional[str] def __init__(self, address_info: Dict[str, Optional[str]]): + super().__init__() self.dashboard_url = get_dashboard_url() self.python_version = "{}.{}.{}".format(*sys.version_info[:3]) self.ray_version = ray.__version__ @@ -1103,20 +1175,6 @@ def disconnect(self): # Include disconnect() to stay consistent with ClientContext ray.shutdown() - def _repr_html_(self): - if self.dashboard_url: - dashboard_row = Template("context_dashrow.html.j2").render( - dashboard_url="http://" + self.dashboard_url - ) - else: - dashboard_row = None - - return Template("context.html.j2").render( - python_version=self.python_version, - ray_version=self.ray_version, - dashboard_row=dashboard_row, - ) - global_worker = Worker() """Worker: The global Worker object for this worker process. diff --git a/python/ray/client_builder.py b/python/ray/client_builder.py index 4e379dc5c5b6..d5cbc03142e1 100644 --- a/python/ray/client_builder.py +++ b/python/ray/client_builder.py @@ -19,7 +19,6 @@ from ray._private.worker import init as ray_driver_init from ray.job_config import JobConfig from ray.util.annotations import Deprecated, PublicAPI -from ray.widgets import Template logger = logging.getLogger(__name__) @@ -86,20 +85,6 @@ def _disconnect_with_context(self, force_disconnect: bool) -> None: # This is only a driver connected to an existing cluster. ray.shutdown() - def _repr_html_(self): - if self.dashboard_url: - dashboard_row = Template("context_dashrow.html.j2").render( - dashboard_url="http://" + self.dashboard_url - ) - else: - dashboard_row = None - - return Template("context.html.j2").render( - python_version=self.python_version, - ray_version=self.ray_version, - dashboard_row=dashboard_row, - ) - @Deprecated class ClientBuilder: diff --git a/python/ray/tests/test_widgets.py b/python/ray/tests/test_widgets.py index be1f8714c0e9..5a4700266a49 100644 --- a/python/ray/tests/test_widgets.py +++ b/python/ray/tests/test_widgets.py @@ -1,134 +1,182 @@ +import logging from unittest import mock import pytest -from ray.widgets.util import repr_with_fallback +import ray +from ray.widgets.util import repr_with_fallback, _can_display_ipywidgets +@pytest.fixture +def logs(propagate_logs, caplog): + """A log fixture which captures logs during a test.""" + caplog.set_level(logging.INFO) + return caplog + + +@pytest.fixture +def fancy_mimebundle(): + return { + "fancy/mimetype": "A fancy repr", + "text/plain": "A simple repr", + } + + +@mock.patch.object(ray.widgets.util, "in_notebook") +@mock.patch("importlib.util.find_spec") @mock.patch("importlib.import_module") -@mock.patch("ray.widgets.util.in_notebook") -def test_ensure_notebook_dep_missing( - mock_in_notebook, mock_import_module, propagate_logs, caplog +def test_repr_with_fallback_missing( + mock_import_module, + mock_find_spec, + mock_in_notebook, + logs, + fancy_mimebundle, ): - """Test that missing notebook dependencies trigger a warning.""" - - class MockDep: - __version__ = "8.0.0" + """Test that missing notebook dependencies trigger a log message.""" def raise_import_error(*args): raise ImportError - mock_import_module.return_value = MockDep() mock_import_module.side_effect = raise_import_error - + mock_find_spec.return_value = None mock_in_notebook.return_value = True class DummyObject: - @repr_with_fallback(["somedep", "8"]) - def dummy_ipython_display(self): - return + def __repr__(self): + return "dummy repr" - DummyObject().dummy_ipython_display() + @repr_with_fallback(["somedep", "8"]) + def _repr_mimebundle_(self, **kwargs): + return fancy_mimebundle - assert "Missing packages:" in caplog.records[-1].msg + result = DummyObject()._repr_mimebundle_() + assert result == {"text/plain": "dummy repr"} + assert "Missing packages:" in logs.records[-1].msg +@mock.patch.object(ray.widgets.util, "in_notebook") +@mock.patch("importlib.util.find_spec") @mock.patch("importlib.import_module") -@mock.patch("ray.widgets.util.in_notebook") -def test_ensure_notebook_dep_outdated( - mock_in_notebook, mock_import_module, propagate_logs, caplog +def test_repr_with_fallback_outdated( + mock_import_module, + mock_find_spec, + mock_in_notebook, + logs, + fancy_mimebundle, ): - """Test that outdated notebook dependencies trigger a warning.""" + """Test that outdated notebook dependencies trigger a log message.""" class MockDep: __version__ = "7.0.0" mock_import_module.return_value = MockDep() - + mock_find_spec.return_value = "a valid import spec" mock_in_notebook.return_value = True class DummyObject: - @repr_with_fallback(["somedep", "8"]) - def dummy_ipython_display(): - return + def __repr__(self): + return "dummy repr" - DummyObject().dummy_ipython_display() + @repr_with_fallback(["somedep", "8"]) + def _repr_mimebundle_(self, **kwargs): + return fancy_mimebundle - assert "Outdated packages:" in caplog.records[-1].msg + result = DummyObject()._repr_mimebundle_() + assert result == {"text/plain": "dummy repr"} + assert "Outdated packages:" in logs.records[-1].msg +@mock.patch.object(ray.widgets.util, "_can_display_ipywidgets") +@mock.patch("importlib.util.find_spec") @mock.patch("importlib.import_module") -@mock.patch("ray.widgets.util.in_notebook") -def test_ensure_notebook_valid( - mock_in_notebook, mock_import_module, propagate_logs, caplog +def test_repr_with_fallback_valid( + mock_import_module, + mock_find_spec, + mock_can_display_ipywidgets, + logs, + fancy_mimebundle, ): - """Test that valid notebook dependencies don't trigger a warning.""" + """Test that valid notebook dependencies don't trigger a log message.""" class MockDep: __version__ = "8.0.0" mock_import_module.return_value = MockDep() - - mock_in_notebook.return_value = True + mock_find_spec.return_value = "a valid import spec" + mock_can_display_ipywidgets.return_value = True class DummyObject: + def __repr__(self): + return "dummy repr" + @repr_with_fallback(["somedep", "8"]) - def dummy_ipython_display(self): - return + def _repr_mimebundle_(self, **kwargs): + return fancy_mimebundle + + result = DummyObject()._repr_mimebundle_() + assert len(logs.records) == 0 + assert result == fancy_mimebundle - DummyObject().dummy_ipython_display() - assert len(caplog.records) == 0 +@mock.patch.object(ray.widgets.util, "_can_display_ipywidgets") +@mock.patch("importlib.util.find_spec") +@mock.patch("importlib.import_module") +def test_repr_with_fallback_invalid_shell( + mock_import_module, + mock_find_spec, + mock_can_display_ipywidgets, + logs, + fancy_mimebundle, +): + """Test that the mimebundle is correctly stripped if run in an invalid shell.""" + + class MockDep: + __version__ = "8.0.0" + mock_import_module.return_value = MockDep() + mock_find_spec.return_value = "a valid import spec" + mock_can_display_ipywidgets.return_value = False + + class DummyObject: + def __repr__(self): + return "dummy repr" + + @repr_with_fallback(["somedep", "8"]) + def _repr_mimebundle_(self, **kwargs): + return fancy_mimebundle + result = DummyObject()._repr_mimebundle_() + assert len(logs.records) == 0 + assert result == {"text/plain": "dummy repr"} + + +@mock.patch.object(ray.widgets.util, "_get_ipython_shell_name") +@mock.patch("importlib.util.find_spec") +@mock.patch("importlib.import_module") @pytest.mark.parametrize( - "kernel", + "shell,can_display", [ - ("google.colab.kernel"), - ("normal.ipython.kernel"), + ("ZMQInteractiveShell", True), + ("google.colab.kernel", False), + ("TerminalInteractiveShell", False), + ("", False), ], ) -def test_repr_fallback_if_colab(kernel): - """Test that the mimebundle is correctly stripped if run in google colab.""" - pytest.importorskip("IPython", reason="IPython is not installed.") - with mock.patch("IPython.get_ipython") as mock_get_ipython: - mock_get_ipython.return_value = kernel - - class DummyObject: - @repr_with_fallback() - def _repr_mimebundle_(self, **kwargs): - return { - "fancy/mimetype": "A fancy repr", - "text/plain": "A simple repr", - } - - obj = DummyObject() - result = obj._repr_mimebundle_() - - assert "text/plain" in result - if "google.colab" in kernel: - assert len(result) == 1 - else: - assert len(result) == 2 - assert "fancy/mimetype" in result - - -@mock.patch("ray.widgets.util.in_ipython_shell") -def test_repr_fallback_if_ipython_shell(mock_in_ipython): - mock_in_ipython.return_value = True - - class DummyObject: - @repr_with_fallback() - def _repr_mimebundle_(self, **kwargs): - return { - "fancy/mimetype": "A fancy repr", - "text/plain": "A simple repr", - } +def test_can_display_ipywidgets( + mock_import_module, + mock_find_spec, + mock_get_ipython_shell_name, + shell, + can_display, +): + class MockDep: + __version__ = "8.0.0" - obj = DummyObject() - result = obj._repr_mimebundle_() + mock_import_module.return_value = MockDep() + mock_find_spec.return_value = "a valid import spec" + mock_get_ipython_shell_name.return_value = shell - assert "text/plain" in result - assert len(result) == 1 + assert _can_display_ipywidgets(["somedep", "8"], message="") == can_display + mock_get_ipython_shell_name.assert_called() if __name__ == "__main__": diff --git a/python/ray/widgets/render.py b/python/ray/widgets/render.py index 1fc0c69d5f0f..f9e861d39925 100644 --- a/python/ray/widgets/render.py +++ b/python/ray/widgets/render.py @@ -19,7 +19,7 @@ def render(self, **kwargs) -> str: from the keyword arguments. Returns: - str: HTML template with the keys of the kwargs replaced with corresponding + HTML template with the keys of the kwargs replaced with corresponding values. """ rendered = self.template @@ -34,7 +34,6 @@ def list_templates() -> List[pathlib.Path]: """List the available HTML templates. Returns: - List[pathlib.Path]: A list of files with .html.j2 extensions inside - ./templates/ + A list of files with .html.j2 extensions inside ../templates/ """ return (pathlib.Path(__file__).parent / "templates").glob("*.html.j2") diff --git a/python/ray/widgets/templates/context.html.j2 b/python/ray/widgets/templates/context.html.j2 index 3e664e01eae2..26cc0ef6c878 100644 --- a/python/ray/widgets/templates/context.html.j2 +++ b/python/ray/widgets/templates/context.html.j2 @@ -1,37 +1,6 @@ -
+ diff --git a/python/ray/widgets/templates/context_logo.html.j2 b/python/ray/widgets/templates/context_logo.html.j2 new file mode 100644 index 000000000000..9233fe3a7722 --- /dev/null +++ b/python/ray/widgets/templates/context_logo.html.j2 @@ -0,0 +1,13 @@ + diff --git a/python/ray/widgets/templates/context_table.html.j2 b/python/ray/widgets/templates/context_table.html.j2 new file mode 100644 index 000000000000..d06822d0c1f5 --- /dev/null +++ b/python/ray/widgets/templates/context_table.html.j2 @@ -0,0 +1,11 @@ + + + + + + + + + + {{ dashboard_row }} + diff --git a/python/ray/widgets/util.py b/python/ray/widgets/util.py index 194c31ba3c70..32280ef09a7d 100644 --- a/python/ray/widgets/util.py +++ b/python/ray/widgets/util.py @@ -68,11 +68,18 @@ def make_table_html_repr( def _has_missing( *deps: Iterable[Union[str, Optional[str]]], message: Optional[str] = None ): + """Return a list of missing dependencies. + + Args: + deps: Dependencies to check for + message: Message to be emitted if a dependency isn't found + + Returns: + A list of dependencies which can't be found, if any + """ missing = [] for (lib, _) in deps: - try: - importlib.import_module(lib) - except ImportError: + if importlib.util.find_spec(lib) is None: missing.append(lib) if missing: @@ -80,11 +87,11 @@ def _has_missing( message = f"Run `pip install {' '.join(missing)}` for rich notebook output." if sys.version_info < (3, 8): - logger.warning(f"Missing packages: {missing}. {message}") + logger.info(f"Missing packages: {missing}. {message}") else: # stacklevel=3: First level is this function, then ensure_notebook_deps, # then the actual function affected. - logger.warning(f"Missing packages: {missing}. {message}", stacklevel=3) + logger.info(f"Missing packages: {missing}. {message}", stacklevel=3) return missing @@ -115,13 +122,11 @@ def _has_outdated( message = f"Run `pip install -U {install_str}` for rich notebook output." if sys.version_info < (3, 8): - logger.warning(f"Outdated packages:\n{outdated_str}\n{message}") + logger.info(f"Outdated packages:\n{outdated_str}\n{message}") else: # stacklevel=3: First level is this function, then ensure_notebook_deps, # then the actual function affected. - logger.warning( - f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3 - ) + logger.info(f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3) return outdated @@ -147,66 +152,61 @@ def repr_with_fallback( conditions above hold, in which case it returns a mimebundle that only contains a single text/plain mimetype. """ - - try: - import IPython - - ipython = IPython.get_ipython() - except (ModuleNotFoundError, ValueError): - ipython = None - message = ( "Run `pip install -U ipywidgets`, then restart " "the notebook server for rich notebook output." ) + if _can_display_ipywidgets(*notebook_deps, message=message): - def wrapper(func: F) -> F: - @wraps(func) - def wrapped(self, *args, **kwargs): - fallback = ( - # In Google Colab. - (ipython and "google.colab" in str(ipython)) - or - # In notebook environment without required dependencies. - ( - in_notebook() - and ( - _has_missing(*notebook_deps, message=message) - or _has_outdated(*notebook_deps, message=message) - ) - ) - or - # In ipython shell. - in_ipython_shell() - ) - if fallback: - return {"text/plain": repr(self)} - else: + def wrapper(func: F) -> F: + @wraps(func) + def wrapped(self, *args, **kwargs): return func(self, *args, **kwargs) - return wrapped + return wrapped + + else: + + def wrapper(func: F) -> F: + @wraps(func) + def wrapped(self, *args, **kwargs): + return {"text/plain": repr(self)} + + return wrapped return wrapper def _get_ipython_shell_name() -> str: - try: - import IPython + if "IPython" in sys.modules: + from IPython import get_ipython + + return get_ipython().__class__.__name__ + return "" + + +def _can_display_ipywidgets(*deps, message) -> bool: + # Default to safe behavior: only display widgets if running in a notebook + # that has valid dependencies + if in_notebook() and not ( + _has_missing(*deps, message=message) or _has_outdated(*deps, message=message) + ): + return True - shell = IPython.get_ipython().__class__.__name__ - return shell - except (ModuleNotFoundError, NameError, ValueError): - return "" + return False @DeveloperAPI -def in_notebook() -> bool: - """Return whether we are in a Jupyter notebook.""" - shell = _get_ipython_shell_name() - return shell == "ZMQInteractiveShell" # Jupyter notebook or qtconsole +def in_notebook(shell_name: Optional[str] = None) -> bool: + """Return whether we are in a Jupyter notebook or qtconsole.""" + if not shell_name: + shell_name = _get_ipython_shell_name() + return shell_name == "ZMQInteractiveShell" @DeveloperAPI -def in_ipython_shell() -> bool: - shell = _get_ipython_shell_name() - return shell == "TerminalInteractiveShell" # Terminal running IPython +def in_ipython_shell(shell_name: Optional[str] = None) -> bool: + """Return whether we are in a terminal running IPython""" + if not shell_name: + shell_name = _get_ipython_shell_name() + return shell_name == "TerminalInteractiveShell"