Skip to content

Commit

Permalink
gdrive: a pair of cleanup refactors (upstream branch) (#3382)
Browse files Browse the repository at this point in the history
* gdrive: minor clean using funcy 1.14 features

Employ @Retry(filter_errors) and @wrap_prop() to pack things up a bit.

* gdrive: refactor dirs/ids/root_id cache and their uses

- cache root transparently
- make cache dirs and ids symmetric, i.e. both using full paths
- implement .list_cache_paths() instead of .all() like in the rest of
the remotes
  • Loading branch information
Suor authored Feb 23, 2020
1 parent 5101bf1 commit 3a5e588
Showing 1 changed file with 100 additions and 155 deletions.
255 changes: 100 additions & 155 deletions dvc/remote/gdrive.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from collections import defaultdict
import os
import posixpath
import logging
import re
import threading
from urllib.parse import urlparse

from funcy import retry, compose, decorator, wrap_with
from funcy import retry, wrap_with, wrap_prop, cached_property
from funcy.py3 import cat

from dvc.progress import Tqdm
Expand All @@ -19,10 +20,6 @@
FOLDER_MIME_TYPE = "application/vnd.google-apps.folder"


class GDriveRetriableError(DvcException):
pass


class GDrivePathNotFound(DvcException):
def __init__(self, path_info):
super().__init__("Google Drive path '{}' not found.".format(path_info))
Expand All @@ -41,30 +38,18 @@ def __init__(self, path):
)


@decorator
def _wrap_pydrive_retriable(call):
def gdrive_retry(func):
from pydrive2.files import ApiRequestError

try:
result = call()
except ApiRequestError as exception:
retry_codes = ["403", "500", "502", "503", "504"]
if any(
"HttpError {}".format(code) in str(exception)
for code in retry_codes
):
raise GDriveRetriableError("Google API request failed")
raise
return result
retry_re = re.compile(r"HttpError (403|500|502|503|504)")


gdrive_retry = compose(
# 15 tries, start at 0.5s, multiply by golden ratio, cap at 20s
retry(
15, GDriveRetriableError, timeout=lambda a: min(0.5 * 1.618 ** a, 20)
),
_wrap_pydrive_retriable,
)
return retry(
15,
timeout=lambda a: min(0.5 * 1.618 ** a, 20),
errors=ApiRequestError,
filter_errors=lambda exc: retry_re.search(str(exc)),
)(func)


class GDriveURLInfo(CloudURLInfo):
Expand Down Expand Up @@ -126,120 +111,91 @@ def __init__(self, repo, config):
)

self._list_params = None
self._gdrive = None

self._cache_initialized = False
self._remote_root_id = None
self._cached_dirs = None
self._cached_ids = None

@property
@wrap_with(threading.RLock())
@wrap_prop(threading.RLock())
@cached_property
def drive(self):
from pydrive2.auth import RefreshError
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
with open(
self._gdrive_user_credentials_path, "w"
) as credentials_file:
credentials_file.write(
os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)
)

if not self._gdrive:
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive
GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings"
GoogleAuth.DEFAULT_SETTINGS["client_config"] = {
"client_id": self._client_id,
"client_secret": self._client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"revoke_uri": "https://oauth2.googleapis.com/revoke",
"redirect_uri": "",
}
GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True
GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file"
GoogleAuth.DEFAULT_SETTINGS[
"save_credentials_file"
] = self._gdrive_user_credentials_path
GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True
GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.appdata",
]

# Pass non existent settings path to force DEFAULT_SETTINGS loading
gauth = GoogleAuth(settings_file="")

try:
gauth.CommandLineAuth()
except RefreshError as exc:
raise GDriveAccessTokenRefreshError from exc
except KeyError as exc:
raise GDriveMissedCredentialKeyError(
self._gdrive_user_credentials_path
) from exc
# Handle pydrive2.auth.AuthenticationError and other auth failures
except Exception as exc:
raise DvcException("Google Drive authentication failed") from exc
finally:
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
with open(
self._gdrive_user_credentials_path, "w"
) as credentials_file:
credentials_file.write(
os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA)
)

GoogleAuth.DEFAULT_SETTINGS["client_config_backend"] = "settings"
GoogleAuth.DEFAULT_SETTINGS["client_config"] = {
"client_id": self._client_id,
"client_secret": self._client_secret,
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"revoke_uri": "https://oauth2.googleapis.com/revoke",
"redirect_uri": "",
}
GoogleAuth.DEFAULT_SETTINGS["save_credentials"] = True
GoogleAuth.DEFAULT_SETTINGS["save_credentials_backend"] = "file"
GoogleAuth.DEFAULT_SETTINGS[
"save_credentials_file"
] = self._gdrive_user_credentials_path
GoogleAuth.DEFAULT_SETTINGS["get_refresh_token"] = True
GoogleAuth.DEFAULT_SETTINGS["oauth_scope"] = [
"https://www.googleapis.com/auth/drive",
"https://www.googleapis.com/auth/drive.appdata",
]

# Pass non existent settings path to force DEFAULT_SETTINGS loading
gauth = GoogleAuth(settings_file="")

