Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a disconnect button for RayContext in notebooks #35507

Merged
merged 3 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 74 additions & 16 deletions python/ray/_private/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
from ray.util.tracing.tracing_helper import _import_from_string
from ray.widgets import Template
from ray.widgets.util import repr_with_fallback

SCRIPT_MODE = 0
WORKER_MODE = 1
Expand Down Expand Up @@ -1019,6 +1020,10 @@ class BaseContext(metaclass=ABCMeta):
Base class for RayContext and ClientContext
"""

dashboard_url: Optional[str]
python_version: str
ray_version: str

@abstractmethod
def disconnect(self):
"""
Expand All @@ -1036,6 +1041,73 @@ def __enter__(self):
def __exit__(self):
pass

def _context_table_template(self):
if self.dashboard_url:
dashboard_row = Template("context_dashrow.html.j2").render(
dashboard_url="http://" + self.dashboard_url
)
else:
dashboard_row = None

return Template("context_table.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
dashboard_row=dashboard_row,
)

def _repr_html_(self):
return Template("context.html.j2").render(
context_logo=Template("context_logo.html.j2").render(),
context_table=self._context_table_template(),
)

@repr_with_fallback(["ipywidgets", "8"])
def _get_widget_bundle(self, **kwargs) -> Dict[str, Any]:
"""Get the mimebundle for the widget representation of the context.

Args:
**kwargs: Passed to the _repr_mimebundle_() function for the widget

Returns:
Dictionary ("mimebundle") of the widget representation of the context.
"""
import ipywidgets

disconnect_button = ipywidgets.Button(
description="Disconnect",
disabled=False,
button_style="",
tooltip="Disconnect from the Ray cluster",
layout=ipywidgets.Layout(margin="auto 0px 0px 0px"),
)

def disconnect_callback(button):
button.disabled = True
button.description = "Disconnecting..."
self.disconnect()
button.description = "Disconnected"

disconnect_button.on_click(disconnect_callback)
left_content = ipywidgets.VBox(
[
ipywidgets.HTML(Template("context_logo.html.j2").render()),
disconnect_button,
],
layout=ipywidgets.Layout(),
)
right_content = ipywidgets.HTML(self._context_table_template())
widget = ipywidgets.HBox(
[left_content, right_content], layout=ipywidgets.Layout(width="100%")
)
return widget._repr_mimebundle_(**kwargs)

def _repr_mimebundle_(self, **kwargs):
bundle = self._get_widget_bundle(**kwargs)

# Overwrite the widget html repr and default repr with those of the BaseContext
bundle.update({"text/html": self._repr_html_(), "text/plain": repr(self)})
return bundle


@dataclass
class RayContext(BaseContext, Mapping):
Expand All @@ -1047,10 +1119,10 @@ class RayContext(BaseContext, Mapping):
python_version: str
ray_version: str
ray_commit: str
protocol_version = Optional[str]
address_info: Dict[str, Optional[str]]
protocol_version: Optional[str]

def __init__(self, address_info: Dict[str, Optional[str]]):
super().__init__()
self.dashboard_url = get_dashboard_url()
self.python_version = "{}.{}.{}".format(*sys.version_info[:3])
self.ray_version = ray.__version__
Expand Down Expand Up @@ -1092,20 +1164,6 @@ def disconnect(self):
# Include disconnect() to stay consistent with ClientContext
ray.shutdown()

def _repr_html_(self):
if self.dashboard_url:
dashboard_row = Template("context_dashrow.html.j2").render(
dashboard_url="http://" + self.dashboard_url
)
else:
dashboard_row = None

return Template("context.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
dashboard_row=dashboard_row,
)


global_worker = Worker()
"""Worker: The global Worker object for this worker process.
Expand Down
15 changes: 0 additions & 15 deletions python/ray/client_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from ray._private.worker import init as ray_driver_init
from ray.job_config import JobConfig
from ray.util.annotations import Deprecated, PublicAPI
from ray.widgets import Template

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -86,20 +85,6 @@ def _disconnect_with_context(self, force_disconnect: bool) -> None:
# This is only a driver connected to an existing cluster.
ray.shutdown()

def _repr_html_(self):
if self.dashboard_url:
dashboard_row = Template("context_dashrow.html.j2").render(
dashboard_url="http://" + self.dashboard_url
)
else:
dashboard_row = None

return Template("context.html.j2").render(
python_version=self.python_version,
ray_version=self.ray_version,
dashboard_row=dashboard_row,
)


@Deprecated
class ClientBuilder:
Expand Down
Loading