Skip to content

Commit 3d70a1e

Browse files
authored
Merge pull request #2853 from skshetry/gs-external-deps
Support adding directories in google cloud storage remote
2 parents 24d3a87 + 6ba7b1d commit 3d70a1e

File tree

9 files changed

+396
-314
lines changed

9 files changed

+396
-314
lines changed

dvc/remote/gs.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,45 @@ def remove(self, path_info):
138138

139139
blob.delete()
140140

141-
def _list_paths(self, bucket, prefix):
142-
for blob in self.gs.bucket(bucket).list_blobs(prefix=prefix):
141+
def _list_paths(self, path_info, max_items=None):
142+
for blob in self.gs.bucket(path_info.bucket).list_blobs(
143+
prefix=path_info.path, max_results=max_items
144+
):
143145
yield blob.name
144146

145147
def list_cache_paths(self):
146-
return self._list_paths(self.path_info.bucket, self.path_info.path)
148+
return self._list_paths(self.path_info)
149+
150+
def walk_files(self, path_info):
151+
for fname in self._list_paths(path_info / ""):
152+
# skip nested empty directories
153+
if fname.endswith("/"):
154+
continue
155+
yield path_info.replace(fname)
156+
157+
def makedirs(self, path_info):
158+
self.gs.bucket(path_info.bucket).blob(
159+
(path_info / "").path
160+
).upload_from_string("")
161+
162+
def isdir(self, path_info):
163+
dir_path = path_info / ""
164+
return bool(list(self._list_paths(dir_path, max_items=1)))
165+
166+
def isfile(self, path_info):
167+
if path_info.path.endswith("/"):
168+
return False
169+
170+
blob = self.gs.bucket(path_info.bucket).blob(path_info.path)
171+
return blob.exists()
147172

148173
def exists(self, path_info):
149-
paths = set(self._list_paths(path_info.bucket, path_info.path))
150-
return any(path_info.path == path for path in paths)
174+
"""Check if the blob exists. If it does not exist,
175+
it could be a part of a directory path.
176+
177+
eg: if `data/file.txt` exists, check for `data` should return True
178+
"""
179+
return self.isfile(path_info) or self.isdir(path_info)
151180

152181
def _upload(self, from_file, to_info, name=None, no_progress_bar=True):
153182
bucket = self.gs.bucket(to_info.bucket)

tests/func/test_api.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,60 +3,12 @@
33

44
import pytest
55

6-
from .test_data_cloud import _should_test_aws
7-
from .test_data_cloud import _should_test_azure
8-
from .test_data_cloud import _should_test_gcp
9-
from .test_data_cloud import _should_test_hdfs
10-
from .test_data_cloud import _should_test_oss
11-
from .test_data_cloud import _should_test_ssh
12-
from .test_data_cloud import get_aws_url
13-
from .test_data_cloud import get_azure_url
14-
from .test_data_cloud import get_gcp_url
15-
from .test_data_cloud import get_hdfs_url
16-
from .test_data_cloud import get_local_url
17-
from .test_data_cloud import get_oss_url
18-
from .test_data_cloud import get_ssh_url
196
from dvc import api
207
from dvc.exceptions import FileMissingError
218
from dvc.main import main
229
from dvc.path_info import URLInfo
2310
from dvc.remote.config import RemoteConfig
24-
25-
26-
# NOTE: staticmethod is only needed in Python 2
27-
class Local:
28-
should_test = staticmethod(lambda: True)
29-
get_url = staticmethod(get_local_url)
30-
31-
32-
class S3:
33-
should_test = staticmethod(_should_test_aws)
34-
get_url = staticmethod(get_aws_url)
35-
36-
37-
class GCP:
38-
should_test = staticmethod(_should_test_gcp)
39-
get_url = staticmethod(get_gcp_url)
40-
41-
42-
class Azure:
43-
should_test = staticmethod(_should_test_azure)
44-
get_url = staticmethod(get_azure_url)
45-
46-
47-
class OSS:
48-
should_test = staticmethod(_should_test_oss)
49-
get_url = staticmethod(get_oss_url)
50-
51-
52-
class SSH:
53-
should_test = staticmethod(_should_test_ssh)
54-
get_url = staticmethod(get_ssh_url)
55-
56-
57-
class HDFS:
58-
should_test = staticmethod(_should_test_hdfs)
59-
get_url = staticmethod(get_hdfs_url)
11+
from tests.remotes import Azure, GCP, HDFS, Local, OSS, S3, SSH
6012

6113

6214
remote_params = [S3, GCP, Azure, OSS, SSH, HDFS]