try:
gauth.CommandLineAuth()
except RefreshError as exc:
raise GDriveAccessTokenRefreshError from exc
except KeyError as exc:
raise GDriveMissedCredentialKeyError(
self._gdrive_user_credentials_path
) from exc
# Handle pydrive2.auth.AuthenticationError and other auth failures
except Exception as exc:
raise DvcException(
"Google Drive authentication failed"
) from exc
finally:
if os.getenv(RemoteGDrive.GDRIVE_USER_CREDENTIALS_DATA):
os.remove(self._gdrive_user_credentials_path)

self._gdrive = GoogleDrive(gauth)

return self._gdrive
os.remove(self._gdrive_user_credentials_path)

@wrap_with(threading.RLock())
def _initialize_cache(self):
if self._cache_initialized:
return
return GoogleDrive(gauth)

@wrap_prop(threading.RLock())
@cached_property
def cache(self):
cache = {"dirs": defaultdict(list), "ids": {}}

cache["root_id"] = self._get_remote_id(self.path_info)
cache["dirs"][self.path_info.path] = [cache["root_id"]]
self._cache_path(self.path_info.path, cache["root_id"], cache)

cached_dirs = {}
cached_ids = {}
self._remote_root_id = self._get_remote_id(self.path_info)
for dir1 in self.gdrive_list_item(
"'{}' in parents and trashed=false".format(self._remote_root_id)
for item in self.gdrive_list_item(
"'{}' in parents and trashed=false".format(cache["root_id"])
):
remote_path = posixpath.join(self.path_info.path, dir1["title"])
cached_dirs.setdefault(remote_path, []).append(dir1["id"])
cached_ids[dir1["id"]] = dir1["title"]

self._cached_dirs = cached_dirs
self._cached_ids = cached_ids
self._cache_initialized = True

@property
def cached_dirs(self):
if not self._cache_initialized:
self._initialize_cache()
return self._cached_dirs

@property
def cached_ids(self):
if not self._cache_initialized:
self._initialize_cache()
return self._cached_ids

@property
def remote_root_id(self):
if not self._cache_initialized:
self._initialize_cache()
return self._remote_root_id

@property
remote_path = (self.path_info / item["title"]).path
self._cache_path(remote_path, item["id"], cache)

return cache

def _cache_path(self, remote_path, remote_id, cache=None):
cache = cache or self.cache
cache["dirs"][remote_path].append(remote_id)
cache["ids"][remote_id] = remote_path

@cached_property
def list_params(self):
if not self._list_params:
params = {"corpora": "default"}
if self._bucket != "root" and self._bucket != "appDataFolder":
params["driveId"] = self._get_remote_drive_id(self._bucket)
params["corpora"] = "drive"
self._list_params = params
return self._list_params
params = {"corpora": "default"}
if self._bucket != "root" and self._bucket != "appDataFolder":
params["driveId"] = self._get_remote_drive_id(self._bucket)
params["corpora"] = "drive"
return params

@gdrive_retry
def gdrive_upload_file(
Expand Down Expand Up @@ -300,16 +256,14 @@ def gdrive_list_item(self, query):

@wrap_with(threading.RLock())
def gdrive_create_dir(self, parent_id, title, remote_path):
if parent_id == self.remote_root_id:
cached = self.cached_dirs.get(remote_path, [])
if cached:
return cached[0]
cached = self.cache["dirs"].get(remote_path)
if cached:
return cached[0]

item = self._create_remote_dir(parent_id, title)

if parent_id == self.remote_root_id:
self.cached_dirs.setdefault(remote_path, []).append(item["id"])
self.cached_ids[item["id"]] = item["title"]
if parent_id == self.cache["root_id"]:
self._cache_path(remote_path, item["id"])

return item["id"]

Expand Down Expand Up @@ -362,10 +316,8 @@ def _get_remote_drive_id(self, remote_id):
def _get_cached_remote_ids(self, path):
if not path:
return [self._bucket]
if self._cache_initialized:
if path == self.path_info.path:
return [self.remote_root_id]
return self.cached_dirs.get(path, [])
if "cache" in self.__dict__:
return self.cache["dirs"].get(path, [])
return []

def _path_to_remote_ids(self, path, create):
Expand Down Expand Up @@ -416,25 +368,18 @@ def _download(self, from_info, to_file, name, no_progress_bar):
file_id = self._get_remote_id(from_info)
self.gdrive_download_file(file_id, to_file, name, no_progress_bar)

def all(self):
if not self.cached_ids:
def list_cache_paths(self):
if not self.cache["ids"]:
return

query = "({})".format(
" or ".join(
"'{}' in parents".format(dir_id) for dir_id in self.cached_ids
)
parents_query = " or ".join(
"'{}' in parents".format(dir_id) for dir_id in self.cache["ids"]
)
query = "({}) and trashed=false".format(parents_query)

query += " and trashed=false"
for file1 in self.gdrive_list_item(query):
parent_id = file1["parents"][0]["id"]
path = posixpath.join(self.cached_ids[parent_id], file1["title"])
try:
yield self.path_to_checksum(path)
except ValueError:
# We ignore all the non-cache looking files
logger.debug('Ignoring path as "non-cache looking"')
for item in self.gdrive_list_item(query):
parent_id = item["parents"][0]["id"]
yield posixpath.join(self.cache["ids"][parent_id], item["title"])

def remove(self, path_info):
remote_id = self._get_remote_id(path_info)
Expand Down

0 comments on commit 3a5e588

Please sign in to comment.