diff --git a/metaflow/plugins/pypi/conda_environment.py b/metaflow/plugins/pypi/conda_environment.py index f9b4a050617..ac28ea67fa1 100644 --- a/metaflow/plugins/pypi/conda_environment.py +++ b/metaflow/plugins/pypi/conda_environment.py @@ -5,10 +5,11 @@ import io import json import os -import sys import tarfile +import threading import time -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import wraps from hashlib import sha256 from io import BufferedIOBase, BytesIO from itertools import chain @@ -50,7 +51,6 @@ def decospecs(self): def validate_environment(self, logger, datastore_type): self.datastore_type = datastore_type - self.logger = logger # Avoiding circular imports. from metaflow.plugins import DATASTORES @@ -62,8 +62,21 @@ def validate_environment(self, logger, datastore_type): from .micromamba import Micromamba from .pip import Pip - micromamba = Micromamba() - self.solvers = {"conda": micromamba, "pypi": Pip(micromamba)} + print_lock = threading.Lock() + + def make_thread_safe(func): + @wraps(func) + def wrapper(*args, **kwargs): + with print_lock: + return func(*args, **kwargs) + + return wrapper + + self.logger = make_thread_safe(logger) + + # TODO: Wire up logging + micromamba = Micromamba(self.logger) + self.solvers = {"conda": micromamba, "pypi": Pip(micromamba, self.logger)} def init_environment(self, echo, only_steps=None): # The implementation optimizes for latency to ensure as many operations can @@ -150,6 +163,9 @@ def _path(url, local_path): ( package["path"], # Lazily fetch package from the interweb if needed. + # TODO: Depending on the len_hint, the package might be downloaded from + # the interweb prematurely. save_bytes needs to be adjusted to handle + # this scenario. LazyOpen( package["local_path"], "rb", @@ -166,22 +182,60 @@ def _path(url, local_path): if id_ in dirty: self.write_to_environment_manifest([id_, platform, type_], packages) - # First resolve environments through Conda, before PyPI. + storage = None + if self.datastore_type not in ["local"]: + # Initialize storage for caching if using a remote datastore + storage = self.datastore(_datastore_packageroot(self.datastore, echo)) + self.logger("Bootstrapping virtual environment(s) ...") - for solver in ["conda", "pypi"]: - with ThreadPoolExecutor() as executor: - results = list( - executor.map(lambda x: solve(*x, solver), environments(solver)) - ) - _ = list(map(lambda x: self.solvers[solver].download(*x), results)) - with ThreadPoolExecutor() as executor: - _ = list( - executor.map(lambda x: self.solvers[solver].create(*x), results) - ) - if self.datastore_type not in ["local"]: - # Cache packages only when a remote datastore is in play. - storage = self.datastore(_datastore_packageroot(self.datastore, echo)) - cache(storage, results, solver) + # Sequence of operations: + # 1. Start all conda solves in parallel + # 2. Download conda packages sequentially + # 3. Create and cache conda environments in parallel + # 4. Start PyPI solves in parallel after each conda environment is created + # 5. Download PyPI packages sequentially + # 6. Create and cache PyPI environments in parallel + + with ThreadPoolExecutor() as executor: + # Start all conda solves in parallel + conda_futures = [ + executor.submit(lambda x: solve(*x, "conda"), env) + for env in environments("conda") + ] + + pypi_envs = {env[0]: env for env in environments("pypi")} + pypi_futures = [] + + # Process conda results sequentially for downloads + for future in as_completed(conda_futures): + result = future.result() + # Sequential conda download + self.solvers["conda"].download(*result) + # Parallel conda create and cache + create_future = executor.submit(self.solvers["conda"].create, *result) + if storage: + executor.submit(cache, storage, [result], "conda") + + # Queue PyPI solve to start after conda create + if result[0] in pypi_envs: + + def pypi_solve(env): + create_future.result() # Wait for conda create + return solve(*env, "pypi") + + pypi_futures.append( + executor.submit(pypi_solve, pypi_envs[result[0]]) + ) + + # Process PyPI results sequentially for downloads + for solve_future in pypi_futures: + result = solve_future.result() + # Sequential PyPI download + self.solvers["pypi"].download(*result) + # Parallel PyPI create and cache + executor.submit(self.solvers["pypi"].create, *result) + if storage: + executor.submit(cache, storage, [result], "pypi") self.logger("Virtual environment(s) bootstrapped!") def executable(self, step_name, default=None): @@ -382,7 +436,8 @@ def bootstrap_commands(self, step_name, datastore_type): 'DISABLE_TRACING=True python -m metaflow.plugins.pypi.bootstrap "%s" %s "%s" linux-64' % (self.flow.name, id_, self.datastore_type), "echo 'Environment bootstrapped.'", - "export PATH=$PATH:$(pwd)/micromamba", + # To avoid having to install micromamba in the PATH in micromamba.py, we add it to the PATH here. + "export PATH=$PATH:$(pwd)/micromamba/bin", ] else: # for @conda/@pypi(disabled=True). diff --git a/metaflow/plugins/pypi/micromamba.py b/metaflow/plugins/pypi/micromamba.py index 378d3d5993a..bf5f659e0a7 100644 --- a/metaflow/plugins/pypi/micromamba.py +++ b/metaflow/plugins/pypi/micromamba.py @@ -1,7 +1,9 @@ +import functools import json import os import subprocess import tempfile +import time from metaflow.exception import MetaflowException from metaflow.util import which @@ -20,7 +22,7 @@ def __init__(self, error): class Micromamba(object): - def __init__(self): + def __init__(self, logger=None): # micromamba is a tiny version of the mamba package manager and comes with # metaflow specific performance enhancements. @@ -33,6 +35,12 @@ def __init__(self): os.path.expanduser(_home), "micromamba", ) + + if logger: + self.logger = logger + else: + self.logger = lambda *args, **kwargs: None # No-op logger if not provided + self.bin = ( which(os.environ.get("METAFLOW_PATH_TO_MICROMAMBA") or "micromamba") or which("./micromamba") # to support remote execution @@ -78,6 +86,7 @@ def solve(self, id_, packages, python, platform): "--dry-run", "--no-extra-safety-checks", "--repodata-ttl=86400", + "--safety-checks=disabled", "--retry-clean-cache", "--prefix=%s/prefix" % tmp_dir, ] @@ -91,10 +100,11 @@ def solve(self, id_, packages, python, platform): cmd.append("python==%s" % python) # TODO: Ensure a human readable message is returned when the environment # can't be resolved for any and all reasons. - return [ + solved_packages = [ {k: v for k, v in item.items() if k in ["url"]} for item in self._call(cmd, env)["actions"]["LINK"] ] + return solved_packages def download(self, id_, packages, python, platform): # Unfortunately all the packages need to be catalogued in package cache @@ -103,8 +113,6 @@ def download(self, id_, packages, python, platform): # Micromamba is painfully slow in determining if many packages are infact # already cached. As a perf heuristic, we check if the environment already # exists to short circuit package downloads. - if self.path_to_environment(id_, platform): - return prefix = "{env_dirs}/{keyword}/{platform}/{id}".format( env_dirs=self.info()["envs_dirs"][0], @@ -113,10 +121,14 @@ def download(self, id_, packages, python, platform): id=id_, ) - # Another forced perf heuristic to skip cross-platform downloads. + # cheap check if os.path.exists(f"{prefix}/fake.done"): return + # somewhat expensive check + if self.path_to_environment(id_, platform): + return + with tempfile.TemporaryDirectory() as tmp_dir: env = { "CONDA_SUBDIR": platform, @@ -174,6 +186,7 @@ def create(self, id_, packages, python, platform): cmd.append("{url}".format(**package)) self._call(cmd, env) + @functools.lru_cache(maxsize=None) def info(self): return self._call(["config", "list", "-a"]) @@ -198,18 +211,24 @@ def metadata(self, id_, packages, python, platform): } directories = self.info()["pkgs_dirs"] # search all package caches for packages - metadata = { - url: os.path.join(d, file) + + file_to_path = {} + for d in directories: + if os.path.isdir(d): + try: + with os.scandir(d) as entries: + for entry in entries: + if entry.is_file(): + # Prefer the first occurrence if the file exists in multiple directories + file_to_path.setdefault(entry.name, entry.path) + except OSError: + continue + ret = { + # set package tarball local paths to None if package tarballs are missing + url: file_to_path.get(file) for url, file in packages_to_filenames.items() - for d in directories - if os.path.isdir(d) - and file in os.listdir(d) - and os.path.isfile(os.path.join(d, file)) } - # set package tarball local paths to None if package tarballs are missing - for url in packages_to_filenames: - metadata.setdefault(url, None) - return metadata + return ret def interpreter(self, id_): return os.path.join(self.path_to_environment(id_), "bin/python") diff --git a/metaflow/plugins/pypi/pip.py b/metaflow/plugins/pypi/pip.py index 13750ea742c..577f8eb91b6 100644 --- a/metaflow/plugins/pypi/pip.py +++ b/metaflow/plugins/pypi/pip.py @@ -4,6 +4,7 @@ import shutil import subprocess import tempfile +import time from concurrent.futures import ThreadPoolExecutor from itertools import chain, product from urllib.parse import unquote @@ -50,10 +51,14 @@ def __init__(self, error): class Pip(object): - def __init__(self, micromamba=None): + def __init__(self, micromamba=None, logger=None): # pip is assumed to be installed inside a conda environment managed by # micromamba. pip commands are executed using `micromamba run --prefix` - self.micromamba = micromamba or Micromamba() + self.micromamba = micromamba or Micromamba(logger) + if logger: + self.logger = logger + else: + self.logger = lambda *args, **kwargs: None # No-op logger if not provided def solve(self, id_, packages, python, platform): prefix = self.micromamba.path_to_environment(id_) @@ -123,7 +128,7 @@ def _format(dl_info): **res, subdir_str=( "#subdirectory=%s" % subdirectory if subdirectory else "" - ) + ), ) # used to deduplicate the storage location in case wheel does not # build with enough unique identifiers.