diff --git a/dvc/config.py b/dvc/config.py index bd569864cb..f3621d907b 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -153,6 +153,8 @@ class Config(object): # pylint: disable=too-many-instance-attributes CONFIG = "config" CONFIG_LOCAL = "config.local" + CREDENTIALPATH = "credentialpath" + LEVEL_LOCAL = 0 LEVEL_REPO = 1 LEVEL_GLOBAL = 2 @@ -221,7 +223,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes # backward compatibility SECTION_AWS = "aws" SECTION_AWS_STORAGEPATH = "storagepath" - SECTION_AWS_CREDENTIALPATH = "credentialpath" + SECTION_AWS_CREDENTIALPATH = CREDENTIALPATH SECTION_AWS_ENDPOINT_URL = "endpointurl" SECTION_AWS_LIST_OBJECTS = "listobjects" SECTION_AWS_REGION = "region" @@ -244,7 +246,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes # backward compatibility SECTION_GCP = "gcp" SECTION_GCP_STORAGEPATH = SECTION_AWS_STORAGEPATH - SECTION_GCP_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH + SECTION_GCP_CREDENTIALPATH = CREDENTIALPATH SECTION_GCP_PROJECTNAME = "projectname" SECTION_GCP_SCHEMA = { SECTION_GCP_STORAGEPATH: str, @@ -261,6 +263,10 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_OSS_ACCESS_KEY_ID = "oss_key_id" SECTION_OSS_ACCESS_KEY_SECRET = "oss_key_secret" SECTION_OSS_ENDPOINT = "oss_endpoint" + # GDrive options + SECTION_GDRIVE_CLIENT_ID = "gdrive_client_id" + SECTION_GDRIVE_CLIENT_SECRET = "gdrive_client_secret" + SECTION_GDRIVE_USER_CREDENTIALS_FILE = "gdrive_user_credentials_file" SECTION_REMOTE_REGEX = r'^\s*remote\s*"(?P.*)"\s*$' SECTION_REMOTE_FMT = 'remote "{}"' @@ -277,7 +283,7 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_REMOTE_URL: str, Optional(SECTION_AWS_REGION): str, Optional(SECTION_AWS_PROFILE): str, - Optional(SECTION_AWS_CREDENTIALPATH): str, + Optional(CREDENTIALPATH): str, Optional(SECTION_AWS_ENDPOINT_URL): str, Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA, Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, @@ -297,6 +303,9 @@ class Config(object): # pylint: disable=too-many-instance-attributes Optional(SECTION_OSS_ACCESS_KEY_ID): str, Optional(SECTION_OSS_ACCESS_KEY_SECRET): str, Optional(SECTION_OSS_ENDPOINT): str, + Optional(SECTION_GDRIVE_CLIENT_ID): str, + Optional(SECTION_GDRIVE_CLIENT_SECRET): str, + Optional(SECTION_GDRIVE_USER_CREDENTIALS_FILE): str, Optional(PRIVATE_CWD): str, Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): BOOL_SCHEMA, } diff --git a/dvc/remote/__init__.py b/dvc/remote/__init__.py index 3ff90d365b..e2c20a2168 100644 --- a/dvc/remote/__init__.py +++ b/dvc/remote/__init__.py @@ -2,6 +2,7 @@ from .config import RemoteConfig from dvc.remote.azure import RemoteAZURE +from dvc.remote.gdrive import RemoteGDrive from dvc.remote.gs import RemoteGS from dvc.remote.hdfs import RemoteHDFS from dvc.remote.http import RemoteHTTP @@ -14,6 +15,7 @@ REMOTES = [ RemoteAZURE, + RemoteGDrive, RemoteGS, RemoteHDFS, RemoteHTTP, diff --git a/dvc/remote/gdrive/__init__.py b/dvc/remote/gdrive/__init__.py new file mode 100644 index 0000000000..5c2c27dab5 --- /dev/null +++ b/dvc/remote/gdrive/__init__.py @@ -0,0 +1,340 @@ +from __future__ import unicode_literals + +import os +import posixpath +import logging +import threading + +from funcy import cached_property, retry, compose, decorator, wrap_with +from funcy.py3 import cat + +from dvc.remote.gdrive.utils import TrackFileReadProgress, FOLDER_MIME_TYPE +from dvc.scheme import Schemes +from dvc.path_info import CloudURLInfo +from dvc.remote.base import RemoteBASE +from dvc.config import Config +from dvc.exceptions import DvcException +from dvc.utils import tmp_fname + +logger = logging.getLogger(__name__) + + +class GDriveRetriableError(DvcException): + def __init__(self, msg): + super(GDriveRetriableError, self).__init__(msg) + + +@decorator +def _wrap_pydrive_retriable(call): + try: + result = call() + except Exception as exception: + retry_codes = ["403", "500", "502", "503", "504"] + if any( + "HttpError {}".format(code) in str(exception) + for code in retry_codes + ): + raise GDriveRetriableError(msg="Google API request failed") + raise + return result + + +gdrive_retry = compose( + # 8 tries, start at 0.5s, multiply by golden ratio, cap at 10s + retry( + 8, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 10) + ), + _wrap_pydrive_retriable, +) + + +class RemoteGDrive(RemoteBASE): + scheme = Schemes.GDRIVE + path_cls = CloudURLInfo + REGEX = r"^gdrive://.*$" + REQUIRES = {"pydrive": "pydrive"} + GDRIVE_USER_CREDENTIALS_DATA = "GDRIVE_USER_CREDENTIALS_DATA" + DEFAULT_USER_CREDENTIALS_FILE = ".dvc/tmp/gdrive-user-credentials.json" + + def __init__(self, repo, config): + super(RemoteGDrive, self).__init__(repo, config) + self.no_traverse = False + self.path_info = self.path_cls(config[Config.SECTION_REMOTE_URL]) + self.config = config + self.init_drive() + + def init_drive(self): + self.gdrive_client_id = self.config.get( + Config.SECTION_GDRIVE_CLIENT_ID, None + ) + self.gdrive_client_secret = self.config.get( + Config.SECTION_GDRIVE_CLIENT_SECRET, None + ) + if not self.gdrive_client_id or not self.gdrive_client_secret: + raise DvcException( + "Please specify Google Drive's client id and " + "secret in DVC's config. Learn more at " + "https://man.dvc.org/remote/add." + ) + self.gdrive_user_credentials_path = ( + tmp_fname(".dvc/tmp/") + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) + else self.config.get( + Config.SECTION_GDRIVE_USER_CREDENTIALS_FILE, + self.DEFAULT_USER_CREDENTIALS_FILE, + ) + ) + + def gdrive_upload_file( + self, args, no_progress_bar=True, from_file="", progress_name="" + ): + item = self.drive.CreateFile( + {"title": args["title"], "parents": [{"id": args["parent_id"]}]} + ) + self.upload_file(item, no_progress_bar, from_file, progress_name) + return item + + def upload_file(self, item, no_progress_bar, from_file, progress_name): + with open(from_file, "rb") as opened_file: + if not no_progress_bar: + opened_file = TrackFileReadProgress(progress_name, opened_file) + if os.stat(from_file).st_size: + item.content = opened_file + item.Upload() + + def gdrive_download_file( + self, file_id, to_file, progress_name, no_progress_bar + ): + from dvc.progress import Tqdm + + gdrive_file = self.drive.CreateFile({"id": file_id}) + with Tqdm( + desc=progress_name, + total=int(gdrive_file["fileSize"]), + disable=no_progress_bar, + ): + gdrive_file.GetContentFile(to_file) + + def gdrive_list_item(self, query): + file_list = self.drive.ListFile({"q": query, "maxResults": 1000}) + + # Isolate and decorate fetching of remote drive items in pages + get_list = gdrive_retry(lambda: next(file_list, None)) + + # Fetch pages until None is received, lazily flatten the thing + return cat(iter(get_list, None)) + + def cache_root_dirs(self): + cached_dirs = {} + cached_ids = {} + for dir1 in self.gdrive_list_item( + "'{}' in parents and trashed=false".format(self.root_id) + ): + cached_dirs.setdefault(dir1["title"], []).append(dir1["id"]) + cached_ids[dir1["id"]] = dir1["title"] + return cached_dirs, cached_ids + + @property + def cached_dirs(self): + if not hasattr(self, "_cached_dirs"): + self.drive + return self._cached_dirs + + @property + def cached_ids(self): + if not hasattr(self, "_cached_ids"): + self.drive + return self._cached_ids + + @property + @wrap_with(threading.RLock()) + def drive(self): + if not hasattr(self, "_gdrive"): + from pydrive.auth import GoogleAuth + from pydrive.drive import GoogleDrive + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + with open( + self.gdrive_user_credentials_path, "w" + ) as credentials_file: + credentials_file.write( + os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA) + ) + + GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings" + GoogleAuth.DEFAULT_SETTINGS["client_config"] = { + "client_id": self.gdrive_client_id, + "client_secret": self.gdrive_client_secret, + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "revoke_uri": "https://oauth2.googleapis.com/revoke", + "redirect_uri": "", + } + GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True + GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file" + GoogleAuth.DEFAULT_SETTINGS[ + "save_credentials_file" + ] = self.gdrive_user_credentials_path + GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True + GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [ + "https://www.googleapis.com/auth/drive", + "https://www.googleapis.com/auth/drive.appdata", + ] + + # Pass non existent settings path to force DEFAULT_SETTINGS loading + gauth = GoogleAuth(settings_file="") + gauth.CommandLineAuth() + + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + os.remove(self.gdrive_user_credentials_path) + + self._gdrive = GoogleDrive(gauth) + + self.root_id = self.get_remote_id(self.path_info, create=True) + self._cached_dirs, self._cached_ids = self.cache_root_dirs() + + return self._gdrive + + @gdrive_retry + def create_remote_dir(self, parent_id, title): + item = self.drive.CreateFile( + { + "title": title, + "parents": [{"id": parent_id}], + "mimeType": FOLDER_MIME_TYPE, + } + ) + item.Upload() + return item + + @gdrive_retry + def get_remote_item(self, name, parents_ids): + if not parents_ids: + return None + query = " or ".join( + "'{}' in parents".format(parent_id) for parent_id in parents_ids + ) + + query += " and trashed=false and title='{}'".format(name) + + # Limit found remote items count to 1 in response + item_list = self.drive.ListFile( + {"q": query, "maxResults": 1} + ).GetList() + return next(iter(item_list), None) + + def resolve_remote_item_from_path(self, parents_ids, path_parts, create): + for path_part in path_parts: + item = self.get_remote_item(path_part, parents_ids) + if not item and create: + item = self.create_remote_dir(parents_ids[0], path_part) + elif not item: + return None + parents_ids = [item["id"]] + return item + + def subtract_root_path(self, path_parts): + if not hasattr(self, "root_id"): + return path_parts, [self.path_info.bucket] + + for part in self.path_info.path.split("/"): + if path_parts and path_parts[0] == part: + path_parts.pop(0) + else: + break + return path_parts, [self.root_id] + + def get_remote_id_from_cache(self, path_info): + remote_ids = [] + path_parts, parents_ids = self.subtract_root_path( + path_info.path.split("/") + ) + if ( + hasattr(self, "_cached_dirs") + and path_info != self.path_info + and path_parts + and (path_parts[0] in self.cached_dirs) + ): + parents_ids = self.cached_dirs[path_parts[0]] + remote_ids = self.cached_dirs[path_parts[0]] + path_parts.pop(0) + + return remote_ids, parents_ids, path_parts + + def get_remote_id(self, path_info, create=False): + remote_ids, parents_ids, path_parts = self.get_remote_id_from_cache( + path_info + ) + + if not path_parts and remote_ids: + return remote_ids[0] + + file1 = self.resolve_remote_item_from_path( + parents_ids, path_parts, create + ) + return file1["id"] if file1 else "" + + def exists(self, path_info): + return self.get_remote_id(path_info) != "" + + def _upload(self, from_file, to_info, name, no_progress_bar): + dirname = to_info.parent + if dirname: + parent_id = self.get_remote_id(dirname, True) + else: + parent_id = to_info.bucket + + gdrive_retry( + lambda: self.gdrive_upload_file( + {"title": to_info.name, "parent_id": parent_id}, + no_progress_bar, + from_file, + name, + ) + )() + + def _download(self, from_info, to_file, name, no_progress_bar): + file_id = self.get_remote_id(from_info) + gdrive_retry( + lambda: self.gdrive_download_file( + file_id, to_file, name, no_progress_bar + ) + )() + + def list_cache_paths(self): + file_id = self.get_remote_id(self.path_info) + prefix = self.path_info.path + for path in self.list_children(file_id): + yield posixpath.join(prefix, path) + + def list_children(self, parent_id): + for file1 in self.gdrive_list_item( + "'{}' in parents and trashed=false".format(parent_id) + ): + for path in self.list_remote_item(file1): + yield path + + def list_remote_item(self, drive_file): + if drive_file["mimeType"] == FOLDER_MIME_TYPE: + for i in self.list_children(drive_file["id"]): + yield posixpath.join(drive_file["title"], i) + else: + yield drive_file["title"] + + def all(self): + if not self.cached_ids: + return + + query = " or ".join( + "'{}' in parents".format(dir_id) for dir_id in self.cached_ids + ) + + query += " and trashed=false" + for file1 in self.gdrive_list_item(query): + parent_id = file1["parents"][0]["id"] + path = posixpath.join(self.cached_ids[parent_id], file1["title"]) + try: + yield self.path_to_checksum(path) + except ValueError: + # We ignore all the non-cache looking files + logger.debug('Ignoring path as "non-cache looking"') diff --git a/dvc/remote/gdrive/utils.py b/dvc/remote/gdrive/utils.py new file mode 100644 index 0000000000..781af811a5 --- /dev/null +++ b/dvc/remote/gdrive/utils.py @@ -0,0 +1,25 @@ +import os + +from dvc.progress import Tqdm + + +FOLDER_MIME_TYPE = "application/vnd.google-apps.folder" + + +class TrackFileReadProgress(object): + def __init__(self, progress_name, fobj): + self.progress_name = progress_name + self.fobj = fobj + file_size = os.fstat(fobj.fileno()).st_size + self.tqdm = Tqdm(desc=self.progress_name, total=file_size) + + def read(self, size): + self.tqdm.update(size) + return self.fobj.read(size) + + def close(self): + self.fobj.close() + self.tqdm.close() + + def __getattr__(self, attr): + return getattr(self.fobj, attr) diff --git a/dvc/scheme.py b/dvc/scheme.py index e12b768f58..5f7a8d1a28 100644 --- a/dvc/scheme.py +++ b/dvc/scheme.py @@ -9,5 +9,6 @@ class Schemes: HTTP = "http" HTTPS = "https" GS = "gs" + GDRIVE = "gdrive" LOCAL = "local" OSS = "oss" diff --git a/setup.py b/setup.py index 00bd03bf66..d12563401b 100644 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ def run(self): # Extra dependencies for remote integrations gs = ["google-cloud-storage==1.19.0"] +gdrive = ["pydrive==1.3.1"] s3 = ["boto3==1.9.115"] azure = ["azure-storage-blob==2.1.0"] oss = ["oss2==2.6.1"] @@ -100,7 +101,7 @@ def run(self): # we can start shipping it by default. ssh_gssapi = ["paramiko[gssapi]>=2.5.0"] hdfs = ["pyarrow==0.14.0"] -all_remotes = gs + s3 + azure + ssh + oss +all_remotes = gs + s3 + azure + ssh + oss + gdrive if os.name != "nt" or sys.version_info[0] != 2: # NOTE: there are no pyarrow wheels for python2 on windows @@ -150,6 +151,7 @@ def run(self): extras_require={ "all": all_remotes, "gs": gs, + "gdrive": gdrive, "s3": s3, "azure": azure, "oss": oss, diff --git a/tests/func/test_data_cloud.py b/tests/func/test_data_cloud.py index 6eb94d7eae..ad0d33020b 100644 --- a/tests/func/test_data_cloud.py +++ b/tests/func/test_data_cloud.py @@ -18,6 +18,7 @@ from dvc.data_cloud import DataCloud from dvc.main import main from dvc.remote import RemoteAZURE +from dvc.remote import RemoteGDrive from dvc.remote import RemoteGS from dvc.remote import RemoteHDFS from dvc.remote import RemoteHTTP @@ -58,6 +59,11 @@ # Ensure that absolute path is used os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = TEST_GCP_CREDS_FILE +TEST_GDRIVE_CLIENT_ID = ( + "719861249063-v4an78j9grdtuuuqg3lnm0sugna6v3lh.apps.googleusercontent.com" +) +TEST_GDRIVE_CLIENT_SECRET = "2fy_HyzSwkxkGzEken7hThXb" + def _should_test_aws(): do_test = env2bool("DVC_TEST_AWS", undefined=None) @@ -70,6 +76,13 @@ def _should_test_aws(): return False +def _should_test_gdrive(): + if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA): + return True + + return False + + def _should_test_gcp(): do_test = env2bool("DVC_TEST_GCP", undefined=None) if do_test is not None: @@ -202,6 +215,10 @@ def get_aws_url(): return "s3://" + get_aws_storagepath() +def get_gdrive_url(): + return "gdrive://root/" + str(uuid.uuid4()) + + def get_gcp_storagepath(): return TEST_GCP_REPO_BUCKET + "/" + str(uuid.uuid4()) @@ -375,6 +392,35 @@ def _get_cloud_class(self): return RemoteS3 +class TestRemoteGDrive(TestDataCloudBase): + def _should_test(self): + return _should_test_gdrive() + + def _setup_cloud(self): + self._ensure_should_run() + + repo = self._get_url() + + config = copy.deepcopy(TEST_CONFIG) + config[TEST_SECTION][Config.SECTION_REMOTE_URL] = repo + config[TEST_SECTION][ + Config.SECTION_GDRIVE_CLIENT_ID + ] = TEST_GDRIVE_CLIENT_ID + config[TEST_SECTION][ + Config.SECTION_GDRIVE_CLIENT_SECRET + ] = TEST_GDRIVE_CLIENT_SECRET + self.dvc.config.config = config + self.cloud = DataCloud(self.dvc) + + self.assertIsInstance(self.cloud.get_remote(), self._get_cloud_class()) + + def _get_url(self): + return get_gdrive_url() + + def _get_cloud_class(self): + return RemoteGDrive + + class TestRemoteGS(TestDataCloudBase): def _should_test(self): return _should_test_gcp() @@ -621,6 +667,36 @@ def _test(self): self._test_cloud(TEST_REMOTE) +class TestRemoteGDriveCLI(TestDataCloudCLIBase): + def _should_test(self): + return _should_test_gdrive() + + def _test(self): + url = get_gdrive_url() + + self.main(["remote", "add", TEST_REMOTE, url]) + self.main( + [ + "remote", + "modify", + TEST_REMOTE, + Config.SECTION_GDRIVE_CLIENT_ID, + TEST_GDRIVE_CLIENT_ID, + ] + ) + self.main( + [ + "remote", + "modify", + TEST_REMOTE, + Config.SECTION_GDRIVE_CLIENT_SECRET, + TEST_GDRIVE_CLIENT_SECRET, + ] + ) + + self._test_cloud(TEST_REMOTE) + + class TestRemoteGSCLI(TestDataCloudCLIBase): def _should_test(self): return _should_test_gcp() diff --git a/tests/unit/remote/gdrive/__init__.py b/tests/unit/remote/gdrive/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/remote/gdrive/conftest.py b/tests/unit/remote/gdrive/conftest.py new file mode 100644 index 0000000000..035ca15094 --- /dev/null +++ b/tests/unit/remote/gdrive/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from dvc.remote.gdrive import RemoteGDrive + + +@pytest.fixture +def gdrive(repo): + ret = RemoteGDrive(None, {"url": "gdrive://root/data"}) + return ret diff --git a/tests/unit/remote/gdrive/test_gdrive.py b/tests/unit/remote/gdrive/test_gdrive.py new file mode 100644 index 0000000000..28e003748c --- /dev/null +++ b/tests/unit/remote/gdrive/test_gdrive.py @@ -0,0 +1,9 @@ +import mock +from dvc.remote.gdrive import RemoteGDrive + + +@mock.patch("dvc.remote.gdrive.RemoteGDrive.init_drive") +def test_init_drive(repo): + url = "gdrive://root/data" + gdrive = RemoteGDrive(repo, {"url": url}) + assert str(gdrive.path_info) == url