tests/func/test_data_cloud.py

Lines changed: 23 additions & 207 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import copy
2-
import getpass
32
import logging
43
import os
5-
import platform
64
import shutil
75
import uuid
8-
from subprocess import CalledProcessError
9-
from subprocess import check_output
10-
from subprocess import Popen
116
from unittest import SkipTest
127

138
import pytest
@@ -29,216 +24,37 @@
2924
from dvc.remote.base import STATUS_DELETED
3025
from dvc.remote.base import STATUS_NEW
3126
from dvc.remote.base import STATUS_OK
32-
from dvc.utils import env2bool
3327
from dvc.utils import file_md5
3428
from dvc.utils.compat import str
3529
from dvc.utils.stage import dump_stage_file
3630
from dvc.utils.stage import load_stage_file
3731
from tests.basic_env import TestDvc
3832
from tests.utils import spy
3933

40-
41-
TEST_REMOTE = "upstream"
42-
TEST_SECTION = 'remote "{}"'.format(TEST_REMOTE)
43-
TEST_CONFIG = {
44-
Config.SECTION_CACHE: {},
45-
Config.SECTION_CORE: {Config.SECTION_CORE_REMOTE: TEST_REMOTE},
46-
TEST_SECTION: {Config.SECTION_REMOTE_URL: ""},
47-
}
48-
49-
TEST_AWS_REPO_BUCKET = os.environ.get("DVC_TEST_AWS_REPO_BUCKET", "dvc-test")
50-
TEST_GCP_REPO_BUCKET = os.environ.get("DVC_TEST_GCP_REPO_BUCKET", "dvc-test")
51-
TEST_OSS_REPO_BUCKET = "dvc-test"
52-
53-
TEST_GCP_CREDS_FILE = os.path.abspath(
54-
os.environ.get(
55-
"GOOGLE_APPLICATION_CREDENTIALS",
56-
os.path.join("scripts", "ci", "gcp-creds.json"),
57-
)
58-
)
59-
# Ensure that absolute path is used
60-
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE
61-
62-
TEST_GDRIVE_CLIENT_ID = (
63-
"719861249063-v4an78j9grdtuuuqg3lnm0sugna6v3lh.apps.googleusercontent.com"
34+
from tests.remotes import (
35+
_should_test_aws,
36+
_should_test_azure,
37+
_should_test_gcp,
38+
_should_test_gdrive,
39+
_should_test_hdfs,
40+
_should_test_oss,
41+
_should_test_ssh,
42+
TEST_CONFIG,
43+
TEST_SECTION,
44+
TEST_GCP_CREDS_FILE,
45+
TEST_GDRIVE_CLIENT_ID,
46+
TEST_GDRIVE_CLIENT_SECRET,
47+
TEST_REMOTE,
48+
get_aws_url,
49+
get_azure_url,
50+
get_gcp_url,
51+
get_gdrive_url,
52+
get_hdfs_url,
53+
get_local_url,
54+
get_oss_url,
55+
get_ssh_url,
56+
get_ssh_url_mocked,
6457
)
65-
TEST_GDRIVE_CLIENT_SECRET = "2fy_HyzSwkxkGzEken7hThXb"
66-
67-
68-
def _should_test_aws():
69-
do_test = env2bool("DVC_TEST_AWS", undefined=None)
70-
if do_test is not None:
71-
return do_test
72-
73-
if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"):
74-
return True
75-
76-
return False
77-
78-
79-
def _should_test_gdrive():
80-
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
81-
return True
82-
83-
return False
84-
85-
86-
def _should_test_gcp():
87-
do_test = env2bool("DVC_TEST_GCP", undefined=None)
88-
if do_test is not None:
89-
return do_test
90-
91-
if not os.path.exists(TEST_GCP_CREDS_FILE):
92-
return False
93-
94-
try:
95-
check_output(
96-
[
97-
"gcloud",
98-
"auth",
99-
"activate-service-account",
100-
"--key-file",
101-
TEST_GCP_CREDS_FILE,
102-
]
103-
)
104-
except (CalledProcessError, OSError):
105-
return False
106-
return True
107-
108-
109-
def _should_test_azure():
110-
do_test = env2bool("DVC_TEST_AZURE", undefined=None)
111-
if do_test is not None:
112-
return do_test
113-
114-
return os.getenv("AZURE_STORAGE_CONTAINER_NAME") and os.getenv(
115-
"AZURE_STORAGE_CONNECTION_STRING"
116-
)
117-
118-
119-
def _should_test_oss():
120-
do_test = env2bool("DVC_TEST_OSS", undefined=None)
121-
if do_test is not None:
122-
return do_test
123-
124-
return (
125-
os.getenv("OSS_ENDPOINT")
126-
and os.getenv("OSS_ACCESS_KEY_ID")
127-
and os.getenv("OSS_ACCESS_KEY_SECRET")
128-
)
129-
130-
131-
def _should_test_ssh():
132-
do_test = env2bool("DVC_TEST_SSH", undefined=None)
133-
if do_test is not None:
134-
return do_test
135-
136-
# FIXME: enable on windows
137-
if os.name == "nt":
138-
return False
139-
140-
try:
141-
check_output(["ssh", "-o", "BatchMode=yes", "127.0.0.1", "ls"])
142-
except (CalledProcessError, IOError):
143-
return False
144-
145-
return True
146-
147-
148-
def _should_test_hdfs():
149-
if platform.system() != "Linux":
150-
return False
151-
152-
try:
153-
check_output(
154-
["hadoop", "version"], shell=True, executable=os.getenv("SHELL")
155-
)
156-
except (CalledProcessError, IOError):
157-
return False
158-
159-
p = Popen(
160-
"hadoop fs -ls hdfs://127.0.0.1/",
161-
shell=True,
162-
executable=os.getenv("SHELL"),
163-
)
164-
p.communicate()
165-
if p.returncode != 0:
166-
return False
167-
168-
return True
169-
170-
171-
def get_local_storagepath():
172-
return TestDvc.mkdtemp()
173-
174-
175-
def get_local_url():
176-
return get_local_storagepath()
177-
178-
179-
def get_ssh_url():
180-
return "ssh://{}@127.0.0.1:22{}".format(
181-
getpass.getuser(), get_local_storagepath()
182-
)
183-
184-
185-
def get_ssh_url_mocked(user, port):
186-
path = get_local_storagepath()
187-
if os.name == "nt":
188-
# NOTE: On Windows get_local_storagepath() will return an ntpath
189-
# that looks something like `C:\some\path`, which is not compatible
190-
# with SFTP paths [1], so we need to convert it to a proper posixpath.
191-
# To do that, we should construct a posixpath that would be relative
192-
# to the server's root. In our case our ssh server is running with
193-
# `c:/` as a root, and our URL format requires absolute paths, so the
194-
# resulting path would look like `/some/path`.
195-
#
196-
# [1]https://tools.ietf.org/html/draft-ietf-secsh-filexfer-13#section-6
197-
drive, path = os.path.splitdrive(path)
198-
assert drive.lower() == "c:"
199-
path = path.replace("\\", "/")
200-
url = "ssh://{}@127.0.0.1:{}{}".format(user, port, path)
201-
return url
202-
203-
204-
def get_hdfs_url():
205-
return "hdfs://{}@127.0.0.1{}".format(
206-
getpass.getuser(), get_local_storagepath()
207-
)
208-
209-
210-
def get_aws_storagepath():
211-
return TEST_AWS_REPO_BUCKET + "/" + str(uuid.uuid4())
212-
213-
214-
def get_aws_url():
215-
return "s3://" + get_aws_storagepath()
216-
217-
218-
def get_gdrive_url():
219-
return "gdrive://root/" + str(uuid.uuid4())
220-
221-
222-
def get_gcp_storagepath():
223-
return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4())
224-
225-
226-
def get_gcp_url():
227-
return "gs://" + get_gcp_storagepath()
228-
229-
230-
def get_azure_url():
231-
container_name = os.getenv("AZURE_STORAGE_CONTAINER_NAME")
232-
assert container_name is not None
233-
return "azure://{}/{}".format(container_name, str(uuid.uuid4()))
234-
235-
236-
def get_oss_storagepath():
237-
return "{}/{}".format(TEST_OSS_REPO_BUCKET, (uuid.uuid4()))
238-
239-
240-
def get_oss_url():
241-
return "oss://{}".format(get_oss_storagepath())
24258

24359

24460
class TestDataCloud(TestDvc):

tests/func/test_remote.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
import configobj
55
from mock import patch
66

7-
from .test_data_cloud import get_local_url
87
from dvc.config import Config
98
from dvc.main import main
109
from dvc.path_info import PathInfo
1110
from dvc.remote import RemoteLOCAL
1211
from dvc.remote.base import RemoteBASE
1312
from tests.basic_env import TestDvc
14-
from tests.func.test_data_cloud import get_local_storagepath
13+
from tests.remotes import get_local_url, get_local_storagepath
1514

1615

1716
class TestRemote(TestDvc):

0 commit comments

Comments
 (0)