Skip to content

Commit

Permalink
feat: allow copying files to remote with LocalContext (#487)
Browse files Browse the repository at this point in the history
Fix #486.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit


- **New Features**
- Updated documentation to clarify file handling with the `Local`
context type, emphasizing the use of symlinks for improved efficiency.
- Introduced a new class attribute for configurable symlink options
during file operations.
	- Added a new key for symlink configuration in the remote profile JSON.

- **Bug Fixes**
- Enhanced error handling for file copying processes to ensure smoother
operation.

- **Tests**
- Added a new test class for local context submissions, improving test
coverage for file handling scenarios.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz authored Aug 28, 2024
1 parent 4b98dcd commit f6072fe
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 18 deletions.
3 changes: 2 additions & 1 deletion doc/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ Since [`bash -l`](https://www.gnu.org/software/bash/manual/bash.html#Invoking-Ba
{dargs:argument}`context_type <machine/context_type>`: `Local`

`Local` runs jobs in the local server, but in a different directory.
Files will be copied to the remote directory before jobs start and copied back after jobs finish.
Files will be symlinked to the remote directory before jobs start and copied back after jobs finish.
If the local directory is not accessible with the [batch system](./batch.md), turn off {dargs:argument}`symlink <machine[SSHContext]/remote_profile/symlink>`, and then files on the local directory will be copied to the remote directory.

Since [`bash -l`](https://www.gnu.org/software/bash/manual/bash.html#Invoking-Bash) is used in the shebang line of the submission scripts, the [login shell startup files](https://www.gnu.org/software/bash/manual/bash.html#Invoking-Bash) will be executed, potentially overriding the current environment variables. Therefore, it's advisable to explicitly set the environment variables using {dargs:argument}`envs <resources/envs>` or {dargs:argument}`source_list <resources/source_list>`.

Expand Down
73 changes: 57 additions & 16 deletions dpdispatcher/contexts/local_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import subprocess as sp
from glob import glob
from subprocess import TimeoutExpired
from typing import List

from dargs import Argument

from dpdispatcher.base_context import BaseContext
from dpdispatcher.dlog import dlog
Expand Down Expand Up @@ -60,6 +63,7 @@ def __init__(
self.temp_local_root = os.path.abspath(local_root)
self.temp_remote_root = os.path.abspath(remote_root)
self.remote_profile = remote_profile
self.symlink = remote_profile.get("symlink", True)

@classmethod
def load_from_dict(cls, context_dict):
Expand All @@ -83,6 +87,25 @@ def bind_submission(self, submission):
self.temp_remote_root, submission.submission_hash
)

def _copy_from_local_to_remote(self, local_path, remote_path):
if not os.path.exists(local_path):
raise FileNotFoundError(
f"cannot find uploaded file {os.path.join(local_path)}"
)
if os.path.exists(remote_path):
os.remove(remote_path)
_check_file_path(remote_path)

if self.symlink:
# ensure the file exist
os.symlink(local_path, remote_path)
elif os.path.isfile(local_path):
shutil.copyfile(local_path, remote_path)
elif os.path.isdir(local_path):
shutil.copytree(local_path, remote_path)
else:
raise ValueError(f"Unknown file type: {local_path}")

def upload(self, submission):
os.makedirs(self.remote_root, exist_ok=True)
for ii in submission.belonging_tasks:
Expand All @@ -103,14 +126,9 @@ def upload(self, submission):
file_list.extend(rel_file_list)

for jj in file_list:
if not os.path.exists(os.path.join(local_job, jj)):
raise FileNotFoundError(
"cannot find upload file " + os.path.join(local_job, jj)
)
if os.path.exists(os.path.join(remote_job, jj)):
os.remove(os.path.join(remote_job, jj))
_check_file_path(os.path.join(remote_job, jj))
os.symlink(os.path.join(local_job, jj), os.path.join(remote_job, jj))
self._copy_from_local_to_remote(
os.path.join(local_job, jj), os.path.join(remote_job, jj)
)

local_job = self.local_root
remote_job = self.remote_root
Expand All @@ -128,14 +146,9 @@ def upload(self, submission):
file_list.extend(rel_file_list)

for jj in file_list:
if not os.path.exists(os.path.join(local_job, jj)):
raise FileNotFoundError(
"cannot find upload file " + os.path.join(local_job, jj)
)
if os.path.exists(os.path.join(remote_job, jj)):
os.remove(os.path.join(remote_job, jj))
_check_file_path(os.path.join(remote_job, jj))
os.symlink(os.path.join(local_job, jj), os.path.join(remote_job, jj))
self._copy_from_local_to_remote(
os.path.join(local_job, jj), os.path.join(remote_job, jj)
)

def download(
self, submission, check_exists=False, mark_failure=True, back_error=False
Expand Down Expand Up @@ -336,3 +349,31 @@ def get_return(self, proc):
stdout = None
stderr = None
return ret, stdout, stderr

@classmethod
def machine_subfields(cls) -> List[Argument]:
"""Generate the machine subfields.
Returns
-------
list[Argument]
machine subfields
"""
doc_remote_profile = "The information used to maintain the local machine."
return [
Argument(
"remote_profile",
dict,
optional=True,
doc=doc_remote_profile,
sub_fields=[
Argument(
"symlink",
bool,
optional=True,
default=True,
doc="Whether to use symbolic links to replace copy. This option should be turned off if the local directory is not accessible on the Batch system.",
),
],
)
]
3 changes: 3 additions & 0 deletions dpdispatcher/machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ def serialize(self, if_empty_remote_profile=False):
machine_dict["remote_profile"] = self.context.remote_profile
else:
machine_dict["remote_profile"] = {}
# normalize the dict
base = self.arginfo()
machine_dict = base.normalize_value(machine_dict, trim_pattern="_*")
return machine_dict

def __eq__(self, other):
Expand Down
5 changes: 4 additions & 1 deletion tests/test_argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ def test_machine_argcheck(self):
"context_type": "LocalContext",
"local_root": "./",
"remote_root": "/some/path",
"remote_profile": {},
"remote_profile": {
"symlink": True,
},
"clean_asynchronously": False,
}
self.assertDictEqual(norm_dict, expected_dict)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_run_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,24 @@ def test_async_run_submission(self):
return super().test_async_run_submission()


@unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows")
class TestLocalContextCopy(RunSubmission, unittest.TestCase):
def setUp(self):
super().setUp()
self.temp_dir = tempfile.TemporaryDirectory()
self.machine_dict["context_type"] = "LocalContext"
self.machine_dict["remote_root"] = self.temp_dir.name
self.machine_dict["remote_profile"]["symlink"] = False

def tearDown(self):
super().tearDown()
self.temp_dir.cleanup()

@unittest.skip("It seems the remote file may be deleted")
def test_async_run_submission(self):
return super().test_async_run_submission()


@unittest.skipIf(sys.platform == "win32", "Shell is not supported on Windows")
class TestLazyLocalContext(RunSubmission, unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit f6072fe

Please sign in to comment.