From 59f25f6a713cced1dd83351676373fef219c0cac Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Tue, 9 Apr 2019 19:46:00 +0100 Subject: [PATCH] [AIRFLOW-4255] Replace Discovery based api with client based for GCS (#5054) --- UPDATING.md | 24 + airflow/contrib/hooks/gcs_hook.py | 502 ++++++------------ .../operators/gcs_download_operator.py | 3 +- airflow/models/xcom.py | 231 ++++++++ setup.py | 1 + tests/contrib/hooks/test_gcs_hook.py | 441 +++++---------- 6 files changed, 563 insertions(+), 639 deletions(-) create mode 100644 airflow/models/xcom.py diff --git a/UPDATING.md b/UPDATING.md index bf48ee1a708f..caaf7f9de177 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -22,6 +22,30 @@ under the License. This file documents any backwards-incompatible changes in Airflow and assists users migrating to a new version. +## Airflow Master + +### Changes to GoogleCloudStorageHook + +* the discovery-based api (`googleapiclient.discovery`) used in `GoogleCloudStorageHook` is now replaced by the recommended client based api (`google-cloud-storage`). To know the difference between both the libraries, read https://cloud.google.com/apis/docs/client-libraries-explained. PR: [#5054](https://github.com/apache/airflow/pull/5054) +* as a part of this replacement, the `multipart` & `num_retries` parameters for `GoogleCloudStorageHook.upload` method has been removed: + + **Old**: + ```python + def upload(self, bucket, object, filename, + mime_type='application/octet-stream', gzip=False, + multipart=False, num_retries=0): + ``` + + **New**: + ```python + def upload(self, bucket, object, filename, + mime_type='application/octet-stream', gzip=False): + ``` + + The client library uses multipart upload automatically if the object/blob size is more than 8 MB - [source code](https://github.com/googleapis/google-cloud-python/blob/11c543ce7dd1d804688163bc7895cf592feb445f/storage/google/cloud/storage/blob.py#L989-L997). + +* the `generation` parameter is no longer supported in `GoogleCloudStorageHook.delete` and `GoogleCloudStorageHook.insert_object_acl`. + ## Airflow 1.10.3 ### RedisPy dependency updated to v3 series diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py index cab06b92d5de..5e0da2d1f074 100644 --- a/airflow/contrib/hooks/gcs_hook.py +++ b/airflow/contrib/hooks/gcs_hook.py @@ -17,19 +17,15 @@ # specific language governing permissions and limitations # under the License. # +import gzip as gz +import os +import shutil -from googleapiclient.discovery import build -from googleapiclient.http import MediaFileUpload -from googleapiclient.errors import HttpError +from google.cloud import storage from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook from airflow.exceptions import AirflowException -import gzip as gz -import shutil -import re -import os - class GoogleCloudStorageHook(GoogleCloudBaseHook): """ @@ -37,6 +33,8 @@ class GoogleCloudStorageHook(GoogleCloudBaseHook): connection. """ + _conn = None + def __init__(self, google_cloud_storage_conn_id='google_cloud_default', delegate_to=None): @@ -47,9 +45,10 @@ def get_conn(self): """ Returns a Google Cloud Storage service object. """ - http_authorized = self._authorize() - return build( - 'storage', 'v1', http=http_authorized, cache_discovery=False) + if not self._conn: + self._conn = storage.Client(credentials=self._get_credentials()) + + return self._conn # pylint:disable=redefined-builtin def copy(self, source_bucket, source_object, destination_bucket=None, @@ -83,19 +82,18 @@ def copy(self, source_bucket, source_object, destination_bucket=None, if not source_bucket or not source_object: raise ValueError('source_bucket and source_object cannot be empty.') - service = self.get_conn() - try: - service \ - .objects() \ - .copy(sourceBucket=source_bucket, sourceObject=source_object, - destinationBucket=destination_bucket, - destinationObject=destination_object, body='') \ - .execute() - return True - except HttpError as ex: - if ex.resp['status'] == '404': - return False - raise + client = self.get_conn() + source_bucket = client.get_bucket(source_bucket) + source_object = source_bucket.blob(source_object) + destination_bucket = client.get_bucket(destination_bucket) + destination_object = source_bucket.copy_blob( + blob=source_object, + destination_bucket=destination_bucket, + new_name=destination_object) + + self.log.info('Object %s in bucket %s copied to object %s in bucket %s', + source_object.name, source_bucket.name, + destination_object.name, destination_bucket.name) def rewrite(self, source_bucket, source_object, destination_bucket, destination_object=None): @@ -125,29 +123,30 @@ def rewrite(self, source_bucket, source_object, destination_bucket, if not source_bucket or not source_object: raise ValueError('source_bucket and source_object cannot be empty.') - service = self.get_conn() - request_count = 1 - try: - result = service.objects() \ - .rewrite(sourceBucket=source_bucket, sourceObject=source_object, - destinationBucket=destination_bucket, - destinationObject=destination_object, body='') \ - .execute() - self.log.info('Rewrite request #%s: %s', request_count, result) - while not result['done']: - request_count += 1 - result = service.objects() \ - .rewrite(sourceBucket=source_bucket, sourceObject=source_object, - destinationBucket=destination_bucket, - destinationObject=destination_object, - rewriteToken=result['rewriteToken'], body='') \ - .execute() - self.log.info('Rewrite request #%s: %s', request_count, result) - return True - except HttpError as ex: - if ex.resp['status'] == '404': - return False - raise + client = self.get_conn() + source_bucket = client.get_bucket(bucket_name=source_bucket) + source_object = source_bucket.blob(blob_name=source_object) + destination_bucket = client.get_bucket(bucket_name=destination_bucket) + + token, bytes_rewritten, total_bytes = destination_bucket.blob( + blob_name=destination_object).rewrite( + source=source_object + ) + + self.log.info('Total Bytes: %s | Bytes Written: %s', + total_bytes, bytes_rewritten) + + while token is not None: + token, bytes_rewritten, total_bytes = destination_bucket.blob( + blob_name=destination_object).rewrite( + source=source_object, token=token + ) + + self.log.info('Total Bytes: %s | Bytes Written: %s', + total_bytes, bytes_rewritten) + self.log.info('Object %s in bucket %s copied to object %s in bucket %s', + source_object.name, source_bucket.name, + destination_object, destination_bucket.name) # pylint:disable=redefined-builtin def download(self, bucket, object, filename=None): @@ -161,24 +160,19 @@ def download(self, bucket, object, filename=None): :param filename: If set, a local file path where the file should be written to. :type filename: str """ - service = self.get_conn() - downloaded_file_bytes = service \ - .objects() \ - .get_media(bucket=bucket, object=object) \ - .execute() + client = self.get_conn() + bucket = client.get_bucket(bucket) + blob = bucket.blob(blob_name=object) - # Write the file to local file path, if requested. if filename: - write_argument = 'wb' if isinstance(downloaded_file_bytes, bytes) else 'w' - with open(filename, write_argument) as file_fd: - file_fd.write(downloaded_file_bytes) + blob.download_to_filename(filename) + self.log.info('File downloaded to %s', filename) - return downloaded_file_bytes + return blob.download_as_string() # pylint:disable=redefined-builtin def upload(self, bucket, object, filename, - mime_type='application/octet-stream', gzip=False, - multipart=False, num_retries=0): + mime_type='application/octet-stream', gzip=False): """ Uploads a local file to Google Cloud Storage. @@ -192,16 +186,7 @@ def upload(self, bucket, object, filename, :type mime_type: str :param gzip: Option to compress file for upload :type gzip: bool - :param multipart: If True, the upload will be split into multiple HTTP requests. The - default size is 256MiB per request. Pass a number instead of True to - specify the request size, which must be a multiple of 262144 (256KiB). - :type multipart: bool or int - :param num_retries: The number of times to attempt to re-upload the file (or individual - chunks, in the case of multipart uploads). Retries are attempted - with exponential backoff. - :type num_retries: int """ - service = self.get_conn() if gzip: filename_gz = filename + '.gz' @@ -211,44 +196,15 @@ def upload(self, bucket, object, filename, shutil.copyfileobj(f_in, f_out) filename = filename_gz - try: - if multipart: - if multipart is True: - chunksize = 256 * 1024 * 1024 - else: - chunksize = multipart - - if chunksize % (256 * 1024) > 0 or chunksize < 0: - raise ValueError("Multipart size is not a multiple of 262144 (256KiB)") - - media = MediaFileUpload(filename, mimetype=mime_type, - chunksize=chunksize, resumable=True) - - request = service.objects().insert(bucket=bucket, name=object, media_body=media) - response = None - while response is None: - status, response = request.next_chunk(num_retries=num_retries) - if status: - self.log.info("Upload progress %.1f%%", status.progress() * 100) + client = self.get_conn() + bucket = client.get_bucket(bucket_name=bucket) + blob = bucket.blob(blob_name=object) + blob.upload_from_filename(filename=filename, + content_type=mime_type) - else: - media = MediaFileUpload(filename, mime_type) - - service \ - .objects() \ - .insert(bucket=bucket, name=object, media_body=media) \ - .execute(num_retries=num_retries) - - except HttpError as ex: - if ex.resp['status'] == '404': - return False - raise - - finally: - if gzip: - os.remove(filename) - - return True + if gzip: + os.remove(filename) + self.log.info('File %s uploaded to %s in %s bucket', filename, object, bucket) # pylint:disable=redefined-builtin def exists(self, bucket, object): @@ -261,17 +217,10 @@ def exists(self, bucket, object): storage bucket. :type object: str """ - service = self.get_conn() - try: - service \ - .objects() \ - .get(bucket=bucket, object=object) \ - .execute() - return True - except HttpError as ex: - if ex.resp['status'] == '404': - return False - raise + client = self.get_conn() + bucket = client.get_bucket(bucket_name=bucket) + blob = bucket.blob(blob_name=object) + return blob.exists() # pylint:disable=redefined-builtin def is_updated_after(self, bucket, object, ts): @@ -286,57 +235,41 @@ def is_updated_after(self, bucket, object, ts): :param ts: The timestamp to check against. :type ts: datetime.datetime """ - service = self.get_conn() - try: - response = (service - .objects() - .get(bucket=bucket, object=object) - .execute()) + client = self.get_conn() + bucket = storage.Bucket(client=client, name=bucket) + blob = bucket.get_blob(blob_name=object) + blob.reload() - if 'updated' in response: - import dateutil.parser - import dateutil.tz + blob_update_time = blob.updated - if not ts.tzinfo: - ts = ts.replace(tzinfo=dateutil.tz.tzutc()) + if blob_update_time is not None: + import dateutil.tz - updated = dateutil.parser.parse(response['updated']) - self.log.info("Verify object date: %s > %s", updated, ts) + if not ts.tzinfo: + ts = ts.replace(tzinfo=dateutil.tz.tzutc()) - if updated > ts: - return True + self.log.info("Verify object date: %s > %s", blob_update_time, ts) - except HttpError as ex: - if ex.resp['status'] != '404': - raise + if blob_update_time > ts: + return True return False - def delete(self, bucket, object, generation=None): + def delete(self, bucket, object): """ - Delete an object if versioning is not enabled for the bucket, or if generation - parameter is used. + Deletes an object from the bucket. :param bucket: name of the bucket, where the object resides :type bucket: str :param object: name of the object to delete :type object: str - :param generation: if present, permanently delete the object of this generation - :type generation: str - :return: True if succeeded - """ - service = self.get_conn() - - try: - service \ - .objects() \ - .delete(bucket=bucket, object=object, generation=generation) \ - .execute() - return True - except HttpError as ex: - if ex.resp['status'] == '404': - return False - raise + """ + client = self.get_conn() + bucket = client.get_bucket(bucket_name=bucket) + blob = bucket.blob(blob_name=object) + blob.delete() + + self.log.info('Blob %s deleted.', object) def list(self, bucket, versions=None, maxResults=None, prefix=None, delimiter=None): """ @@ -355,38 +288,32 @@ def list(self, bucket, versions=None, maxResults=None, prefix=None, delimiter=No :type delimiter: str :return: a stream of object names matching the filtering criteria """ - service = self.get_conn() + client = self.get_conn() + bucket = client.get_bucket(bucket) - ids = list() + ids = [] pageToken = None while True: - response = service.objects().list( - bucket=bucket, - versions=versions, - maxResults=maxResults, - pageToken=pageToken, + blobs = bucket.list_blobs( + max_results=maxResults, + page_token=pageToken, prefix=prefix, - delimiter=delimiter - ).execute() + delimiter=delimiter, + versions=versions + ) - if 'prefixes' not in response: - if 'items' not in response: - self.log.info("No items found for prefix: %s", prefix) - break + blob_names = [] + for blob in blobs: + blob_names.append(blob.name) - for item in response['items']: - if item and 'name' in item: - ids.append(item['name']) + prefixes = blobs.prefixes + if prefixes: + ids += list(prefixes) else: - for item in response['prefixes']: - ids.append(item) + ids += blob_names - if 'nextPageToken' not in response: - # no further pages of results, so stop the loop - break - - pageToken = response['nextPageToken'] - if not pageToken: + pageToken = blobs.next_page_token + if pageToken is None: # empty next page token break return ids @@ -404,23 +331,13 @@ def get_size(self, bucket, object): self.log.info('Checking the file size of object: %s in bucket: %s', object, bucket) - service = self.get_conn() - try: - response = service.objects().get( - bucket=bucket, - object=object - ).execute() - - if 'name' in response and response['name'][-1] != '/': - # Remove Directories & Just check size of files - size = response['size'] - self.log.info('The file size of %s is %s bytes.', object, size) - return size - else: - raise ValueError('Object is not a file') - except HttpError as ex: - if ex.resp['status'] == '404': - raise ValueError('Object Not Found') + client = self.get_conn() + bucket = client.get_bucket(bucket_name=bucket) + blob = bucket.get_blob(blob_name=object) + blob.reload() + blob_size = blob.size + self.log.info('The file size of %s is %s bytes.', object, blob_size) + return blob_size def get_crc32c(self, bucket, object): """ @@ -434,20 +351,13 @@ def get_crc32c(self, bucket, object): """ self.log.info('Retrieving the crc32c checksum of ' 'object: %s in bucket: %s', object, bucket) - service = self.get_conn() - try: - response = service.objects().get( - bucket=bucket, - object=object - ).execute() - - crc32c = response['crc32c'] - self.log.info('The crc32c checksum of %s is %s', object, crc32c) - return crc32c - - except HttpError as ex: - if ex.resp['status'] == '404': - raise ValueError('Object Not Found') + client = self.get_conn() + bucket = client.get_bucket(bucket_name=bucket) + blob = bucket.get_blob(blob_name=object) + blob.reload() + blob_crc32c = blob.crc32c + self.log.info('The crc32c checksum of %s is %s', object, blob_crc32c) + return blob_crc32c def get_md5hash(self, bucket, object): """ @@ -461,21 +371,16 @@ def get_md5hash(self, bucket, object): """ self.log.info('Retrieving the MD5 hash of ' 'object: %s in bucket: %s', object, bucket) - service = self.get_conn() - try: - response = service.objects().get( - bucket=bucket, - object=object - ).execute() - - md5hash = response['md5Hash'] - self.log.info('The md5Hash of %s is %s', object, md5hash) - return md5hash - - except HttpError as ex: - if ex.resp['status'] == '404': - raise ValueError('Object Not Found') - + client = self.get_conn() + bucket = client.get_bucket(bucket_name=bucket) + blob = bucket.get_blob(blob_name=object) + blob.reload() + blob_md5hash = blob.md5_hash + self.log.info('The md5Hash of %s is %s', object, blob_md5hash) + return blob_md5hash + + @GoogleCloudBaseHook.catch_http_exception + @GoogleCloudBaseHook.fallback_to_default_project_id def create_bucket(self, bucket_name, resource=None, @@ -525,57 +430,23 @@ def create_bucket(self, :return: If successful, it returns the ``id`` of the bucket. """ - project_id = project_id if project_id is not None else self.project_id - storage_classes = [ - 'MULTI_REGIONAL', - 'REGIONAL', - 'NEARLINE', - 'COLDLINE', - 'STANDARD', # alias for MULTI_REGIONAL/REGIONAL, based on location - ] - self.log.info('Creating Bucket: %s; Location: %s; Storage Class: %s', bucket_name, location, storage_class) - if storage_class not in storage_classes: - raise ValueError( - 'Invalid value ({}) passed to storage_class. Value should be ' - 'one of {}'.format(storage_class, storage_classes)) - - if not re.match('[a-zA-Z0-9]+', bucket_name[0]): - raise ValueError('Bucket names must start with a number or letter.') - - if not re.match('[a-zA-Z0-9]+', bucket_name[-1]): - raise ValueError('Bucket names must end with a number or letter.') - service = self.get_conn() + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket_name) bucket_resource = resource or {} - bucket_resource.update({ - 'name': bucket_name, - 'location': location, - 'storageClass': storage_class - }) - - self.log.info('The Default Project ID is %s', self.project_id) - - if labels is not None: - bucket_resource['labels'] = labels - - try: - response = service.buckets().insert( - project=project_id, - body=bucket_resource - ).execute() - self.log.info('Bucket: %s created successfully.', bucket_name) + for item in bucket_resource: + if item != "name": + bucket._patch_property(name=item, value=resource[item]) - return response['id'] + bucket.storage_class = storage_class + bucket.labels = labels or {} + bucket.create(project=project_id, location=location) + return bucket.id - except HttpError as ex: - raise AirflowException( - 'Bucket creation failed. Error was: {}'.format(ex.content) - ) - - def insert_bucket_acl(self, bucket, entity, role, user_project): + def insert_bucket_acl(self, bucket, entity, role, user_project=None): """ Creates a new ACL entry on the specified bucket. See: https://cloud.google.com/storage/docs/json_api/v1/bucketAccessControls/insert @@ -595,25 +466,17 @@ def insert_bucket_acl(self, bucket, entity, role, user_project): :type user_project: str """ self.log.info('Creating a new ACL entry in bucket: %s', bucket) - service = self.get_conn() - try: - response = service.bucketAccessControls().insert( - bucket=bucket, - body={ - "entity": entity, - "role": role - }, - userProject=user_project - ).execute() - if response: - self.log.info('A new ACL entry created in bucket: %s', bucket) - except HttpError as ex: - raise AirflowException( - 'Bucket ACL entry creation failed. Error was: {}'.format(ex.content) - ) + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket) + bucket.acl.reload() + bucket.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) + if user_project: + bucket.acl.user_project = user_project + bucket.acl.save() - def insert_object_acl(self, bucket, object_name, entity, role, generation, - user_project): + self.log.info('A new ACL entry created in bucket: %s', bucket) + + def insert_object_acl(self, bucket, object_name, entity, role, user_project=None): """ Creates a new ACL entry on the specified object. See: https://cloud.google.com/storage/docs/json_api/v1/objectAccessControls/insert @@ -632,36 +495,26 @@ def insert_object_acl(self, bucket, object_name, entity, role, generation, :param role: The access permission for the entity. Acceptable values are: "OWNER", "READER". :type role: str - :param generation: (Optional) If present, selects a specific revision of this - object (as opposed to the latest version, the default). - :type generation: str :param user_project: (Optional) The project to be billed for this request. Required for Requester Pays buckets. :type user_project: str """ self.log.info('Creating a new ACL entry for object: %s in bucket: %s', object_name, bucket) - service = self.get_conn() - try: - response = service.objectAccessControls().insert( - bucket=bucket, - object=object_name, - body={ - "entity": entity, - "role": role - }, - generation=generation, - userProject=user_project - ).execute() - if response: - self.log.info('A new ACL entry created for object: %s in bucket: %s', - object_name, bucket) - except HttpError as ex: - raise AirflowException( - 'Object ACL entry creation failed. Error was: {}'.format(ex.content) - ) + client = self.get_conn() + bucket = client.bucket(bucket_name=bucket) + blob = bucket.blob(object_name) + # Reload fetches the current ACL from Cloud Storage. + blob.acl.reload() + blob.acl.entity_from_dict(entity_dict={"entity": entity, "role": role}) + if user_project: + blob.acl.user_project = user_project + blob.acl.save() + + self.log.info('A new ACL entry created for object: %s in bucket: %s', + object_name, bucket) - def compose(self, bucket, source_objects, destination_object, num_retries=5): + def compose(self, bucket, source_objects, destination_object): """ Composes a list of existing object into a new object in the same storage bucket @@ -686,28 +539,17 @@ def compose(self, bucket, source_objects, destination_object, num_retries=5): if not bucket or not destination_object: raise ValueError('bucket and destination_object cannot be empty.') - service = self.get_conn() - - dict_source_objects = [{'name': source_object} - for source_object in source_objects] - body = { - 'sourceObjects': dict_source_objects - } - - try: - self.log.info("Composing %s to %s in the bucket %s", - source_objects, destination_object, bucket) - service \ - .objects() \ - .compose(destinationBucket=bucket, - destinationObject=destination_object, - body=body) \ - .execute(num_retries=num_retries) - return True - except HttpError as ex: - if ex.resp['status'] == '404': - return False - raise + self.log.info("Composing %s to %s in the bucket %s", + source_objects, destination_object, bucket) + client = self.get_conn() + bucket = client.get_bucket(bucket) + destination_blob = bucket.blob(destination_object) + destination_blob.compose( + sources=[ + bucket.blob(blob_name=source_object) for source_object in source_objects + ]) + + self.log.info("Completed successfully.") def _parse_gcs_url(gsurl): diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py index 1d168d466072..4c0d117d994a 100644 --- a/airflow/contrib/operators/gcs_download_operator.py +++ b/airflow/contrib/operators/gcs_download_operator.py @@ -21,6 +21,7 @@ from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook from airflow.models import BaseOperator +from airflow.models.xcom import MAX_XCOM_SIZE from airflow.utils.decorators import apply_defaults @@ -82,7 +83,7 @@ def execute(self, context): object=self.object, filename=self.filename) if self.store_to_xcom_key: - if sys.getsizeof(file_bytes) < 48000: + if sys.getsizeof(file_bytes) < MAX_XCOM_SIZE: context['ti'].xcom_push(key=self.store_to_xcom_key, value=file_bytes) else: raise RuntimeError( diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py new file mode 100644 index 000000000000..ca049c837b2e --- /dev/null +++ b/airflow/models/xcom.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import pickle + +from sqlalchemy import Column, Integer, String, Index, LargeBinary, and_ +from sqlalchemy.orm import reconstructor + +from airflow import configuration +from airflow.models.base import Base, ID_LEN +from airflow.utils import timezone +from airflow.utils.db import provide_session +from airflow.utils.helpers import as_tuple +from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.sqlalchemy import UtcDateTime + + +# MAX XCOM Size is 48KB +# https://github.com/apache/airflow/pull/1618#discussion_r68249677 +MAX_XCOM_SIZE = 49344 +XCOM_RETURN_KEY = 'return_value' + + +class XCom(Base, LoggingMixin): + """ + Base class for XCom objects. + """ + __tablename__ = "xcom" + + id = Column(Integer, primary_key=True) + key = Column(String(512)) + value = Column(LargeBinary) + timestamp = Column( + UtcDateTime, default=timezone.utcnow, nullable=False) + execution_date = Column(UtcDateTime, nullable=False) + + # source information + task_id = Column(String(ID_LEN), nullable=False) + dag_id = Column(String(ID_LEN), nullable=False) + + __table_args__ = ( + Index('idx_xcom_dag_task_date', dag_id, task_id, execution_date, unique=False), + ) + + """ + TODO: "pickling" has been deprecated and JSON is preferred. + "pickling" will be removed in Airflow 2.0. + """ + @reconstructor + def init_on_load(self): + enable_pickling = configuration.getboolean('core', 'enable_xcom_pickling') + if enable_pickling: + self.value = pickle.loads(self.value) + else: + try: + self.value = json.loads(self.value.decode('UTF-8')) + except (UnicodeEncodeError, ValueError): + # For backward-compatibility. + # Preventing errors in webserver + # due to XComs mixed with pickled and unpickled. + self.value = pickle.loads(self.value) + + def __repr__(self): + return ''.format( + key=self.key, + task_id=self.task_id, + execution_date=self.execution_date) + + @classmethod + @provide_session + def set( + cls, + key, + value, + execution_date, + task_id, + dag_id, + session=None): + """ + Store an XCom value. + TODO: "pickling" has been deprecated and JSON is preferred. + "pickling" will be removed in Airflow 2.0. + + :return: None + """ + session.expunge_all() + + enable_pickling = configuration.getboolean('core', 'enable_xcom_pickling') + if enable_pickling: + value = pickle.dumps(value) + else: + try: + value = json.dumps(value).encode('UTF-8') + except ValueError: + log = LoggingMixin().log + log.error("Could not serialize the XCOM value into JSON. " + "If you are using pickles instead of JSON " + "for XCOM, then you need to enable pickle " + "support for XCOM in your airflow config.") + raise + + # remove any duplicate XComs + session.query(cls).filter( + cls.key == key, + cls.execution_date == execution_date, + cls.task_id == task_id, + cls.dag_id == dag_id).delete() + + session.commit() + + # insert new XCom + session.add(XCom( + key=key, + value=value, + execution_date=execution_date, + task_id=task_id, + dag_id=dag_id)) + + session.commit() + + @classmethod + @provide_session + def get_one(cls, + execution_date, + key=None, + task_id=None, + dag_id=None, + include_prior_dates=False, + session=None): + """ + Retrieve an XCom value, optionally meeting certain criteria. + TODO: "pickling" has been deprecated and JSON is preferred. + "pickling" will be removed in Airflow 2.0. + + :return: XCom value + """ + filters = [] + if key: + filters.append(cls.key == key) + if task_id: + filters.append(cls.task_id == task_id) + if dag_id: + filters.append(cls.dag_id == dag_id) + if include_prior_dates: + filters.append(cls.execution_date <= execution_date) + else: + filters.append(cls.execution_date == execution_date) + + query = ( + session.query(cls.value).filter(and_(*filters)) + .order_by(cls.execution_date.desc(), cls.timestamp.desc())) + + result = query.first() + if result: + enable_pickling = configuration.getboolean('core', 'enable_xcom_pickling') + if enable_pickling: + return pickle.loads(result.value) + else: + try: + return json.loads(result.value.decode('UTF-8')) + except ValueError: + log = LoggingMixin().log + log.error("Could not deserialize the XCOM value from JSON. " + "If you are using pickles instead of JSON " + "for XCOM, then you need to enable pickle " + "support for XCOM in your airflow config.") + raise + + @classmethod + @provide_session + def get_many(cls, + execution_date, + key=None, + task_ids=None, + dag_ids=None, + include_prior_dates=False, + limit=100, + session=None): + """ + Retrieve an XCom value, optionally meeting certain criteria + TODO: "pickling" has been deprecated and JSON is preferred. + "pickling" will be removed in Airflow 2.0. + """ + filters = [] + if key: + filters.append(cls.key == key) + if task_ids: + filters.append(cls.task_id.in_(as_tuple(task_ids))) + if dag_ids: + filters.append(cls.dag_id.in_(as_tuple(dag_ids))) + if include_prior_dates: + filters.append(cls.execution_date <= execution_date) + else: + filters.append(cls.execution_date == execution_date) + + query = ( + session.query(cls).filter(and_(*filters)) + .order_by(cls.execution_date.desc(), cls.timestamp.desc()) + .limit(limit)) + results = query.all() + return results + + @classmethod + @provide_session + def delete(cls, xcoms, session=None): + if isinstance(xcoms, XCom): + xcoms = [xcoms] + for xcom in xcoms: + if not isinstance(xcom, XCom): + raise TypeError( + 'Expected XCom; received {}'.format(xcom.__class__.__name__) + ) + session.delete(xcom) + session.commit() diff --git a/setup.py b/setup.py index 006de0ad444a..4236950eed1c 100644 --- a/setup.py +++ b/setup.py @@ -179,6 +179,7 @@ def write_version(filename=os.path.join(*['airflow', 'google-cloud-container>=0.1.1', 'google-cloud-language>=1.1.1', 'google-cloud-spanner>=1.7.1', + 'google-cloud-storage~=1.14', 'google-cloud-translate>=1.3.3', 'google-cloud-vision>=0.35.2', 'grpcio-gcp>=0.2.2', diff --git a/tests/contrib/hooks/test_gcs_hook.py b/tests/contrib/hooks/test_gcs_hook.py index faed4db6c0f0..4fb3b76440bb 100644 --- a/tests/contrib/hooks/test_gcs_hook.py +++ b/tests/contrib/hooks/test_gcs_hook.py @@ -24,13 +24,16 @@ import airflow.contrib.hooks.gcs_hook as gcs_hook from airflow.exceptions import AirflowException -from googleapiclient.errors import HttpError from tests.compat import mock +from tests.contrib.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id +from google.cloud import storage +from google.cloud import exceptions BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}' GCS_STRING = 'airflow.contrib.hooks.gcs_hook.{}' EMPTY_CONTENT = ''.encode('utf8') +PROJECT_ID_TEST = 'project-id' class TestGCSHookHelperFunctions(unittest.TestCase): @@ -57,128 +60,73 @@ def test_parse_gcs_url(self): gcs_hook._parse_gcs_url('gs://bucket/'), ('bucket', '')) -class TestGCSBucket(unittest.TestCase): - def test_bucket_name_value(self): - - bad_start_bucket_name = '/testing123' - with self.assertRaises(ValueError): - - gcs_hook.GoogleCloudStorageHook().create_bucket( - bucket_name=bad_start_bucket_name - ) - - bad_end_bucket_name = 'testing123/' - with self.assertRaises(ValueError): - gcs_hook.GoogleCloudStorageHook().create_bucket( - bucket_name=bad_end_bucket_name - ) - - class TestGoogleCloudStorageHook(unittest.TestCase): def setUp(self): - with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__')): + with mock.patch( + GCS_STRING.format('GoogleCloudBaseHook.__init__'), + new=mock_base_gcp_hook_default_project_id, + ): self.gcs_hook = gcs_hook.GoogleCloudStorageHook( - google_cloud_storage_conn_id='test' - ) + google_cloud_storage_conn_id='test') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_exists(self, mock_service): - test_bucket = 'test_bucket' test_object = 'test_object' - (mock_service.return_value.objects.return_value - .get.return_value.execute.return_value) = { - "kind": "storage#object", - # The ID of the object, including the bucket name, - # object name, and generation number. - "id": "{}/{}/1521132662504504".format(test_bucket, test_object), - "name": test_object, - "bucket": test_bucket, - "generation": "1521132662504504", - "contentType": "text/csv", - "timeCreated": "2018-03-15T16:51:02.502Z", - "updated": "2018-03-15T16:51:02.502Z", - "storageClass": "MULTI_REGIONAL", - "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z", - "size": "89", - "md5Hash": "leYUJBUWrRtks1UeUFONJQ==", - "metadata": { - "md5-hash": "95e614241516ad1b64b3551e50538d25" - }, - "crc32c": "xgdNfQ==", - "etag": "CLf4hODk7tkCEAE=" - } + # Given + get_bucket_mock = mock_service.return_value.get_bucket + blob_object = get_bucket_mock.return_value.blob + exists_method = blob_object.return_value.exists + exists_method.return_value = True + # When response = self.gcs_hook.exists(bucket=test_bucket, object=test_object) + # Then self.assertTrue(response) + get_bucket_mock.assert_called_once_with(bucket_name=test_bucket) + blob_object.assert_called_once_with(blob_name=test_object) + exists_method.assert_called_once_with() @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_exists_nonexisting_object(self, mock_service): - test_bucket = 'test_bucket' test_object = 'test_object' - (mock_service.return_value.objects.return_value - .get.return_value.execute.side_effect) = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) + # Given + get_bucket_mock = mock_service.return_value.get_bucket + blob_object = get_bucket_mock.return_value.blob + exists_method = blob_object.return_value.exists + exists_method.return_value = False + # When response = self.gcs_hook.exists(bucket=test_bucket, object=test_object) + # Then self.assertFalse(response) + @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_copy(self, mock_service): + def test_copy(self, mock_service, mock_bucket): source_bucket = 'test-source-bucket' source_object = 'test-source-object' destination_bucket = 'test-dest-bucket' destination_object = 'test-dest-object' - (mock_service.return_value.objects.return_value - .get.return_value.execute.return_value) = { - "kind": "storage#object", - # The ID of the object, including the bucket name, object name, - # and generation number. - "id": "{}/{}/1521132662504504".format( - destination_bucket, destination_object), - "name": destination_object, - "bucket": destination_bucket, - "generation": "1521132662504504", - "contentType": "text/csv", - "timeCreated": "2018-03-15T16:51:02.502Z", - "updated": "2018-03-15T16:51:02.502Z", - "storageClass": "MULTI_REGIONAL", - "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z", - "size": "89", - "md5Hash": "leYUJBUWrRtks1UeUFONJQ==", - "metadata": { - "md5-hash": "95e614241516ad1b64b3551e50538d25" - }, - "crc32c": "xgdNfQ==", - "etag": "CLf4hODk7tkCEAE=" - } - - response = self.gcs_hook.copy( - source_bucket=source_bucket, - source_object=source_object, - destination_bucket=destination_bucket, - destination_object=destination_object - ) + destination_bucket_instance = mock_bucket + source_blob = mock_bucket.blob(source_object) + destination_blob = storage.Blob( + bucket=destination_bucket_instance, + name=destination_object) - self.assertTrue(response) - - @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_copy_failedcopy(self, mock_service): - source_bucket = 'test-source-bucket' - source_object = 'test-source-object' - destination_bucket = 'test-dest-bucket' - destination_object = 'test-dest-object' - - (mock_service.return_value.objects.return_value - .copy.return_value.execute.side_effect) = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) + # Given + get_bucket_mock = mock_service.return_value.get_bucket + get_bucket_mock.return_value = mock_bucket + copy_method = get_bucket_mock.return_value.copy_blob + copy_method.return_value = destination_blob + # When response = self.gcs_hook.copy( source_bucket=source_bucket, source_object=source_object, @@ -186,10 +134,15 @@ def test_copy_failedcopy(self, mock_service): destination_object=destination_object ) - self.assertFalse(response) + # Then + self.assertEqual(response, None) + copy_method.assert_called_once_with( + blob=source_blob, + destination_bucket=destination_bucket_instance, + new_name=destination_object + ) def test_copy_fail_same_source_and_destination(self): - source_bucket = 'test-source-bucket' source_object = 'test-source-object' destination_bucket = 'test-source-bucket' @@ -209,7 +162,6 @@ def test_copy_fail_same_source_and_destination(self): ) def test_copy_empty_source_bucket(self): - source_bucket = None source_object = 'test-source-object' destination_bucket = 'test-dest-bucket' @@ -227,7 +179,6 @@ def test_copy_empty_source_bucket(self): ) def test_copy_empty_source_object(self): - source_bucket = 'test-source-object' source_object = None destination_bucket = 'test-dest-bucket' @@ -244,109 +195,77 @@ def test_copy_empty_source_object(self): 'source_bucket and source_object cannot be empty.' ) + @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_rewrite(self, mock_service): + def test_rewrite(self, mock_service, mock_bucket): source_bucket = 'test-source-bucket' source_object = 'test-source-object' destination_bucket = 'test-dest-bucket' destination_object = 'test-dest-object' - # First response has `done` equals False has it has not completed copying - # It also has `rewriteToken` which would be passed to the second call - # to the api. - first_response = { - "kind": "storage#rewriteResponse", - "totalBytesRewritten": "9111", - "objectSize": "9111", - "done": False, - "rewriteToken": "testRewriteToken" - } - - second_response = { - "kind": "storage#rewriteResponse", - "totalBytesRewritten": "9111", - "objectSize": "9111", - "done": True, - "resource": { - "kind": "storage#object", - # The ID of the object, including the bucket name, - # object name, and generation number. - "id": "{}/{}/1521132662504504".format( - destination_bucket, destination_object), - "name": destination_object, - "bucket": destination_bucket, - "generation": "1521132662504504", - "contentType": "text/csv", - "timeCreated": "2018-03-15T16:51:02.502Z", - "updated": "2018-03-15T16:51:02.502Z", - "storageClass": "MULTI_REGIONAL", - "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z", - "size": "9111", - "md5Hash": "leYUJBUWrRtks1UeUFONJQ==", - "metadata": { - "md5-hash": "95e614241516ad1b64b3551e50538d25" - }, - "crc32c": "xgdNfQ==", - "etag": "CLf4hODk7tkCEAE=" - } - } - - (mock_service.return_value.objects.return_value - .rewrite.return_value.execute.side_effect) = [first_response, second_response] - - result = self.gcs_hook.rewrite( + source_blob = mock_bucket.blob(source_object) + + # Given + get_bucket_mock = mock_service.return_value.get_bucket + get_bucket_mock.return_value = mock_bucket + get_blob_method = get_bucket_mock.return_value.blob + rewrite_method = get_blob_method.return_value.rewrite + rewrite_method.side_effect = [(None, mock.ANY, mock.ANY), (mock.ANY, mock.ANY, mock.ANY)] + + # When + response = self.gcs_hook.rewrite( source_bucket=source_bucket, source_object=source_object, destination_bucket=destination_bucket, - destination_object=destination_object - ) + destination_object=destination_object) - self.assertTrue(result) - mock_service.return_value.objects.return_value.rewrite.assert_called_with( - sourceBucket=source_bucket, - sourceObject=source_object, - destinationBucket=destination_bucket, - destinationObject=destination_object, - rewriteToken=first_response['rewriteToken'], - body='' - ) + # Then + self.assertEqual(response, None) + rewrite_method.assert_called_once_with( + source=source_blob) + @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_delete(self, mock_service): + def test_delete(self, mock_service, mock_bucket): test_bucket = 'test_bucket' test_object = 'test_object' + blob_to_be_deleted = storage.Blob(name=test_object, bucket=mock_bucket) - (mock_service.return_value.objects.return_value - .delete.return_value.execute.return_value) = {} + get_bucket_method = mock_service.return_value.get_bucket + get_blob_method = get_bucket_method.return_value.get_blob + delete_method = get_blob_method.return_value.delete + delete_method.return_value = blob_to_be_deleted response = self.gcs_hook.delete(bucket=test_bucket, object=test_object) - - self.assertTrue(response) + self.assertIsNone(response) @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_delete_nonexisting_object(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - (mock_service.return_value.objects.return_value - .delete.return_value.execute.side_effect) = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) + get_bucket_method = mock_service.return_value.get_bucket + blob = get_bucket_method.return_value.blob + delete_method = blob.return_value.delete + delete_method.side_effect = exceptions.NotFound(message="Not Found") - response = self.gcs_hook.delete(bucket=test_bucket, object=test_object) - - self.assertFalse(response) + with self.assertRaises(exceptions.NotFound): + self.gcs_hook.delete(bucket=test_bucket, object=test_object) + @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_create_bucket(self, mock_service): + def test_create_bucket(self, mock_service, mock_bucket): test_bucket = 'test_bucket' test_project = 'test-project' - test_location = 'EU', + test_location = 'EU' test_labels = {'env': 'prod'} test_storage_class = 'MULTI_REGIONAL' - test_response_id = "{}/0123456789012345".format(test_bucket) - (mock_service.return_value.buckets.return_value - .insert.return_value.execute.return_value) = {"id": test_response_id} + mock_service.return_value.bucket.return_value.create.return_value = None + mock_bucket.return_value.storage_class = test_storage_class + mock_bucket.return_value.labels = test_labels + + sample_bucket = mock_service().bucket(bucket_name=test_bucket) response = self.gcs_hook.create_bucket( bucket_name=test_bucket, @@ -356,59 +275,63 @@ def test_create_bucket(self, mock_service): project_id=test_project ) - self.assertEqual(response, test_response_id) - mock_service.return_value.buckets.return_value.insert.assert_called_with( - project=test_project, - body={ - 'name': test_bucket, - 'location': test_location, - 'storageClass': test_storage_class, - 'labels': test_labels - } + self.assertEquals(response, sample_bucket.id) + + self.assertEquals(sample_bucket.storage_class, test_storage_class) + self.assertEquals(sample_bucket.labels, test_labels) + + mock_service.return_value.bucket.return_value.create.assert_called_with( + project=test_project, location=test_location ) + @mock.patch('google.cloud.storage.Bucket') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_create_bucket_with_resource(self, mock_service): + def test_create_bucket_with_resource(self, mock_service, mock_bucket): test_bucket = 'test_bucket' test_project = 'test-project' - test_location = 'EU', + test_location = 'EU' test_labels = {'env': 'prod'} test_storage_class = 'MULTI_REGIONAL' - test_response_id = "{}/0123456789012345".format(test_bucket) - test_lifecycle = {"rule": [{"action": {"type": "Delete"}, "condition": {"age": 7}}]} + test_versioning_enabled = {"enabled": True} + + mock_service.return_value.bucket.return_value.create.return_value = None + mock_bucket.return_value.storage_class = test_storage_class + mock_bucket.return_value.labels = test_labels + mock_bucket.return_value.versioning_enabled = True - (mock_service.return_value.buckets.return_value - .insert.return_value.execute.return_value) = {"id": test_response_id} + sample_bucket = mock_service().bucket(bucket_name=test_bucket) + # sample_bucket = storage.Bucket(client=mock_service, name=test_bucket) # Assert for resource other than None. response = self.gcs_hook.create_bucket( bucket_name=test_bucket, - resource={"lifecycle": {"rule": [{"action": {"type": "Delete"}, "condition": {"age": 7}}]}}, + resource={"versioning": test_versioning_enabled}, storage_class=test_storage_class, location=test_location, labels=test_labels, project_id=test_project ) + self.assertEquals(response, sample_bucket.id) - self.assertEqual(response, test_response_id) - mock_service.return_value.buckets.return_value.insert.assert_called_with( - project=test_project, - body={ - "lifecycle": test_lifecycle, - 'name': test_bucket, - 'location': test_location, - 'storageClass': test_storage_class, - 'labels': test_labels - } + mock_service.return_value.bucket.return_value._patch_property.assert_called_with( + name='versioning', value=test_versioning_enabled ) + mock_service.return_value.bucket.return_value.create.assert_called_with( + project=test_project, location=test_location + ) + + @mock.patch('google.cloud.storage.Bucket.blob') @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_compose(self, mock_service): + def test_compose(self, mock_service, mock_blob): test_bucket = 'test_bucket' test_source_objects = ['test_object_1', 'test_object_2', 'test_object_3'] test_destination_object = 'test_object_composed' - method = (mock_service.return_value.objects.return_value.compose) + mock_service.return_value.get_bucket.return_value\ + .blob.return_value = mock_blob(blob_name=mock.ANY) + method = mock_service.return_value.get_bucket.return_value.blob\ + .return_value.compose self.gcs_hook.compose( bucket=test_bucket, @@ -416,19 +339,10 @@ def test_compose(self, mock_service): destination_object=test_destination_object ) - body = { - 'sourceObjects': [ - {'name': 'test_object_1'}, - {'name': 'test_object_2'}, - {'name': 'test_object_3'} - ] - } - method.assert_called_once_with( - destinationBucket=test_bucket, - destinationObject=test_destination_object, - body=body - ) + sources=[ + mock_blob(blob_name=source_object) for source_object in test_source_objects + ]) @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_compose_with_empty_source_objects(self, mock_service): @@ -505,121 +419,32 @@ def test_upload(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - (mock_service.return_value.objects.return_value - .insert.return_value.execute.return_value) = { - "kind": "storage#object", - "id": "{}/{}/0123456789012345".format(test_bucket, test_object), - "name": test_object, - "bucket": test_bucket, - "generation": "0123456789012345", - "contentType": "application/octet-stream", - "timeCreated": "2018-03-15T16:51:02.502Z", - "updated": "2018-03-15T16:51:02.502Z", - "storageClass": "MULTI_REGIONAL", - "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z", - "size": "393216", - "md5Hash": "leYUJBUWrRtks1UeUFONJQ==", - "crc32c": "xgdNfQ==", - "etag": "CLf4hODk7tkCEAE=" - } + upload_method = mock_service.return_value.get_bucket.return_value\ + .blob.return_value.upload_from_filename + upload_method.return_value = None response = self.gcs_hook.upload(test_bucket, test_object, self.testfile.name) - self.assertTrue(response) + self.assertIsNone(response) + upload_method.assert_called_once_with( + filename=self.testfile.name, + content_type='application/octet-stream' + ) @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) def test_upload_gzip(self, mock_service): test_bucket = 'test_bucket' test_object = 'test_object' - (mock_service.return_value.objects.return_value - .insert.return_value.execute.return_value) = { - "kind": "storage#object", - "id": "{}/{}/0123456789012345".format(test_bucket, test_object), - "name": test_object, - "bucket": test_bucket, - "generation": "0123456789012345", - "contentType": "application/octet-stream", - "timeCreated": "2018-03-15T16:51:02.502Z", - "updated": "2018-03-15T16:51:02.502Z", - "storageClass": "MULTI_REGIONAL", - "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z", - "size": "393216", - "md5Hash": "leYUJBUWrRtks1UeUFONJQ==", - "crc32c": "xgdNfQ==", - "etag": "CLf4hODk7tkCEAE=" - } - - response = self.gcs_hook.upload(test_bucket, - test_object, - self.testfile.name, - gzip=True) - self.assertFalse(os.path.exists(self.testfile.name + '.gz')) - self.assertTrue(response) - - @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_upload_gzip_error(self, mock_service): - test_bucket = 'test_bucket' - test_object = 'test_object' - - (mock_service.return_value.objects.return_value - .insert.return_value.execute.side_effect) = HttpError( - resp={'status': '404'}, content=EMPTY_CONTENT) + upload_method = mock_service.return_value.get_bucket.return_value \ + .blob.return_value.upload_from_filename + upload_method.return_value = None response = self.gcs_hook.upload(test_bucket, test_object, self.testfile.name, gzip=True) self.assertFalse(os.path.exists(self.testfile.name + '.gz')) - self.assertFalse(response) - - @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_upload_multipart(self, mock_service): - test_bucket = 'test_bucket' - test_object = 'test_object' - - class MockProgress: - def __init__(self, value): - self.value = value - - def progress(self): - return self.value - - (mock_service.return_value.objects.return_value - .insert.return_value.next_chunk.side_effect) = [ - (MockProgress(0.66), None), - (MockProgress(1.0), { - "kind": "storage#object", - "id": "{}/{}/0123456789012345".format(test_bucket, test_object), - "name": test_object, - "bucket": test_bucket, - "generation": "0123456789012345", - "contentType": "application/octet-stream", - "timeCreated": "2018-03-15T16:51:02.502Z", - "updated": "2018-03-15T16:51:02.502Z", - "storageClass": "MULTI_REGIONAL", - "timeStorageClassUpdated": "2018-03-15T16:51:02.502Z", - "size": "393216", - "md5Hash": "leYUJBUWrRtks1UeUFONJQ==", - "crc32c": "xgdNfQ==", - "etag": "CLf4hODk7tkCEAE=" - }) - ] - - response = self.gcs_hook.upload(test_bucket, - test_object, - self.testfile.name, - multipart=True) - - self.assertTrue(response) - - @mock.patch(GCS_STRING.format('GoogleCloudStorageHook.get_conn')) - def test_upload_multipart_wrong_chunksize(self, mock_service): - test_bucket = 'test_bucket' - test_object = 'test_object' - - with self.assertRaises(ValueError): - self.gcs_hook.upload(test_bucket, test_object, - self.testfile.name, multipart=123) + self.assertIsNone(response)