From 2943a93e445f9a1458f46bb2c8baa9ddd0479f3c Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Tue, 23 Jul 2024 15:30:43 -0400 Subject: [PATCH] Install preferred version only if package is not already installed. --- runhouse/resources/packages/package.py | 36 ++++++++++++++++--- .../test_resources/test_data/test_package.py | 15 ++++---- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/runhouse/resources/packages/package.py b/runhouse/resources/packages/package.py index 83a7134cd..4decc486f 100644 --- a/runhouse/resources/packages/package.py +++ b/runhouse/resources/packages/package.py @@ -66,10 +66,11 @@ class Package(Resource): def __init__( self, - name: str = None, - install_method: str = None, - install_target: Union[str, "Folder"] = None, - install_args: str = None, + name: Optional[str] = None, + install_method: Optional[str] = None, + install_target: Optional[Union[str, "Folder"]] = None, + install_args: Optional[str] = None, + preferred_version: Optional[str] = None, dryrun: bool = False, **kwargs, # We have this here to ignore extra arguments when calling from from_config ): @@ -86,6 +87,7 @@ def __init__( self.install_method = install_method self.install_target = install_target self.install_args = install_args + self.preferred_version = preferred_version def config(self, condensed=True): # If the package is just a simple Package.from_string string, no @@ -103,6 +105,7 @@ def config(self, condensed=True): else self.install_target ) config["install_args"] = self.install_args + config["preferred_version"] = self.preferred_version return config def __str__(self): @@ -249,6 +252,25 @@ def _install(self, env: Union[str, "Env"] = None, cluster: "Cluster" = None): return if self.install_method == "pip": + + # If this is a generic pip package, with no version pinned, we want to check if there is a version + # already installed. If there is, then we ignore preferred version and leave the existing version. + # The user can always force a version install by doing `numpy==2.0.0` for example. Else, we install + # the preferred version, that matches their local. + if ( + is_python_package_string(self.install_target) + and self.preferred_version is not None + ): + # Check if this is installed + retcode = run_setup_command( + f"python -c \"import importlib.util; exit(0) if importlib.util.find_spec('{self.install_target}') else exit(1)\"", + cluster=cluster, + )[0] + if retcode != 0: + self.install_target = ( + f"{self.install_target}=={self.preferred_version}" + ) + install_cmd = self._pip_install_cmd(env=env, cluster=cluster) logger.info(f"Running via install_method pip: {install_cmd}") retcode = run_setup_command(install_cmd, cluster=cluster)[0] @@ -493,6 +515,7 @@ def from_string(specifier: str, dryrun=False): # If we are just defaulting to pip, attempt to install the same version of the package # that is already installed locally # Check if the target is only letters, nothing else. This means its a string like 'numpy'. + preferred_version = None if install_method == "pip" and is_python_package_string(target): locally_installed_version = find_locally_installed_version(target) if locally_installed_version: @@ -501,6 +524,10 @@ def from_string(specifier: str, dryrun=False): if local_install_path and Path(local_install_path).exists(): target = (local_install_path, None) + else: + # We want to preferrably install this version of the package server-side + preferred_version = locally_installed_version + # "Local" install method is a special case where we just copy a local folder and add to path if install_method == "local": return Package( @@ -512,6 +539,7 @@ def from_string(specifier: str, dryrun=False): install_target=target, install_args=args, install_method=install_method, + preferred_version=preferred_version, dryrun=dryrun, ) elif install_method == "rh": diff --git a/tests/test_resources/test_data/test_package.py b/tests/test_resources/test_data/test_package.py index 2639adb54..030425a58 100644 --- a/tests/test_resources/test_data/test_package.py +++ b/tests/test_resources/test_data/test_package.py @@ -9,10 +9,10 @@ from runhouse.utils import run_with_logs -def get_plotly_version(): - import plotly +def get_bs4_version(): + import bs4 - return plotly.__version__ + return bs4.__version__ class TestPackage(tests.test_resources.test_resource.TestResource): @@ -152,13 +152,12 @@ def test_local_reqs_on_cluster(self, cluster, local_package): assert isinstance(remote_package.install_target, InstallTarget) @pytest.mark.level("local") - @pytest.mark.skip("Feature deprecated for now") def test_local_package_version_gets_installed(self, cluster): - run_with_logs("pip install plotly==5.9.0") - env = rh.env(name="temp_env", reqs=["plotly"]) + run_with_logs("pip install beautifulsoup4==4.11.1") + env = rh.env(name="temp_env", reqs=["beautifulsoup4"]) - remote_fn = rh.function(get_plotly_version, env=env).to(cluster) - assert remote_fn() == "5.9.0" + remote_fn = rh.function(get_bs4_version, env=env).to(cluster) + assert remote_fn() == "4.11.1" # --------- basic torch index-url testing --------- @pytest.mark.level("unit")