diff --git a/runhouse/resources/module.py b/runhouse/resources/module.py index d4a4cd65f..f4bcebdd6 100644 --- a/runhouse/resources/module.py +++ b/runhouse/resources/module.py @@ -443,13 +443,8 @@ def to( system = ( _get_cluster_from(system, dryrun=self.dryrun) if system else self.system ) - if not env: - if ( - not self.env or (isinstance(self.env, Env) and not self.env.name) - ) and system: - env = system.default_env - else: - env = self.env + + env = self.env if not env else env env = _get_env_from(env) @@ -458,6 +453,9 @@ def to( if isinstance(env, Env): env = env.to(system, force_install=force_install) + if isinstance(env, Env) and not env.name: + env = system.default_env + # We need to backup the system here so the __getstate__ method of the cluster # doesn't wipe the client of this function's cluster when deepcopy copies it. hw_backup = self.system diff --git a/tests/test_resources/test_envs/test_env.py b/tests/test_resources/test_envs/test_env.py index 2036c4d7f..d56d4e37a 100644 --- a/tests/test_resources/test_envs/test_env.py +++ b/tests/test_resources/test_envs/test_env.py @@ -30,6 +30,12 @@ def np_summer(a, b): return int(np.sum([a, b])) +def torch_import(): + import torch + + return str(torch.__version__) + + @pytest.mark.envtest class TestEnv(tests.test_resources.test_resource.TestResource): MAP_FIXTURES = {"resource": "env"} @@ -313,3 +319,8 @@ def test_env_to_with_provider_secret(self, cluster): os.environ["HF_TOKEN"] = "test_hf_token" env = rh.env(name="hf_env", secrets=["huggingface"]) env.to(cluster) + + @pytest.mark.level("local") + def test_env_in_function_factory(self, cluster): + remote_function = rh.function(torch_import, env=["torch"]).to(system=cluster) + assert remote_function() is not None