diff --git a/smart_open/gcs.py b/smart_open/gcs.py index dd33ae39..1e642d87 100644 --- a/smart_open/gcs.py +++ b/smart_open/gcs.py @@ -60,7 +60,8 @@ _WHENCE_CHOICES = (START, CURRENT, END) -_SUCCESSFUL_STATUS_CODES = (200, 201) +_UPLOAD_INCOMPLETE_STATUS_CODE = 308 +_UPLOAD_COMPLETE_STATUS_CODES = (200, 201) def _make_range_string(start, stop=None, end=_UNKNOWN_FILE_SIZE): @@ -90,6 +91,18 @@ def __init__(self, message, status_code, text): self.status_code = status_code self.text = text + @classmethod + def from_response(cls, response, part_num, content_length, total_size, headers): + status_code = response.status_code + response_text = response.text + total_size_gb = total_size / 1024.0 ** 3 + + msg = ( + "upload failed (status code: %(status_code)d, response text: %(response_text)s), " + "part #%(part_num)d, %(total_size)d bytes (total %(total_size_gb).3fGB), headers: %(headers)r" + ) % locals() + return cls(msg, response.status_code, response.text) + def open( bucket_id, @@ -362,16 +375,8 @@ def __str__(self): return "(%s, %r, %r)" % (self.__class__.__name__, self._bucket.name, self._blob.name) def __repr__(self): - return ( - "%s(" - "bucket=%r, " - "blob=%r, " - "buffer_size=%r)" - ) % ( - self.__class__.__name__, - self._bucket.name, - self._blob.name, - self._current_part_size, + return "%s(bucket=%r, blob=%r, buffer_size=%r)" % ( + self.__class__.__name__, self._bucket.name, self._blob.name, self._current_part_size, ) @@ -399,6 +404,7 @@ def __init__( self._total_size = 0 self._total_parts = 0 + self._bytes_uploaded = 0 self._current_part = io.BytesIO() self._session = google_requests.AuthorizedSession(self._credentials) @@ -421,12 +427,18 @@ def flush(self): # def close(self): logger.debug("closing") - if self._total_size == 0: # empty files - self._upload_empty_part() - if self._current_part.tell(): - self._upload_next_part() + if not self.closed: + if self._total_size == 0: # empty files + self._upload_empty_part() + else: + self._upload_final_part() + self._client = None logger.debug("successfully closed") + @property + def closed(self): + return self._client is None + def writable(self): """Return True if the stream supports writing.""" return True @@ -453,7 +465,13 @@ def write(self, b): self._current_part.write(b) self._total_size += len(b) - if self._current_part.tell() >= self._min_part_size: + # + # If the size of this part is precisely equal to the minimum part size, + # we don't perform the actual write now, and wait until we see more data. + # We do this because the very last part of the upload must be handled slightly + # differently (see comments in the _upload_next_part method). + # + if self._current_part.tell() > self._min_part_size: self._upload_next_part() return len(b) @@ -470,49 +488,101 @@ def terminate(self): # def _upload_next_part(self): part_num = self._total_parts + 1 + + # upload the largest amount possible given GCS's restriction + # of parts being multiples of 256kB, except for the last one + size_of_leftovers = self._current_part.tell() % self._min_part_size + content_length = self._current_part.tell() - size_of_leftovers + + # a final upload of 0 bytes does not work, so we need to guard against this edge case + # this results in occasionally keeping an additional 256kB in the buffer after uploading a part, + # but until this is fixed on Google's end there is no other option + # https://stackoverflow.com/questions/60230631/upload-zero-size-final-part-to-google-cloud-storage-resumable-upload + if size_of_leftovers == 0: + content_length -= _REQUIRED_CHUNK_MULTIPLE + + total_size = self._bytes_uploaded + content_length + + start = self._bytes_uploaded + stop = total_size - 1 + + self._current_part.seek(0) + + headers = { + 'Content-Length': str(content_length), + 'Content-Range': _make_range_string(start, stop, _UNKNOWN_FILE_SIZE), + } + logger.info( - "uploading part #%i, %i bytes (total %.3fGB)", - part_num, - self._current_part.tell(), - self._total_size / 1024.0 ** 3 + "uploading part #%i, %i bytes (total %.3fGB) headers %r", + part_num, content_length, total_size / 1024.0 ** 3, headers, ) - content_length = end = self._current_part.tell() - start = self._total_size - content_length - stop = self._total_size - 1 - self._current_part.seek(0) + response = self._session.put( + self._resumable_upload_url, + data=self._current_part.read(content_length), + headers=headers, + ) + + if response.status_code != _UPLOAD_INCOMPLETE_STATUS_CODE: + raise UploadFailedError.from_response( + response, + part_num, + content_length, + self._total_size, + headers, + ) + logger.debug("upload of part #%i finished" % part_num) + + self._total_parts += 1 + self._bytes_uploaded += content_length + # handle the leftovers + self._current_part = io.BytesIO(self._current_part.read()) + self._current_part.seek(0, io.SEEK_END) + + def _upload_final_part(self): + part_num = self._total_parts + 1 + content_length = self._current_part.tell() + stop = self._total_size - 1 + start = self._bytes_uploaded headers = { 'Content-Length': str(content_length), - 'Content-Range': _make_range_string(start, stop, end) + 'Content-Range': _make_range_string(start, stop, self._total_size), } - response = self._session.put(self._resumable_upload_url, data=self._current_part, headers=headers) - - if response.status_code not in _SUCCESSFUL_STATUS_CODES: - msg = ( - "upload failed (" - "status code: %i" - "response text=%s, " - "part #%i, " - "%i bytes (total %.3fGB)" - ) % ( - response.status_code, - response.text, + + logger.info( + "uploading part #%i, %i bytes (total %.3fGB) headers %r", + part_num, content_length, self._total_size / 1024.0 ** 3, headers, + ) + + self._current_part.seek(0) + + response = self._session.put( + self._resumable_upload_url, + data=self._current_part, + headers=headers, + ) + + if response.status_code not in _UPLOAD_COMPLETE_STATUS_CODES: + raise UploadFailedError.from_response( + response, part_num, - self._current_part.tell(), - self._total_size / 1024.0 ** 3, + content_length, + self._total_size, + headers, ) - raise UploadFailedError(msg, response.status_code, response.text) logger.debug("upload of part #%i finished" % part_num) self._total_parts += 1 + self._bytes_uploaded += content_length self._current_part = io.BytesIO() def _upload_empty_part(self): logger.debug("creating empty file") headers = {'Content-Length': '0'} response = self._session.put(self._resumable_upload_url, headers=headers) - assert response.status_code in _SUCCESSFUL_STATUS_CODES + assert response.status_code in _UPLOAD_COMPLETE_STATUS_CODES self._total_parts += 1 @@ -529,14 +599,6 @@ def __str__(self): return "(%s, %r, %r)" % (self.__class__.__name__, self._bucket.name, self._blob.name) def __repr__(self): - return ( - "%s(" - "bucket=%r, " - "blob=%r, " - "min_part_size=%r)" - ) % ( - self.__class__.__name__, - self._bucket.name, - self._blob.name, - self._min_part_size, + return "%s(bucket=%r, blob=%r, min_part_size=%r)" % ( + self.__class__.__name__, self._bucket.name, self._blob.name, self._min_part_size, ) diff --git a/smart_open/tests/test_gcs.py b/smart_open/tests/test_gcs.py index f021dd78..c955f482 100644 --- a/smart_open/tests/test_gcs.py +++ b/smart_open/tests/test_gcs.py @@ -160,9 +160,9 @@ def delete(self): self._bucket.delete_blob(self) self._exists = False - def download_as_string(self, start=None, end=None): - if start is None: - start = 0 + def download_as_string(self, start=0, end=None): + # mimics Google's API by returning bytes, despite the method name + # https://google-cloud-python.readthedocs.io/en/0.32.0/storage/blobs.html#google.cloud.storage.blob.Blob.download_as_string if end is None: end = self.__contents.tell() self.__contents.seek(start) @@ -171,8 +171,13 @@ def download_as_string(self, start=None, end=None): def exists(self, client=None): return self._exists - def upload_from_string(self, str_): - self.__contents.write(str_) + def upload_from_string(self, data): + # mimics Google's API by accepting bytes or str, despite the method name + # https://google-cloud-python.readthedocs.io/en/0.32.0/storage/blobs.html#google.cloud.storage.blob.Blob.upload_from_string + if isinstance(data, six.string_types): + data = bytes(data) if six.PY2 else bytes(data, 'utf8') + self.__contents = io.BytesIO(data) + self.__contents.seek(0, io.SEEK_END) def write(self, data): self.upload_from_string(data) @@ -296,15 +301,18 @@ class FakeBlobUpload(object): def __init__(self, url, blob): self.url = url self.blob = blob # type: FakeBlob + self._finished = False self.__contents = io.BytesIO() def write(self, data): self.__contents.write(data) def finish(self): - self.__contents.seek(0) - data = self.__contents.read() - self.blob.upload_from_string(data) + if not self._finished: + self.__contents.seek(0) + data = self.__contents.read() + self.blob.upload_from_string(data) + self._finished = True def terminate(self): self.blob.delete() @@ -312,8 +320,9 @@ def terminate(self): class FakeResponse(object): - def __init__(self, status_code=200): + def __init__(self, status_code=200, text=None): self.status_code = status_code + self.text = text class FakeAuthorizedSession(object): @@ -325,12 +334,17 @@ def delete(self, upload_url): upload.terminate() def put(self, url, data=None, headers=None): + upload = self._credentials.client.uploads[url] + if data is not None: - upload = self._credentials.client.uploads[url] - upload.write(data.read()) - if not headers['Content-Range'].endswith(smart_open.gcs._UNKNOWN_FILE_SIZE): - upload.finish() - return FakeResponse() + if hasattr(data, 'read'): + upload.write(data.read()) + else: + upload.write(data) + if not headers.get('Content-Range', '').endswith(smart_open.gcs._UNKNOWN_FILE_SIZE): + upload.finish() + return FakeResponse(200) + return FakeResponse(smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODE) @staticmethod def _blob_with_url(url, client): @@ -359,7 +373,7 @@ def test_unfinished_put_does_not_write_to_blob(self): 'Content-Length': str(4), } response = self.session.put(self.upload_url, data, headers=headers) - self.assertEqual(response.status_code, 200) + self.assertEqual(response.status_code, smart_open.gcs._UPLOAD_INCOMPLETE_STATUS_CODE) self.session._blob_with_url(self.upload_url, self.client) blob_contents = self.blob.download_as_string() self.assertEqual(blob_contents, b'') @@ -690,26 +704,86 @@ def test_write_02(self): self.assertEqual(fout.tell(), 14) def test_write_03(self): - """Does gcs multipart chunking work correctly?""" + """Do multiple writes less than the min_part_size work correctly?""" # write + min_part_size = 256 * 1024 smart_open_write = smart_open.gcs.BufferedOutputBase( - BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=256 * 1024 + BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size ) + local_write = io.BytesIO() + with smart_open_write as fout: - fout.write(b"t" * 262141) + first_part = b"t" * 262141 + fout.write(first_part) + local_write.write(first_part) self.assertEqual(fout._current_part.tell(), 262141) - fout.write(b"t\n") + second_part = b"t\n" + fout.write(second_part) + local_write.write(second_part) self.assertEqual(fout._current_part.tell(), 262143) self.assertEqual(fout._total_parts, 0) - fout.write(b"t") - self.assertEqual(fout._current_part.tell(), 0) + third_part = b"t" + fout.write(third_part) + local_write.write(third_part) + self.assertEqual(fout._current_part.tell(), 262144) + self.assertEqual(fout._total_parts, 0) + + fourth_part = b"t" * 1 + fout.write(fourth_part) + local_write.write(fourth_part) + self.assertEqual(fout._current_part.tell(), 1) self.assertEqual(fout._total_parts, 1) # read back the same key and check its content output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) - self.assertEqual(output, ["t" * 262142 + '\n', "t"]) + local_write.seek(0) + actual = [line.decode("utf-8") for line in list(local_write)] + self.assertEqual(output, actual) + + def test_write_03a(self): + """Do multiple writes greater than the min_part_size work correctly?""" + # write + min_part_size = 256 * 1024 + smart_open_write = smart_open.gcs.BufferedOutputBase( + BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size + ) + local_write = io.BytesIO() + + with smart_open_write as fout: + for i in range(1, 4): + part = b"t" * (min_part_size + 1) + fout.write(part) + local_write.write(part) + self.assertEqual(fout._current_part.tell(), i) + self.assertEqual(fout._total_parts, i) + + # read back the same key and check its content + output = list(smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME))) + local_write.seek(0) + actual = [line.decode("utf-8") for line in list(local_write)] + self.assertEqual(output, actual) + + def test_write_03b(self): + """Does writing a last chunk size equal to a multiple of the min_part_size work?""" + # write + min_part_size = 256 * 1024 + smart_open_write = smart_open.gcs.BufferedOutputBase( + BUCKET_NAME, WRITE_BLOB_NAME, min_part_size=min_part_size + ) + expected = b"t" * min_part_size * 2 + + with smart_open_write as fout: + fout.write(expected) + self.assertEqual(fout._current_part.tell(), 262144) + self.assertEqual(fout._total_parts, 1) + + # read back the same key and check its content + with smart_open.open("gs://{}/{}".format(BUCKET_NAME, WRITE_BLOB_NAME)) as fin: + output = fin.read().encode('utf-8') + + self.assertEqual(output, expected) def test_write_04(self): """Does writing no data cause key with an empty value to be created?"""