Skip to content

Commit

Permalink
boost local bootstrap latency by 20% (#2176)
Browse files Browse the repository at this point in the history
* boost local bootstrap latency by 20%

* more optimizations

* wire up logger

* wire up logger
  • Loading branch information
savingoyal authored Dec 12, 2024
1 parent e9f5abd commit 12b6869
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 39 deletions.
97 changes: 76 additions & 21 deletions metaflow/plugins/pypi/conda_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
import io
import json
import os
import sys
import tarfile
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import wraps
from hashlib import sha256
from io import BufferedIOBase, BytesIO
from itertools import chain
Expand Down Expand Up @@ -50,7 +51,6 @@ def decospecs(self):

def validate_environment(self, logger, datastore_type):
self.datastore_type = datastore_type
self.logger = logger

# Avoiding circular imports.
from metaflow.plugins import DATASTORES
Expand All @@ -62,8 +62,21 @@ def validate_environment(self, logger, datastore_type):
from .micromamba import Micromamba
from .pip import Pip

micromamba = Micromamba()
self.solvers = {"conda": micromamba, "pypi": Pip(micromamba)}
print_lock = threading.Lock()

def make_thread_safe(func):
@wraps(func)
def wrapper(*args, **kwargs):
with print_lock:
return func(*args, **kwargs)

return wrapper

self.logger = make_thread_safe(logger)

# TODO: Wire up logging
micromamba = Micromamba(self.logger)
self.solvers = {"conda": micromamba, "pypi": Pip(micromamba, self.logger)}

def init_environment(self, echo, only_steps=None):
# The implementation optimizes for latency to ensure as many operations can
Expand Down Expand Up @@ -150,6 +163,9 @@ def _path(url, local_path):
(
package["path"],
# Lazily fetch package from the interweb if needed.
# TODO: Depending on the len_hint, the package might be downloaded from
# the interweb prematurely. save_bytes needs to be adjusted to handle
# this scenario.
LazyOpen(
package["local_path"],
"rb",
Expand All @@ -166,22 +182,60 @@ def _path(url, local_path):
if id_ in dirty:
self.write_to_environment_manifest([id_, platform, type_], packages)

# First resolve environments through Conda, before PyPI.
storage = None
if self.datastore_type not in ["local"]:
# Initialize storage for caching if using a remote datastore
storage = self.datastore(_datastore_packageroot(self.datastore, echo))

self.logger("Bootstrapping virtual environment(s) ...")
for solver in ["conda", "pypi"]:
with ThreadPoolExecutor() as executor:
results = list(
executor.map(lambda x: solve(*x, solver), environments(solver))
)
_ = list(map(lambda x: self.solvers[solver].download(*x), results))
with ThreadPoolExecutor() as executor:
_ = list(
executor.map(lambda x: self.solvers[solver].create(*x), results)
)
if self.datastore_type not in ["local"]:
# Cache packages only when a remote datastore is in play.
storage = self.datastore(_datastore_packageroot(self.datastore, echo))
cache(storage, results, solver)
# Sequence of operations:
# 1. Start all conda solves in parallel
# 2. Download conda packages sequentially
# 3. Create and cache conda environments in parallel
# 4. Start PyPI solves in parallel after each conda environment is created
# 5. Download PyPI packages sequentially
# 6. Create and cache PyPI environments in parallel

with ThreadPoolExecutor() as executor:
# Start all conda solves in parallel
conda_futures = [
executor.submit(lambda x: solve(*x, "conda"), env)
for env in environments("conda")
]

pypi_envs = {env[0]: env for env in environments("pypi")}
pypi_futures = []

# Process conda results sequentially for downloads
for future in as_completed(conda_futures):
result = future.result()
# Sequential conda download
self.solvers["conda"].download(*result)
# Parallel conda create and cache
create_future = executor.submit(self.solvers["conda"].create, *result)
if storage:
executor.submit(cache, storage, [result], "conda")

# Queue PyPI solve to start after conda create
if result[0] in pypi_envs:

def pypi_solve(env):
create_future.result() # Wait for conda create
return solve(*env, "pypi")

pypi_futures.append(
executor.submit(pypi_solve, pypi_envs[result[0]])
)

# Process PyPI results sequentially for downloads
for solve_future in pypi_futures:
result = solve_future.result()
# Sequential PyPI download
self.solvers["pypi"].download(*result)
# Parallel PyPI create and cache
executor.submit(self.solvers["pypi"].create, *result)
if storage:
executor.submit(cache, storage, [result], "pypi")
self.logger("Virtual environment(s) bootstrapped!")

def executable(self, step_name, default=None):
Expand Down Expand Up @@ -382,7 +436,8 @@ def bootstrap_commands(self, step_name, datastore_type):
'DISABLE_TRACING=True python -m metaflow.plugins.pypi.bootstrap "%s" %s "%s" linux-64'
% (self.flow.name, id_, self.datastore_type),
"echo 'Environment bootstrapped.'",
"export PATH=$PATH:$(pwd)/micromamba",
# To avoid having to install micromamba in the PATH in micromamba.py, we add it to the PATH here.
"export PATH=$PATH:$(pwd)/micromamba/bin",
]
else:
# for @conda/@pypi(disabled=True).
Expand Down
49 changes: 34 additions & 15 deletions metaflow/plugins/pypi/micromamba.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import functools
import json
import os
import subprocess
import tempfile
import time

from metaflow.exception import MetaflowException
from metaflow.util import which
Expand All @@ -20,7 +22,7 @@ def __init__(self, error):


class Micromamba(object):
def __init__(self):
def __init__(self, logger=None):
# micromamba is a tiny version of the mamba package manager and comes with
# metaflow specific performance enhancements.

Expand All @@ -33,6 +35,12 @@ def __init__(self):
os.path.expanduser(_home),
"micromamba",
)

if logger:
self.logger = logger
else:
self.logger = lambda *args, **kwargs: None # No-op logger if not provided

self.bin = (
which(os.environ.get("METAFLOW_PATH_TO_MICROMAMBA") or "micromamba")
or which("./micromamba") # to support remote execution
Expand Down Expand Up @@ -78,6 +86,7 @@ def solve(self, id_, packages, python, platform):
"--dry-run",
"--no-extra-safety-checks",
"--repodata-ttl=86400",
"--safety-checks=disabled",
"--retry-clean-cache",
"--prefix=%s/prefix" % tmp_dir,
]
Expand All @@ -91,10 +100,11 @@ def solve(self, id_, packages, python, platform):
cmd.append("python==%s" % python)
# TODO: Ensure a human readable message is returned when the environment
# can't be resolved for any and all reasons.
return [
solved_packages = [
{k: v for k, v in item.items() if k in ["url"]}
for item in self._call(cmd, env)["actions"]["LINK"]
]
return solved_packages

def download(self, id_, packages, python, platform):
# Unfortunately all the packages need to be catalogued in package cache
Expand All @@ -103,8 +113,6 @@ def download(self, id_, packages, python, platform):
# Micromamba is painfully slow in determining if many packages are infact
# already cached. As a perf heuristic, we check if the environment already
# exists to short circuit package downloads.
if self.path_to_environment(id_, platform):
return

prefix = "{env_dirs}/{keyword}/{platform}/{id}".format(
env_dirs=self.info()["envs_dirs"][0],
Expand All @@ -113,10 +121,14 @@ def download(self, id_, packages, python, platform):
id=id_,
)

# Another forced perf heuristic to skip cross-platform downloads.
# cheap check
if os.path.exists(f"{prefix}/fake.done"):
return

# somewhat expensive check
if self.path_to_environment(id_, platform):
return

with tempfile.TemporaryDirectory() as tmp_dir:
env = {
"CONDA_SUBDIR": platform,
Expand Down Expand Up @@ -174,6 +186,7 @@ def create(self, id_, packages, python, platform):
cmd.append("{url}".format(**package))
self._call(cmd, env)

@functools.lru_cache(maxsize=None)
def info(self):
return self._call(["config", "list", "-a"])

Expand All @@ -198,18 +211,24 @@ def metadata(self, id_, packages, python, platform):
}
directories = self.info()["pkgs_dirs"]
# search all package caches for packages
metadata = {
url: os.path.join(d, file)

file_to_path = {}
for d in directories:
if os.path.isdir(d):
try:
with os.scandir(d) as entries:
for entry in entries:
if entry.is_file():
# Prefer the first occurrence if the file exists in multiple directories
file_to_path.setdefault(entry.name, entry.path)
except OSError:
continue
ret = {
# set package tarball local paths to None if package tarballs are missing
url: file_to_path.get(file)
for url, file in packages_to_filenames.items()
for d in directories
if os.path.isdir(d)
and file in os.listdir(d)
and os.path.isfile(os.path.join(d, file))
}
# set package tarball local paths to None if package tarballs are missing
for url in packages_to_filenames:
metadata.setdefault(url, None)
return metadata
return ret

def interpreter(self, id_):
return os.path.join(self.path_to_environment(id_), "bin/python")
Expand Down
11 changes: 8 additions & 3 deletions metaflow/plugins/pypi/pip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import subprocess
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from itertools import chain, product
from urllib.parse import unquote
Expand Down Expand Up @@ -50,10 +51,14 @@ def __init__(self, error):


class Pip(object):
def __init__(self, micromamba=None):
def __init__(self, micromamba=None, logger=None):
# pip is assumed to be installed inside a conda environment managed by
# micromamba. pip commands are executed using `micromamba run --prefix`
self.micromamba = micromamba or Micromamba()
self.micromamba = micromamba or Micromamba(logger)
if logger:
self.logger = logger
else:
self.logger = lambda *args, **kwargs: None # No-op logger if not provided

def solve(self, id_, packages, python, platform):
prefix = self.micromamba.path_to_environment(id_)
Expand Down Expand Up @@ -123,7 +128,7 @@ def _format(dl_info):
**res,
subdir_str=(
"#subdirectory=%s" % subdirectory if subdirectory else ""
)
),
)
# used to deduplicate the storage location in case wheel does not
# build with enough unique identifiers.
Expand Down

0 comments on commit 12b6869

Please sign in to comment.