Skip to content

Commit

Permalink
Propagate self.env if it was set. (#916)
Browse files Browse the repository at this point in the history
Breaking example, where we initialize `reqs` in the constructor, which gets set in the env, but is not propagated in `.to` because the env has no name. Pretty sure we can just send `self.env` and later we have a check if it has no name to take a different path.

See test for repro.
  • Loading branch information
rohinb2 committed Jun 20, 2024
1 parent 5f63a79 commit 30bd6c4
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 7 deletions.
12 changes: 5 additions & 7 deletions runhouse/resources/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,8 @@ def to(
system = (
_get_cluster_from(system, dryrun=self.dryrun) if system else self.system
)
if not env:
if (
not self.env or (isinstance(self.env, Env) and not self.env.name)
) and system:
env = system.default_env
else:
env = self.env

env = self.env if not env else env

env = _get_env_from(env)

Expand All @@ -458,6 +453,9 @@ def to(
if isinstance(env, Env):
env = env.to(system, force_install=force_install)

if isinstance(env, Env) and not env.name:
env = system.default_env

# We need to backup the system here so the __getstate__ method of the cluster
# doesn't wipe the client of this function's cluster when deepcopy copies it.
hw_backup = self.system
Expand Down
11 changes: 11 additions & 0 deletions tests/test_resources/test_envs/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def np_summer(a, b):
return int(np.sum([a, b]))


def torch_import():
import torch

return str(torch.__version__)


@pytest.mark.envtest
class TestEnv(tests.test_resources.test_resource.TestResource):
MAP_FIXTURES = {"resource": "env"}
Expand Down Expand Up @@ -313,3 +319,8 @@ def test_env_to_with_provider_secret(self, cluster):
os.environ["HF_TOKEN"] = "test_hf_token"
env = rh.env(name="hf_env", secrets=["huggingface"])
env.to(cluster)

@pytest.mark.level("local")
def test_env_in_function_factory(self, cluster):
remote_function = rh.function(torch_import, env=["torch"]).to(system=cluster)
assert remote_function() is not None

0 comments on commit 30bd6c4

Please sign in to comment.