Skip to content

Commit

Permalink
Merge pull request #2809 from casperdcl/gs-progress
Browse files Browse the repository at this point in the history
GS progress for push & pull
  • Loading branch information
efiop authored Nov 21, 2019
2 parents 5c5d1ed + b2f7c84 commit 7bc7a57
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 28 deletions.
79 changes: 62 additions & 17 deletions dvc/remote/gs.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import unicode_literals
from __future__ import unicode_literals, division

import logging
from datetime import timedelta
from functools import wraps
import io
import os.path

from funcy import cached_property

from dvc.config import Config
from dvc.exceptions import DvcException
from dvc.path_info import CloudURLInfo
from dvc.progress import Tqdm
from dvc.remote.base import RemoteBASE
from dvc.scheme import Schemes
from dvc.utils.compat import FileNotFoundError # skipcq: PYL-W0622
Expand All @@ -20,35 +23,57 @@ def dynamic_chunk_size(func):
@wraps(func)
def wrapper(*args, **kwargs):
import requests
from google.cloud.storage.blob import Blob, _DEFAULT_CHUNKSIZE
from google.cloud.storage.blob import Blob

# Default chunk size for gs is 100M, which might be too much for
# particular network (see [1]). So if we are getting ConnectionError,
# we should try lowering the chunk size until we reach the minimum
# allowed chunk size of 256K. Also note that `chunk_size` must be a
# multiple of 256K per the API specification.
# `ConnectionError` may be due to too large `chunk_size`
# (see [#2572]) so try halving on error.
# Note: start with 40 * [default: 256K] = 10M.
# Note: must be multiple of 256K.
#
# [1] https://github.com/iterative/dvc/issues/2572
# [#2572]: https://github.com/iterative/dvc/issues/2572

# skipcq: PYL-W0212
multiplier = int(_DEFAULT_CHUNKSIZE / Blob._CHUNK_SIZE_MULTIPLE)
multiplier = 40
while True:
try:
# skipcq: PYL-W0212
chunk_size = Blob._CHUNK_SIZE_MULTIPLE * multiplier
return func(*args, chunk_size=chunk_size, **kwargs)
except requests.exceptions.ConnectionError:
multiplier = int(multiplier / 2)
multiplier //= 2
if not multiplier:
raise

return wrapper


@dynamic_chunk_size
def _upload_to_bucket(bucket, from_file, to_info, **kwargs):
blob = bucket.blob(to_info.path, **kwargs)
blob.upload_from_filename(from_file)
def _upload_to_bucket(
bucket,
from_file,
to_info,
chunk_size=None,
name=None,
no_progress_bar=True,
):
blob = bucket.blob(to_info.path, chunk_size=chunk_size)
with Tqdm(
desc=name or to_info.path,
total=os.path.getsize(from_file),
bytes=True,
disable=no_progress_bar,
) as pbar:
with io.open(from_file, mode="rb") as fobj:
raw_read = fobj.read

def read(size=chunk_size):
res = raw_read(size)
if res:
pbar.update(len(res))
return res

fobj.read = read
blob.upload_from_file(fobj)


class RemoteGS(RemoteBASE):
Expand Down Expand Up @@ -123,14 +148,34 @@ def exists(self, path_info):
paths = set(self._list_paths(path_info.bucket, path_info.path))
return any(path_info.path == path for path in paths)

def _upload(self, from_file, to_info, **_kwargs):
def _upload(self, from_file, to_info, name=None, no_progress_bar=True):
bucket = self.gs.bucket(to_info.bucket)
_upload_to_bucket(bucket, from_file, to_info)
_upload_to_bucket(
bucket,
from_file,
to_info,
name=name,
no_progress_bar=no_progress_bar,
)

def _download(self, from_info, to_file, **_kwargs):
def _download(self, from_info, to_file, name=None, no_progress_bar=True):
bucket = self.gs.bucket(from_info.bucket)
blob = bucket.get_blob(from_info.path)
blob.download_to_filename(to_file)
with Tqdm(
desc=name or from_info.path,
total=blob.size,
bytes=True,
disable=no_progress_bar,
) as pbar:
with io.open(to_file, mode="wb") as fobj:
raw_write = fobj.write

def write(byte_string):
raw_write(byte_string)
pbar.update(len(byte_string))

fobj.write = write
blob.download_to_file(fobj)

def _generate_download_url(self, path_info, expires=3600):
expiration = timedelta(seconds=int(expires))
Expand Down
12 changes: 1 addition & 11 deletions tests/unit/remote/test_gs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,4 @@ def upload(chunk_size=None):
with pytest.raises(requests.exceptions.ConnectionError):
upload()

assert chunk_sizes == [
104857600,
52428800,
26214400,
13107200,
6553600,
3145728,
1572864,
786432,
262144,
]
assert chunk_sizes == [10485760, 5242880, 2621440, 1310720, 524288, 262144]

0 comments on commit 7bc7a57

Please sign in to comment.