Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

boost local bootstrap latency by 20% #2176

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions metaflow/plugins/pypi/conda_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import io
import json
import os
import sys
import tarfile
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from hashlib import sha256
from io import BufferedIOBase, BytesIO
from itertools import chain
Expand Down Expand Up @@ -50,7 +51,6 @@ def decospecs(self):

def validate_environment(self, logger, datastore_type):
self.datastore_type = datastore_type
self.logger = logger

# Avoiding circular imports.
from metaflow.plugins import DATASTORES
Expand All @@ -62,8 +62,21 @@ def validate_environment(self, logger, datastore_type):
from .micromamba import Micromamba
from .pip import Pip

micromamba = Micromamba()
self.solvers = {"conda": micromamba, "pypi": Pip(micromamba)}
print_lock = threading.Lock()

def make_thread_safe(func):
@wraps(func)
def wrapper(*args, **kwargs):
with print_lock:
return func(*args, **kwargs)

return wrapper

self.logger = make_thread_safe(logger)

# TODO: Wire up logging
micromamba = Micromamba(self.logger)
self.solvers = {"conda": micromamba, "pypi": Pip(micromamba, self.logger)}

def init_environment(self, echo, only_steps=None):
# The implementation optimizes for latency to ensure as many operations can
Expand Down Expand Up @@ -150,6 +163,9 @@ def _path(url, local_path):
(
package["path"],
# Lazily fetch package from the interweb if needed.
# TODO: Depending on the len_hint, the package might be downloaded from
# the interweb prematurely. save_bytes needs to be adjusted to handle
# this scenario.
LazyOpen(
package["local_path"],
"rb",
Expand All @@ -166,22 +182,60 @@ def _path(url, local_path):
if id_ in dirty:
self.write_to_environment_manifest([id_, platform, type_], packages)

# First resolve environments through Conda, before PyPI.
storage = None
if self.datastore_type not in ["local"]:
# Initialize storage for caching if using a remote datastore
storage = self.datastore(_datastore_packageroot(self.datastore, echo))

self.logger("Bootstrapping virtual environment(s) ...")
for solver in ["conda", "pypi"]:
with ThreadPoolExecutor() as executor:
results = list(
executor.map(lambda x: solve(*x, solver), environments(solver))
)
_ = list(map(lambda x: self.solvers[solver].download(*x), results))
with ThreadPoolExecutor() as executor:
_ = list(
executor.map(lambda x: self.solvers[solver].create(*x), results)
)
if self.datastore_type not in ["local"]:
# Cache packages only when a remote datastore is in play.
storage = self.datastore(_datastore_packageroot(self.datastore, echo))
cache(storage, results, solver)
# Sequence of operations:
# 1. Start all conda solves in parallel
# 2. Download conda packages sequentially
# 3. Create and cache conda environments in parallel
# 4. Start PyPI solves in parallel after each conda environment is created
# 5. Download PyPI packages sequentially
# 6. Create and cache PyPI environments in parallel

with ThreadPoolExecutor() as executor:
# Start all conda solves in parallel
conda_futures = [
executor.submit(lambda x: solve(*x, "conda"), env)
for env in environments("conda")
]

pypi_envs = {env[0]: env for env in environments("pypi")}
pypi_futures = []

# Process conda results sequentially for downloads
for future in as_completed(conda_futures):
result = future.result()
# Sequential conda download
self.solvers["conda"].download(*result)
# Parallel conda create and cache
create_future = executor.submit(self.solvers["conda"].create, *result)
if storage:
executor.submit(cache, storage, [result], "conda")

# Queue PyPI solve to start after conda create
if result[0] in pypi_envs:

def pypi_solve(env):
create_future.result() # Wait for conda create
return solve(*env, "pypi")

pypi_futures.append(
executor.submit(pypi_solve, pypi_envs[result[0]])
)

# Process PyPI results sequentially for downloads
for solve_future in pypi_futures:
result = solve_future.result()
# Sequential PyPI download
self.solvers["pypi"].download(*result)
# Parallel PyPI create and cache
executor.submit(self.solvers["pypi"].create, *result)
if storage:
executor.submit(cache, storage, [result], "pypi")
self.logger("Virtual environment(s) bootstrapped!")

def executable(self, step_name, default=None):
Expand Down Expand Up @@ -382,7 +436,8 @@ def bootstrap_commands(self, step_name, datastore_type):
'DISABLE_TRACING=True python -m metaflow.plugins.pypi.bootstrap "%s" %s "%s" linux-64'
% (self.flow.name, id_, self.datastore_type),
"echo 'Environment bootstrapped.'",
"export PATH=$PATH:$(pwd)/micromamba",
# To avoid having to install micromamba in the PATH in micromamba.py, we add it to the PATH here.
"export PATH=$PATH:$(pwd)/micromamba/bin",
]
else:
# for @conda/@pypi(disabled=True).
Expand Down
49 changes: 34 additions & 15 deletions metaflow/plugins/pypi/micromamba.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import functools
import json
import os
import subprocess
import tempfile
import time

from metaflow.exception import MetaflowException
from metaflow.util import which
Expand All @@ -20,7 +22,7 @@ def __init__(self, error):


class Micromamba(object):
def __init__(self):
def __init__(self, logger=None):
# micromamba is a tiny version of the mamba package manager and comes with
# metaflow specific performance enhancements.

Expand All @@ -33,6 +35,12 @@ def __init__(self):
os.path.expanduser(_home),
"micromamba",
)

