diff --git a/python/ray/_private/worker.py b/python/ray/_private/worker.py
index b4d1f8154c86..9b9a233d2a61 100644
--- a/python/ray/_private/worker.py
+++ b/python/ray/_private/worker.py
@@ -96,6 +96,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tracing.tracing_helper import _import_from_string
from ray.widgets import Template
+from ray.widgets.util import repr_with_fallback
SCRIPT_MODE = 0
WORKER_MODE = 1
@@ -1019,6 +1020,10 @@ class BaseContext(metaclass=ABCMeta):
Base class for RayContext and ClientContext
"""
+ dashboard_url: Optional[str]
+ python_version: str
+ ray_version: str
+
@abstractmethod
def disconnect(self):
"""
@@ -1036,6 +1041,73 @@ def __enter__(self):
def __exit__(self):
pass
+ def _context_table_template(self):
+ if self.dashboard_url:
+ dashboard_row = Template("context_dashrow.html.j2").render(
+ dashboard_url="http://" + self.dashboard_url
+ )
+ else:
+ dashboard_row = None
+
+ return Template("context_table.html.j2").render(
+ python_version=self.python_version,
+ ray_version=self.ray_version,
+ dashboard_row=dashboard_row,
+ )
+
+ def _repr_html_(self):
+ return Template("context.html.j2").render(
+ context_logo=Template("context_logo.html.j2").render(),
+ context_table=self._context_table_template(),
+ )
+
+ @repr_with_fallback(["ipywidgets", "8"])
+ def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]:
+ """Get the mimebundle for the widget representation of the context.
+
+ Args:
+ **kwargs: Passed to the _repr_mimebundle_() function for the widget
+
+ Returns:
+ Dictionary ("mimebundle") of the widget representation of the context.
+ """
+ import ipywidgets
+
+ disconnect_button = ipywidgets.Button(
+ description="Disconnect",
+ disabled=False,
+ button_style="",
+ tooltip="Disconnect from the Ray cluster",
+ layout=ipywidgets.Layout(margin="auto 0px 0px 0px"),
+ )
+
+ def disconnect_callback(button):
+ button.disabled = True
+ button.description = "Disconnecting..."
+ self.disconnect()
+ button.description = "Disconnected"
+
+ disconnect_button.on_click(disconnect_callback)
+ left_content = ipywidgets.VBox(
+ [
+ ipywidgets.HTML(Template("context_logo.html.j2").render()),
+ disconnect_button,
+ ],
+ layout=ipywidgets.Layout(),
+ )
+ right_content = ipywidgets.HTML(self._context_table_template())
+ widget = ipywidgets.HBox(
+ [left_content, right_content], layout=ipywidgets.Layout(width="100%")
+ )
+ return widget._repr_mimebundle_(**kwargs)
+
+ def _repr_mimebundle_(self, **kwargs):
+ bundle = self._get_widget_bundle(**kwargs)
+
+ # Overwrite the widget html repr and default repr with those of the BaseContext
+ bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)})
+ return bundle
+
@dataclass
class RayContext(BaseContext, Mapping):
@@ -1047,10 +1119,10 @@ class RayContext(BaseContext, Mapping):
python_version: str
ray_version: str
ray_commit: str
- protocol_version = Optional[str]
- address_info: Dict[str, Optional[str]]
+ protocol_version: Optional[str]
def __init__(self, address_info: Dict[str, Optional[str]]):
+ super().__init__()
self.dashboard_url = get_dashboard_url()
self.python_version = "{}.{}.{}".format(*sys.version_info[:3])
self.ray_version = ray.__version__
@@ -1092,20 +1164,6 @@ def disconnect(self):
# Include disconnect() to stay consistent with ClientContext
ray.shutdown()
- def _repr_html_(self):
- if self.dashboard_url:
- dashboard_row = Template("context_dashrow.html.j2").render(
- dashboard_url="http://" + self.dashboard_url
- )
- else:
- dashboard_row = None
-
- return Template("context.html.j2").render(
- python_version=self.python_version,
- ray_version=self.ray_version,
- dashboard_row=dashboard_row,
- )
-
global_worker = Worker()
"""Worker: The global Worker object for this worker process.
diff --git a/python/ray/client_builder.py b/python/ray/client_builder.py
index 4e379dc5c5b6..d5cbc03142e1 100644
--- a/python/ray/client_builder.py
+++ b/python/ray/client_builder.py
@@ -19,7 +19,6 @@
from ray._private.worker import init as ray_driver_init
from ray.job_config import JobConfig
from ray.util.annotations import Deprecated, PublicAPI
-from ray.widgets import Template
logger = logging.getLogger(__name__)
@@ -86,20 +85,6 @@ def _disconnect_with_context(self, force_disconnect: bool) -> None:
# This is only a driver connected to an existing cluster.
ray.shutdown()
- def _repr_html_(self):
- if self.dashboard_url:
- dashboard_row = Template("context_dashrow.html.j2").render(
- dashboard_url="http://" + self.dashboard_url
- )
- else:
- dashboard_row = None
-
- return Template("context.html.j2").render(
- python_version=self.python_version,
- ray_version=self.ray_version,
- dashboard_row=dashboard_row,
- )
-
@Deprecated
class ClientBuilder:
diff --git a/python/ray/tests/test_widgets.py b/python/ray/tests/test_widgets.py
index be1f8714c0e9..5a4700266a49 100644
--- a/python/ray/tests/test_widgets.py
+++ b/python/ray/tests/test_widgets.py
@@ -1,134 +1,182 @@
+import logging
from unittest import mock
import pytest
-from ray.widgets.util import repr_with_fallback
+import ray
+from ray.widgets.util import repr_with_fallback, _can_display_ipywidgets
+@pytest.fixture
+def logs(propagate_logs, caplog):
+ """A log fixture which captures logs during a test."""
+ caplog.set_level(logging.INFO)
+ return caplog
+
+
+@pytest.fixture
+def fancy_mimebundle():
+ return {
+ "fancy/mimetype": "A fancy repr",
+ "text/plain": "A simple repr",
+ }
+
+
+@mock.patch.object(ray.widgets.util, "in_notebook")
+@mock.patch("importlib.util.find_spec")
@mock.patch("importlib.import_module")
-@mock.patch("ray.widgets.util.in_notebook")
-def test_ensure_notebook_dep_missing(
- mock_in_notebook, mock_import_module, propagate_logs, caplog
+def test_repr_with_fallback_missing(
+ mock_import_module,
+ mock_find_spec,
+ mock_in_notebook,
+ logs,
+ fancy_mimebundle,
):
- """Test that missing notebook dependencies trigger a warning."""
-
- class MockDep:
- __version__ = "8.0.0"
+ """Test that missing notebook dependencies trigger a log message."""
def raise_import_error(*args):
raise ImportError
- mock_import_module.return_value = MockDep()
mock_import_module.side_effect = raise_import_error
-
+ mock_find_spec.return_value = None
mock_in_notebook.return_value = True
class DummyObject:
- @repr_with_fallback(["somedep", "8"])
- def dummy_ipython_display(self):
- return
+ def __repr__(self):
+ return "dummy repr"
- DummyObject().dummy_ipython_display()
+ @repr_with_fallback(["somedep", "8"])
+ def _repr_mimebundle_(self, **kwargs):
+ return fancy_mimebundle
- assert "Missing packages:" in caplog.records[-1].msg
+ result = DummyObject()._repr_mimebundle_()
+ assert result == {"text/plain": "dummy repr"}
+ assert "Missing packages:" in logs.records[-1].msg
+@mock.patch.object(ray.widgets.util, "in_notebook")
+@mock.patch("importlib.util.find_spec")
@mock.patch("importlib.import_module")
-@mock.patch("ray.widgets.util.in_notebook")
-def test_ensure_notebook_dep_outdated(
- mock_in_notebook, mock_import_module, propagate_logs, caplog
+def test_repr_with_fallback_outdated(
+ mock_import_module,
+ mock_find_spec,
+ mock_in_notebook,
+ logs,
+ fancy_mimebundle,
):
- """Test that outdated notebook dependencies trigger a warning."""
+ """Test that outdated notebook dependencies trigger a log message."""
class MockDep:
__version__ = "7.0.0"
mock_import_module.return_value = MockDep()
-
+ mock_find_spec.return_value = "a valid import spec"
mock_in_notebook.return_value = True
class DummyObject:
- @repr_with_fallback(["somedep", "8"])
- def dummy_ipython_display():
- return
+ def __repr__(self):
+ return "dummy repr"
- DummyObject().dummy_ipython_display()
+ @repr_with_fallback(["somedep", "8"])
+ def _repr_mimebundle_(self, **kwargs):
+ return fancy_mimebundle
- assert "Outdated packages:" in caplog.records[-1].msg
+ result = DummyObject()._repr_mimebundle_()
+ assert result == {"text/plain": "dummy repr"}
+ assert "Outdated packages:" in logs.records[-1].msg
+@mock.patch.object(ray.widgets.util, "_can_display_ipywidgets")
+@mock.patch("importlib.util.find_spec")
@mock.patch("importlib.import_module")
-@mock.patch("ray.widgets.util.in_notebook")
-def test_ensure_notebook_valid(
- mock_in_notebook, mock_import_module, propagate_logs, caplog
+def test_repr_with_fallback_valid(
+ mock_import_module,
+ mock_find_spec,
+ mock_can_display_ipywidgets,
+ logs,
+ fancy_mimebundle,
):
- """Test that valid notebook dependencies don't trigger a warning."""
+ """Test that valid notebook dependencies don't trigger a log message."""
class MockDep:
__version__ = "8.0.0"
mock_import_module.return_value = MockDep()
-
- mock_in_notebook.return_value = True
+ mock_find_spec.return_value = "a valid import spec"
+ mock_can_display_ipywidgets.return_value = True
class DummyObject:
+ def __repr__(self):
+ return "dummy repr"
+
@repr_with_fallback(["somedep", "8"])
- def dummy_ipython_display(self):
- return
+ def _repr_mimebundle_(self, **kwargs):
+ return fancy_mimebundle
+
+ result = DummyObject()._repr_mimebundle_()
+ assert len(logs.records) == 0
+ assert result == fancy_mimebundle
- DummyObject().dummy_ipython_display()
- assert len(caplog.records) == 0
+@mock.patch.object(ray.widgets.util, "_can_display_ipywidgets")
+@mock.patch("importlib.util.find_spec")
+@mock.patch("importlib.import_module")
+def test_repr_with_fallback_invalid_shell(
+ mock_import_module,
+ mock_find_spec,
+ mock_can_display_ipywidgets,
+ logs,
+ fancy_mimebundle,
+):
+ """Test that the mimebundle is correctly stripped if run in an invalid shell."""
+
+ class MockDep:
+ __version__ = "8.0.0"
+ mock_import_module.return_value = MockDep()
+ mock_find_spec.return_value = "a valid import spec"
+ mock_can_display_ipywidgets.return_value = False
+
+ class DummyObject:
+ def __repr__(self):
+ return "dummy repr"
+
+ @repr_with_fallback(["somedep", "8"])
+ def _repr_mimebundle_(self, **kwargs):
+ return fancy_mimebundle
+ result = DummyObject()._repr_mimebundle_()
+ assert len(logs.records) == 0
+ assert result == {"text/plain": "dummy repr"}
+
+
+@mock.patch.object(ray.widgets.util, "_get_ipython_shell_name")
+@mock.patch("importlib.util.find_spec")
+@mock.patch("importlib.import_module")
@pytest.mark.parametrize(
- "kernel",
+ "shell,can_display",
[
- ("google.colab.kernel"),
- ("normal.ipython.kernel"),
+ ("ZMQInteractiveShell", True),
+ ("google.colab.kernel", False),
+ ("TerminalInteractiveShell", False),
+ ("", False),
],
)
-def test_repr_fallback_if_colab(kernel):
- """Test that the mimebundle is correctly stripped if run in google colab."""
- pytest.importorskip("IPython", reason="IPython is not installed.")
- with mock.patch("IPython.get_ipython") as mock_get_ipython:
- mock_get_ipython.return_value = kernel
-
- class DummyObject:
- @repr_with_fallback()
- def _repr_mimebundle_(self, **kwargs):
- return {
- "fancy/mimetype": "A fancy repr",
- "text/plain": "A simple repr",
- }
-
- obj = DummyObject()
- result = obj._repr_mimebundle_()
-
- assert "text/plain" in result
- if "google.colab" in kernel:
- assert len(result) == 1
- else:
- assert len(result) == 2
- assert "fancy/mimetype" in result
-
-
-@mock.patch("ray.widgets.util.in_ipython_shell")
-def test_repr_fallback_if_ipython_shell(mock_in_ipython):
- mock_in_ipython.return_value = True
-
- class DummyObject:
- @repr_with_fallback()
- def _repr_mimebundle_(self, **kwargs):
- return {
- "fancy/mimetype": "A fancy repr",
- "text/plain": "A simple repr",
- }
+def test_can_display_ipywidgets(
+ mock_import_module,
+ mock_find_spec,
+ mock_get_ipython_shell_name,
+ shell,
+ can_display,
+):
+ class MockDep:
+ __version__ = "8.0.0"
- obj = DummyObject()
- result = obj._repr_mimebundle_()
+ mock_import_module.return_value = MockDep()
+ mock_find_spec.return_value = "a valid import spec"
+ mock_get_ipython_shell_name.return_value = shell
- assert "text/plain" in result
- assert len(result) == 1
+ assert _can_display_ipywidgets(["somedep", "8"], message="") == can_display
+ mock_get_ipython_shell_name.assert_called()
if __name__ == "__main__":
diff --git a/python/ray/widgets/render.py b/python/ray/widgets/render.py
index 1fc0c69d5f0f..f9e861d39925 100644
--- a/python/ray/widgets/render.py
+++ b/python/ray/widgets/render.py
@@ -19,7 +19,7 @@ def render(self, **kwargs) -> str:
from the keyword arguments.
Returns:
- str: HTML template with the keys of the kwargs replaced with corresponding
+ HTML template with the keys of the kwargs replaced with corresponding
values.
"""
rendered = self.template
@@ -34,7 +34,6 @@ def list_templates() -> List[pathlib.Path]:
"""List the available HTML templates.
Returns:
- List[pathlib.Path]: A list of files with .html.j2 extensions inside
- ./templates/
+ A list of files with .html.j2 extensions inside ../templates/
"""
return (pathlib.Path(__file__).parent / "templates").glob("*.html.j2")
diff --git a/python/ray/widgets/templates/context.html.j2 b/python/ray/widgets/templates/context.html.j2
index 3e664e01eae2..26cc0ef6c878 100644
--- a/python/ray/widgets/templates/context.html.j2
+++ b/python/ray/widgets/templates/context.html.j2
@@ -1,37 +1,6 @@
-
+
diff --git a/python/ray/widgets/templates/context_logo.html.j2 b/python/ray/widgets/templates/context_logo.html.j2
new file mode 100644
index 000000000000..9233fe3a7722
--- /dev/null
+++ b/python/ray/widgets/templates/context_logo.html.j2
@@ -0,0 +1,13 @@
+
diff --git a/python/ray/widgets/templates/context_table.html.j2 b/python/ray/widgets/templates/context_table.html.j2
new file mode 100644
index 000000000000..d06822d0c1f5
--- /dev/null
+++ b/python/ray/widgets/templates/context_table.html.j2
@@ -0,0 +1,11 @@
+
+
+ Python version: |
+ {{ python_version }} |
+
+
+ Ray version: |
+ {{ ray_version }} |
+
+ {{ dashboard_row }}
+
diff --git a/python/ray/widgets/util.py b/python/ray/widgets/util.py
index 194c31ba3c70..32280ef09a7d 100644
--- a/python/ray/widgets/util.py
+++ b/python/ray/widgets/util.py
@@ -68,11 +68,18 @@ def make_table_html_repr(
def _has_missing(
*deps: Iterable[Union[str, Optional[str]]], message: Optional[str] = None
):
+ """Return a list of missing dependencies.
+
+ Args:
+ deps: Dependencies to check for
+ message: Message to be emitted if a dependency isn't found
+
+ Returns:
+ A list of dependencies which can't be found, if any
+ """
missing = []
for (lib, _) in deps:
- try:
- importlib.import_module(lib)
- except ImportError:
+ if importlib.util.find_spec(lib) is None:
missing.append(lib)
if missing:
@@ -80,11 +87,11 @@ def _has_missing(
message = f"Run `pip install {' '.join(missing)}` for rich notebook output."
if sys.version_info < (3, 8):
- logger.warning(f"Missing packages: {missing}. {message}")
+ logger.info(f"Missing packages: {missing}. {message}")
else:
# stacklevel=3: First level is this function, then ensure_notebook_deps,
# then the actual function affected.
- logger.warning(f"Missing packages: {missing}. {message}", stacklevel=3)
+ logger.info(f"Missing packages: {missing}. {message}", stacklevel=3)
return missing
@@ -115,13 +122,11 @@ def _has_outdated(
message = f"Run `pip install -U {install_str}` for rich notebook output."
if sys.version_info < (3, 8):
- logger.warning(f"Outdated packages:\n{outdated_str}\n{message}")
+ logger.info(f"Outdated packages:\n{outdated_str}\n{message}")
else:
# stacklevel=3: First level is this function, then ensure_notebook_deps,
# then the actual function affected.
- logger.warning(
- f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3
- )
+ logger.info(f"Outdated packages:\n{outdated_str}\n{message}", stacklevel=3)
return outdated
@@ -147,66 +152,61 @@ def repr_with_fallback(
conditions above hold, in which case it returns a mimebundle that only contains
a single text/plain mimetype.
"""
-
- try:
- import IPython
-
- ipython = IPython.get_ipython()
- except (ModuleNotFoundError, ValueError):
- ipython = None
-
message = (
"Run `pip install -U ipywidgets`, then restart "
"the notebook server for rich notebook output."
)
+ if _can_display_ipywidgets(*notebook_deps, message=message):
- def wrapper(func: F) -> F:
- @wraps(func)
- def wrapped(self, *args, **kwargs):
- fallback = (
- # In Google Colab.
- (ipython and "google.colab" in str(ipython))
- or
- # In notebook environment without required dependencies.
- (
- in_notebook()
- and (
- _has_missing(*notebook_deps, message=message)
- or _has_outdated(*notebook_deps, message=message)
- )
- )
- or
- # In ipython shell.
- in_ipython_shell()
- )
- if fallback:
- return {"text/plain": repr(self)}
- else:
+ def wrapper(func: F) -> F:
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
return func(self, *args, **kwargs)
- return wrapped
+ return wrapped
+
+ else:
+
+ def wrapper(func: F) -> F:
+ @wraps(func)
+ def wrapped(self, *args, **kwargs):
+ return {"text/plain": repr(self)}
+
+ return wrapped
return wrapper
def _get_ipython_shell_name() -> str:
- try:
- import IPython
+ if "IPython" in sys.modules:
+ from IPython import get_ipython
+
+ return get_ipython().__class__.__name__
+ return ""
+
+
+def _can_display_ipywidgets(*deps, message) -> bool:
+ # Default to safe behavior: only display widgets if running in a notebook
+ # that has valid dependencies
+ if in_notebook() and not (
+ _has_missing(*deps, message=message) or _has_outdated(*deps, message=message)
+ ):
+ return True
- shell = IPython.get_ipython().__class__.__name__
- return shell
- except (ModuleNotFoundError, NameError, ValueError):
- return ""
+ return False
@DeveloperAPI
-def in_notebook() -> bool:
- """Return whether we are in a Jupyter notebook."""
- shell = _get_ipython_shell_name()
- return shell == "ZMQInteractiveShell" # Jupyter notebook or qtconsole
+def in_notebook(shell_name: Optional[str] = None) -> bool:
+ """Return whether we are in a Jupyter notebook or qtconsole."""
+ if not shell_name:
+ shell_name = _get_ipython_shell_name()
+ return shell_name == "ZMQInteractiveShell"
@DeveloperAPI
-def in_ipython_shell() -> bool:
- shell = _get_ipython_shell_name()
- return shell == "TerminalInteractiveShell" # Terminal running IPython
+def in_ipython_shell(shell_name: Optional[str] = None) -> bool:
+ """Return whether we are in a terminal running IPython"""
+ if not shell_name:
+ shell_name = _get_ipython_shell_name()
+ return shell_name == "TerminalInteractiveShell"