Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optuna broken with new pickle changes #7762

Closed
mrocklin opened this issue Apr 7, 2023 · 5 comments
Closed

Optuna broken with new pickle changes #7762

mrocklin opened this issue Apr 7, 2023 · 5 comments

Comments

@mrocklin
Copy link
Member

mrocklin commented Apr 7, 2023

Right now dask-optuna is sad, I think because the DaskStorage object gets deserialized on the scheduler, and tries to get a client. Code looks like this:

    def __init__(
        self,
        storage: Union[None, str, BaseStorage] = None,
        name: Optional[str] = None,
        client: Optional["distributed.Client"] = None,
    ):
        _imports.check()
        self.name = name or f"dask-storage-{uuid.uuid4().hex}"
        self.client = client or get_client()

        if self.client.asynchronous or getattr(thread_state, "on_event_loop_thread", False):

            async def _register() -> DaskStorage:
                await self.client.run_on_scheduler(  # type: ignore[no-untyped-call]
                    _register_with_scheduler, storage=storage, name=self.name
                )
                return self

            self._started = asyncio.ensure_future(_register())
        else:
            self.client.run_on_scheduler(  # type: ignore[no-untyped-call]
                _register_with_scheduler, storage=storage, name=self.name
            )

@jrbourbeau I'm curious about this logic. Do each of the workers want to register the storage with the scheduler? My sense is that this code could be skipped if we're deserializing the object. Maybe __reduce__ sends a different function which explicitly doesn't try to register things? Thoughts?

@mrocklin
Copy link
Member Author

mrocklin commented Apr 7, 2023

This works:

diff --git a/optuna/integration/dask.py b/optuna/integration/dask.py
index aa5497f77..24b61af8a 100644
--- a/optuna/integration/dask.py
+++ b/optuna/integration/dask.py
@@ -1,5 +1,6 @@
 import asyncio
 from datetime import datetime
+import functools
 from typing import Any
 from typing import Container
 from typing import Dict
@@ -453,22 +454,29 @@ class DaskStorage(BaseStorage):
     ):
         _imports.check()
         self.name = name or f"dask-storage-{uuid.uuid4().hex}"
-        self.client = client or get_client()
+        try:
+            client = client or get_client()
+        except ValueError:  # we're likely on the scheduler
+            return

-        if self.client.asynchronous or getattr(thread_state, "on_event_loop_thread", False):
+        if client.asynchronous or getattr(thread_state, "on_event_loop_thread", False):

             async def _register() -> DaskStorage:
-                await self.client.run_on_scheduler(  # type: ignore[no-untyped-call]
+                await client.run_on_scheduler(  # type: ignore[no-untyped-call]
                     _register_with_scheduler, storage=storage, name=self.name
                 )
                 return self

             self._started = asyncio.ensure_future(_register())
         else:
-            self.client.run_on_scheduler(  # type: ignore[no-untyped-call]
+            client.run_on_scheduler(  # type: ignore[no-untyped-call]
                 _register_with_scheduler, storage=storage, name=self.name
             )

+    @functools.cached_property
+    def client(self):
+        return get_client()
+
     def __await__(self) -> Generator[Any, None, "DaskStorage"]:
         if hasattr(self, "_started"):
             return self._started.__await__()

I tried not using the cached property, and instead assigning things in __init__ but ran into #7763

@mrocklin
Copy link
Member Author

mrocklin commented Apr 7, 2023

@jrbourbeau can I ask you (or someone around you) to try getting this upstream?

@mrocklin
Copy link
Member Author

mrocklin commented Apr 7, 2023

Slightly more efficient:

diff --git a/optuna/integration/dask.py b/optuna/integration/dask.py
index aa5497f77..52bddcd48 100644
--- a/optuna/integration/dask.py
+++ b/optuna/integration/dask.py
@@ -1,5 +1,6 @@
 import asyncio
 from datetime import datetime
+import functools
 from typing import Any
 from typing import Container
 from typing import Dict
@@ -450,24 +451,30 @@ class DaskStorage(BaseStorage):
         storage: Union[None, str, BaseStorage] = None,
         name: Optional[str] = None,
         client: Optional["distributed.Client"] = None,
+        register: bool = True,
     ):
         _imports.check()
         self.name = name or f"dask-storage-{uuid.uuid4().hex}"
-        self.client = client or get_client()
+        if register:
+            client = client or get_client()

-        if self.client.asynchronous or getattr(thread_state, "on_event_loop_thread", False):
+            if client.asynchronous or getattr(thread_state, "on_event_loop_thread", False):

-            async def _register() -> DaskStorage:
-                await self.client.run_on_scheduler(  # type: ignore[no-untyped-call]
+                async def _register() -> DaskStorage:
+                    await client.run_on_scheduler(  # type: ignore[no-untyped-call]
+                        _register_with_scheduler, storage=storage, name=self.name
+                    )
+                    return self
+
+                self._started = asyncio.ensure_future(_register())
+            else:
+                client.run_on_scheduler(  # type: ignore[no-untyped-call]
                     _register_with_scheduler, storage=storage, name=self.name
                 )
-                return self

-            self._started = asyncio.ensure_future(_register())
-        else:
-            self.client.run_on_scheduler(  # type: ignore[no-untyped-call]
-                _register_with_scheduler, storage=storage, name=self.name
-            )
+    @functools.cached_property
+    def client(self):
+        return get_client()

     def __await__(self) -> Generator[Any, None, "DaskStorage"]:
         if hasattr(self, "_started"):
@@ -484,7 +491,7 @@ class DaskStorage(BaseStorage):
         # on the scheduler. This is okay since this DaskStorage instance has already been
         # registered with the scheduler, and ``storage`` is only ever needed during the
         # scheduler registration process. We use ``storage=None`` below by convention.
-        return (DaskStorage, (None, self.name))
+        return (DaskStorage, (None, self.name, None, False))

I have this in a branch here: https://github.com/mrocklin/optuna/tree/dask-client

@jrbourbeau
Copy link
Member

jrbourbeau commented Apr 7, 2023

Ah, good catch. Taking a look now... Will make sure a fix gets pushed upstream

EDIT: xref optuna/optuna#4589

@jrbourbeau
Copy link
Member

Closed via optuna/optuna#4589

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants