Skip to content

Commit

Permalink
update: implement --to-remote flag (#5343)
Browse files Browse the repository at this point in the history
* upload: implement --to-remote flag

* unit tests

* preserve frozen= for stage
  • Loading branch information
isidentical authored Feb 3, 2021
1 parent b2c2929 commit fc77618
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 7 deletions.
26 changes: 26 additions & 0 deletions dvc/command/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def run(self):
targets=self.args.targets,
rev=self.args.rev,
recursive=self.args.recursive,
to_remote=self.args.to_remote,
remote=self.args.remote,
jobs=self.args.jobs,
)
except DvcException:
logger.exception("failed update data")
Expand Down Expand Up @@ -48,4 +51,27 @@ def add_parser(subparsers, parent_parser):
default=False,
help="Update all stages in the specified directory.",
)
update_parser.add_argument(
"--to-remote",
action="store_true",
default=False,
help="Update data directly on the remote",
)
update_parser.add_argument(
"-r",
"--remote",
help="Remote storage to perform updates to",
metavar="<name>",
)
update_parser.add_argument(
"-j",
"--jobs",
type=int,
help=(
"Number of jobs to run simultaneously. "
"The default value is 4 * cpu_count(). "
"For SSH remotes, the default is 4. "
),
metavar="<number>",
)
update_parser.set_defaults(func=CmdUpdate)
12 changes: 10 additions & 2 deletions dvc/repo/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@


@locked
def update(self, targets=None, rev=None, recursive=False):
def update(
self,
targets=None,
rev=None,
recursive=False,
to_remote=False,
remote=None,
jobs=None,
):
from ..dvcfile import Dvcfile

if not targets:
Expand All @@ -16,7 +24,7 @@ def update(self, targets=None, rev=None, recursive=False):
stages.update(self.stage.collect(target, recursive=recursive))

for stage in stages:
stage.update(rev)
stage.update(rev, to_remote=to_remote, remote=remote, jobs=jobs)
dvcfile = Dvcfile(self, stage.path)
dvcfile.dump(stage)
stages.add(stage)
Expand Down
6 changes: 4 additions & 2 deletions dvc/stage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,12 @@ def reproduce(self, interactive=False, **kwargs):

return self

def update(self, rev=None):
def update(self, rev=None, to_remote=False, remote=None, jobs=None):
if not (self.is_repo_import or self.is_import):
raise StageUpdateError(self.relpath)
update_import(self, rev=rev)
update_import(
self, rev=rev, to_remote=to_remote, remote=remote, jobs=jobs
)

def reload(self):
return self.dvcfile.stage
Expand Down
14 changes: 12 additions & 2 deletions dvc/stage/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,22 @@
logger = logging.getLogger(__name__)


def update_import(stage, rev=None):
def _update_import_on_remote(stage, remote, jobs):
url = stage.deps[0].path_info.url
stage.outs[0].hash_info = stage.repo.cloud.transfer(
url, jobs=jobs, remote=remote, command="update"
)


def update_import(stage, rev=None, to_remote=False, remote=None, jobs=None):
stage.deps[0].update(rev=rev)
frozen = stage.frozen
stage.frozen = False
try:
stage.reproduce()
if to_remote:
_update_import_on_remote(stage, remote, jobs)
else:
stage.reproduce()
finally:
stage.frozen = frozen

Expand Down
48 changes: 48 additions & 0 deletions tests/func/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,3 +315,51 @@ def test_update_from_subrepos(tmp_dir, dvc, erepo_dir, is_dvc):
"url": repo_path,
"rev_lock": erepo_dir.scm.get_rev(),
}


@pytest.mark.parametrize(
"workspace",
[
pytest.lazy_fixture("local_cloud"),
pytest.lazy_fixture("s3"),
pytest.lazy_fixture("gs"),
pytest.lazy_fixture("hdfs"),
],
indirect=True,
)
def test_update_import_url_to_remote(tmp_dir, dvc, workspace, local_remote):
workspace.gen("foo", "foo")
stage = dvc.imp_url("remote://workspace/foo", to_remote=True)

workspace.gen("foo", "bar")
stage = dvc.update(stage.path, to_remote=True)

dvc.pull("foo")
assert (tmp_dir / "foo").read_text() == "bar"


@pytest.mark.parametrize(
"workspace",
[
pytest.lazy_fixture("local_cloud"),
pytest.lazy_fixture("s3"),
pytest.lazy_fixture("gs"),
pytest.lazy_fixture("hdfs"),
],
indirect=True,
)
def test_update_import_url_to_remote_directory(
tmp_dir, dvc, workspace, local_remote
):
workspace.gen({"data": {"foo": "foo", "bar": {"baz": "baz"}}})
stage = dvc.imp_url("remote://workspace/data", to_remote=True)

workspace.gen({"data": {"foo2": "foo2", "bar": {"baz2": "baz2"}}})
stage = dvc.update(stage.path, to_remote=True)

dvc.pull("data")
assert (tmp_dir / "data").read_text() == {
"foo": "foo",
"foo2": "foo2",
"bar": {"baz": "baz", "baz2": "baz2"},
}
37 changes: 36 additions & 1 deletion tests/unit/command/test_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,40 @@ def test_update(dvc, mocker):
assert cmd.run() == 0

m.assert_called_once_with(
targets=["target1", "target2"], rev="REV", recursive=True,
targets=["target1", "target2"],
rev="REV",
recursive=True,
to_remote=False,
remote=None,
jobs=None,
)


def test_update_to_remote(dvc, mocker):
cli_args = parse_args(
[
"update",
"target1",
"target2",
"--to-remote",
"-j",
"5",
"-r",
"remote",
"--recursive",
]
)
assert cli_args.func == CmdUpdate
cmd = cli_args.func(cli_args)
m = mocker.patch("dvc.repo.Repo.update")

assert cmd.run() == 0

m.assert_called_once_with(
targets=["target1", "target2"],
rev=None,
recursive=True,
to_remote=True,
remote="remote",
jobs=5,
)

0 comments on commit fc77618

Please sign in to comment.