if logger:
self.logger = logger
else:
self.logger = lambda *args, **kwargs: None # No-op logger if not provided

self.bin = (
which(os.environ.get("METAFLOW_PATH_TO_MICROMAMBA") or "micromamba")
or which("./micromamba") # to support remote execution
Expand Down Expand Up @@ -78,6 +86,7 @@ def solve(self, id_, packages, python, platform):
"--dry-run",
"--no-extra-safety-checks",
"--repodata-ttl=86400",
"--safety-checks=disabled",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was there a specific reason for adding this? how does it play out with --no-extra-safety-checks ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep - it seems that extra-safety-checks runs extra checks alltogether. it wasn't immediately clear if safety-checks=disabled also disables the extra-safety-checks.

"--retry-clean-cache",
"--prefix=%s/prefix" % tmp_dir,
]
Expand All @@ -91,10 +100,11 @@ def solve(self, id_, packages, python, platform):
cmd.append("python==%s" % python)
# TODO: Ensure a human readable message is returned when the environment
# can't be resolved for any and all reasons.
return [
solved_packages = [
{k: v for k, v in item.items() if k in ["url"]}
for item in self._call(cmd, env)["actions"]["LINK"]
]
return solved_packages

def download(self, id_, packages, python, platform):
# Unfortunately all the packages need to be catalogued in package cache
Expand All @@ -103,8 +113,6 @@ def download(self, id_, packages, python, platform):
# Micromamba is painfully slow in determining if many packages are infact
# already cached. As a perf heuristic, we check if the environment already
# exists to short circuit package downloads.
if self.path_to_environment(id_, platform):
return

prefix = "{env_dirs}/{keyword}/{platform}/{id}".format(
env_dirs=self.info()["envs_dirs"][0],
Expand All @@ -113,10 +121,14 @@ def download(self, id_, packages, python, platform):
id=id_,
)

# Another forced perf heuristic to skip cross-platform downloads.
# cheap check
if os.path.exists(f"{prefix}/fake.done"):
return

# somewhat expensive check
if self.path_to_environment(id_, platform):
return

with tempfile.TemporaryDirectory() as tmp_dir:
env = {
"CONDA_SUBDIR": platform,
Expand Down Expand Up @@ -174,6 +186,7 @@ def create(self, id_, packages, python, platform):
cmd.append("{url}".format(**package))
self._call(cmd, env)

@functools.lru_cache(maxsize=None)
def info(self):
return self._call(["config", "list", "-a"])

Expand All @@ -198,18 +211,24 @@ def metadata(self, id_, packages, python, platform):
}
directories = self.info()["pkgs_dirs"]
# search all package caches for packages
metadata = {
url: os.path.join(d, file)

file_to_path = {}
for d in directories:
if os.path.isdir(d):
try:
with os.scandir(d) as entries:
for entry in entries:
if entry.is_file():
# Prefer the first occurrence if the file exists in multiple directories
file_to_path.setdefault(entry.name, entry.path)
except OSError:
continue
ret = {
# set package tarball local paths to None if package tarballs are missing
url: file_to_path.get(file)
for url, file in packages_to_filenames.items()
for d in directories
if os.path.isdir(d)
and file in os.listdir(d)
and os.path.isfile(os.path.join(d, file))
}
# set package tarball local paths to None if package tarballs are missing
for url in packages_to_filenames:
metadata.setdefault(url, None)
return metadata
return ret

def interpreter(self, id_):
return os.path.join(self.path_to_environment(id_), "bin/python")
Expand Down
11 changes: 8 additions & 3 deletions metaflow/plugins/pypi/pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import subprocess
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import chain, product
from urllib.parse import unquote
Expand Down Expand Up @@ -50,10 +51,14 @@ def __init__(self, error):


class Pip(object):
def __init__(self, micromamba=None):
def __init__(self, micromamba=None, logger=None):
# pip is assumed to be installed inside a conda environment managed by
# micromamba. pip commands are executed using `micromamba run --prefix`
self.micromamba = micromamba or Micromamba()
self.micromamba = micromamba or Micromamba(logger)
if logger:
self.logger = logger
else:
self.logger = lambda *args, **kwargs: None # No-op logger if not provided

def solve(self, id_, packages, python, platform):
prefix = self.micromamba.path_to_environment(id_)
Expand Down Expand Up @@ -123,7 +128,7 @@ def _format(dl_info):
**res,
subdir_str=(
"#subdirectory=%s" % subdirectory if subdirectory else ""
)
),
)
# used to deduplicate the storage location in case wheel does not
# build with enough unique identifiers.
Expand Down
Loading