Skip to content

Commit

Permalink
Implement module upload plugin (#8698)
Browse files Browse the repository at this point in the history
  • Loading branch information
dbalabka committed Oct 4, 2024
1 parent 36020d6 commit c0eaf07
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import functools
import logging
import os
import shutil
import socket
import subprocess
import sys
import tempfile
import uuid
import zipfile
from collections.abc import Awaitable
from contextlib import contextmanager
from importlib.util import find_spec
from io import BytesIO
from typing import TYPE_CHECKING, Any, Callable, ClassVar
from types import ModuleType
from typing import Any, Tuple
from pathlib import Path

from dask.typing import Key
from dask.utils import funcname, tmpfile
Expand All @@ -29,6 +36,7 @@
from distributed.scheduler import TaskStateState as SchedulerTaskStateState
from distributed.worker import Worker
from distributed.worker_state_machine import TaskStateState as WorkerTaskStateState
from distributed.node import ServerNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1051,3 +1059,165 @@ def setup(self, worker):

def teardown(self, worker):
self._exit_stack.close()


@contextmanager
def serialize_module(
module: ModuleType, exclude: Tuple[str] = ("__pycache__", ".DS_Store")
) -> Path:
module_path = Path(module.__file__)

if module_path.stem == "__init__":
# In case of package we serialize the whole package
module_path = module_path.parent
if "." in module.__name__:
# TODO: the problem is that we serialize the `package.module`, as module.egg that contains module.py,
# but it should contain the whole structure of the package (package/module.py)
raise Exception(
f"Plugin supports only top-level packages or single-file modules. You provided `{module.__name__}`, try `{module.__name__.split('.')[0]}`."
)

# In case of single file we don't need to serialize anything

with tempfile.TemporaryDirectory() as tmp:
package_name = module_path.name

package_copy_path = Path(tmp).joinpath(package_name)
if module_path.is_dir():
copied_package = Path(
shutil.copytree(
module_path,
package_copy_path,
ignore=shutil.ignore_patterns(f"{package_name}.zip", *exclude),
)
)
else:
copied_package = Path(shutil.copy2(module_path, package_copy_path))

archive_path = shutil.make_archive(
# output path including a name w/o extension
base_name=str(copied_package),
format="zip",
# chroot
root_dir=copied_package.parent,
# Name of the directory to archive and a common prefix of all files and directories in the archive
base_dir=package_name,
)

egg_file = shutil.move(archive_path, package_copy_path.with_suffix(".egg"))

# zip file handler
zip = zipfile.ZipFile(egg_file)
# list available files in the container
logger.debug(
"The egg file %s contains the following files %s",
str(egg_file),
str(zip.namelist()),
)

logger.info("Created an egg file %s from %s", str(egg_file), str(module_path))

yield Path(egg_file)


class AbstractUploadModulePlugin:
def __init__(self, module: ModuleType):
self._module_name = module.__name__
self._data: bytes
self._filepath: Path
self._filename: str
with serialize_module(module) as filepath:
self._filename = filepath.name
with open(filepath, "rb") as f:
self._data = f.read()

async def _upload_file(self, node: ServerNode):
response = await node.upload_file(self._filename, self._data, load=True)
assert len(self._data) == response["nbytes"]

async def _upload(self, node: ServerNode):
import zipfile
import sys
try:
from IPython.extensions.autoreload import superreload
except ImportError:
superreload = lambda x: x

# Try to find already loaded module
module = (
sys.modules[self._module_name] if self._module_name in sys.modules else None
)
# Try to find module on disk
module_spec = find_spec(self._module_name)

if not module_spec and not module:
# If module does not exist we keep it as egg file and load it.
logger.info(
'Uploading a new module "%s" to "%s" on %s "%s"',
self._module_name,
str(self._filename),
"worker" if isinstance(node, Worker) else "scheduler",
node.id,
)
await self._upload_file(node)
return

if module:
module_path = self._get_module_dir(module)
else:
module_path = Path(module_spec.origin)

if ".egg" in str(module_path):
# Update the previously uploaded egg module and reload it.
logger.info(
'Uploading an update for a previously uploaded a new module "%s" to "%s" on %s "%s"',
self._module_name,
str(self._filename),
"worker" if isinstance(node, Worker) else "scheduler",
node.id,
)
await self._upload_file(node)
return

with zipfile.ZipFile(BytesIO(self._data), "r") as zip_ref:
# In case, we received egg file for module that exists on node in source code,
# we overwrite each file separately by extracting it from the egg.
logger.info(
'Uploading an update for an existing module "%s" in "%s" on %s "%s"',
self._module_name,
str(module_path.parent),
"worker" if isinstance(node, Worker) else "scheduler",
node.id,
)
zip_ref.extractall(module_path.parent)

# TODO: Do we really need Jupyter's `superreload` here instead of built-in Python's function?
if self._module_name in sys.modules:
# Reload module if it is already loaded
superreload(sys.modules[self._module_name])

@classmethod
def _get_module_dir(cls, module: ModuleType) -> Path:
"""Get the directory of the module."""
module_path = Path(sys.modules[module.__name__].__file__)

if module_path.stem == "__init__":
# In case of package we serialize the whole package
return module_path.parent

# In case of single file we don't need to serialize anything
return module_path


class UploadModule(WorkerPlugin, AbstractUploadModulePlugin):
name = "upload_module"

async def setup(self, worker: Worker):
await self._upload(worker)


class SchedulerUploadModule(SchedulerPlugin, AbstractUploadModulePlugin):
name = "upload_module"

async def start(self, scheduler: Scheduler) -> None:
await self._upload(scheduler)

0 comments on commit c0eaf07

Please sign in to comment.